dsq_core/ops/
aggregate.rs

1//! Aggregation operations for dsq
2//!
3//! This module provides aggregation functions for `DataFrames` including:
4//! - Group by operations
5//! - Statistical aggregations (sum, mean, count, etc.)
6//! - Window functions
7//! - Pivot and unpivot operations
8//!
9//! These operations correspond to common SQL aggregations and jq's `group_by`
10//! functionality, adapted for tabular data processing.
11
12use crate::{Error, Result, TypeError, Value};
13use polars::prelude::*;
14use smallvec::SmallVec;
15use std::collections::HashMap;
16
17/// Helper function to convert `AnyValue` to Value
18fn any_value_to_value(any_val: &AnyValue) -> Result<Value> {
19    use serde_json::Value as JsonValue;
20    let json_val = match any_val {
21        AnyValue::Null => JsonValue::Null,
22        AnyValue::Boolean(b) => JsonValue::Bool(*b),
23        AnyValue::Int8(i) => JsonValue::Number(serde_json::Number::from(*i)),
24        AnyValue::Int16(i) => JsonValue::Number(serde_json::Number::from(*i)),
25        AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
26        AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
27        AnyValue::UInt8(i) => JsonValue::Number(serde_json::Number::from(*i)),
28        AnyValue::UInt16(i) => JsonValue::Number(serde_json::Number::from(*i)),
29        AnyValue::UInt32(i) => JsonValue::Number(serde_json::Number::from(*i)),
30        AnyValue::UInt64(i) => JsonValue::Number(serde_json::Number::from(*i)),
31        AnyValue::Float32(f) => JsonValue::Number(
32            serde_json::Number::from_f64(f64::from(*f))
33                .ok_or_else(|| Error::operation("Invalid float"))?,
34        ),
35        AnyValue::Float64(f) => JsonValue::Number(
36            serde_json::Number::from_f64(*f).ok_or_else(|| Error::operation("Invalid float"))?,
37        ),
38        AnyValue::String(s) => JsonValue::String((*s).to_string()),
39        _ => return Err(Error::operation("Unsupported AnyValue type")),
40    };
41    Ok(Value::from_json(json_val))
42}
43
44/// Helper function to convert `DataFrame` to Array of Objects
45fn df_to_array(df: &DataFrame) -> Result<Vec<Value>> {
46    let columns = df.get_column_names();
47    let mut result = Vec::with_capacity(df.height());
48
49    for row_idx in 0..df.height() {
50        let mut obj = std::collections::HashMap::new();
51        for col_name in &columns {
52            let series = df.column(col_name).map_err(Error::from)?;
53            let any_val = series.get(row_idx).map_err(Error::from)?;
54            let value = any_value_to_value(&any_val)?;
55            obj.insert(col_name.to_string(), value);
56        }
57        result.push(Value::Object(obj));
58    }
59
60    Ok(result)
61}
62
63pub fn group_by(value: &Value, columns: &[String]) -> Result<Value> {
64    if columns.is_empty() {
65        return Err(Error::operation("Group by requires at least one column"));
66    }
67
68    match value {
69        Value::DataFrame(df) => {
70            // Convert DataFrame to array of objects, then group
71            let arr = df_to_array(df)?;
72            group_by(&Value::Array(arr), columns)
73        }
74        Value::LazyFrame(lf) => {
75            let grouped = lf
76                .clone()
77                .group_by(columns.iter().map(col).collect::<Vec<_>>())
78                .agg([col("*").count().alias("count")]);
79            Ok(Value::LazyFrame(Box::new(grouped)))
80        }
81        Value::Array(arr) => {
82            // Group array of objects by specified fields
83            let mut groups: std::collections::BTreeMap<String, Vec<Value>> =
84                std::collections::BTreeMap::new();
85
86            for item in arr {
87                if let Value::Object(obj) = item {
88                    // Create group key from specified columns
89                    let mut key_parts = Vec::new();
90                    for col in columns {
91                        if let Some(val) = obj.get(col) {
92                            key_parts.push(format!("{val:?}"));
93                        } else {
94                            key_parts.push("null".to_string());
95                        }
96                    }
97                    let key = key_parts.join("|");
98
99                    groups.entry(key).or_default().push(item.clone());
100                } else {
101                    return Err(TypeError::UnsupportedOperation {
102                        operation: "group_by".to_string(),
103                        typ: item.type_name().to_string(),
104                    }
105                    .into());
106                }
107            }
108
109            // Convert groups to array of arrays
110            let grouped: Vec<Value> = groups.into_values().map(Value::Array).collect();
111
112            Ok(Value::Array(grouped))
113        }
114        _ => Err(TypeError::UnsupportedOperation {
115            operation: "group_by".to_string(),
116            typ: value.type_name().to_string(),
117        }
118        .into()),
119    }
120}
121
122/// Apply aggregation functions to grouped data
123///
124/// # Examples
125///
126/// ```rust,ignore
127/// use dsq_core::ops::aggregate::{group_by_agg, AggregationFunction};
128/// use dsq_core::value::Value;
129///
130/// let group_cols = vec!["department".to_string()];
131/// let agg_funcs = vec![
132///     AggregationFunction::Sum("salary".to_string()),
133///     AggregationFunction::Mean("age".to_string()),
134///     AggregationFunction::Count,
135/// ];
136/// let result = group_by_agg(&dataframe_value, &group_cols, &agg_funcs).unwrap();
137/// ```
138pub fn group_by_agg(
139    value: &Value,
140    group_columns: &[String],
141    aggregations: &[AggregationFunction],
142) -> Result<Value> {
143    if group_columns.is_empty() {
144        return Err(Error::operation("Group by requires at least one column"));
145    }
146
147    if aggregations.is_empty() {
148        return Err(Error::operation(
149            "Aggregation requires at least one function",
150        ));
151    }
152
153    match value {
154        Value::DataFrame(df) => {
155            let group_exprs: Vec<Expr> = group_columns.iter().map(col).collect();
156            let agg_exprs: Vec<Expr> = aggregations
157                .iter()
158                .map(AggregationFunction::to_polars_expr)
159                .collect::<crate::Result<Vec<_>>>()?;
160
161            let grouped = df
162                .clone()
163                .lazy()
164                .group_by(group_exprs)
165                .agg(agg_exprs)
166                .collect()
167                .map_err(Error::from)?;
168
169            Ok(Value::DataFrame(grouped))
170        }
171        Value::LazyFrame(lf) => {
172            let group_exprs: Vec<Expr> = group_columns.iter().map(col).collect();
173            let agg_exprs: Vec<Expr> = aggregations
174                .iter()
175                .map(AggregationFunction::to_polars_expr)
176                .collect::<crate::Result<Vec<_>>>()?;
177
178            let grouped = lf.clone().group_by(group_exprs).agg(agg_exprs);
179
180            Ok(Value::LazyFrame(Box::new(grouped)))
181        }
182        Value::Array(arr) => group_by_agg_array(arr, group_columns, aggregations),
183        _ => Err(TypeError::UnsupportedOperation {
184            operation: "group_by_agg".to_string(),
185            typ: value.type_name().to_string(),
186        }
187        .into()),
188    }
189}
190
191/// Aggregation functions that can be applied to grouped data
192#[derive(Debug, Clone)]
193pub enum AggregationFunction {
194    /// Count of rows in each group
195    Count,
196    /// Sum of values in specified column
197    Sum(String),
198    /// Mean/average of values in specified column
199    Mean(String),
200    /// Median of values in specified column
201    Median(String),
202    /// Minimum value in specified column
203    Min(String),
204    /// Maximum value in specified column
205    Max(String),
206    /// Standard deviation of values in specified column
207    Std(String),
208    /// Variance of values in specified column
209    Var(String),
210    /// First value in specified column (within each group)
211    First(String),
212    /// Last value in specified column (within each group)
213    Last(String),
214    /// Collect all values in specified column into a list
215    List(String),
216    /// Count unique values in specified column
217    CountUnique(String),
218    /// Concatenate string values in specified column
219    StringConcat(String, Option<String>), // column, separator
220}
221
222impl AggregationFunction {
223    /// Convert to Polars expression
224    pub fn to_polars_expr(&self) -> Result<Expr> {
225        match self {
226            AggregationFunction::Count => Ok(len().alias("count")),
227            AggregationFunction::Sum(col_name) => {
228                Ok(col(col_name).sum().alias(format!("{col_name}_sum")))
229            }
230            AggregationFunction::Mean(col_name) => {
231                Ok(col(col_name).mean().alias(format!("{col_name}_mean")))
232            }
233            AggregationFunction::Median(col_name) => {
234                Ok(col(col_name).median().alias(format!("{col_name}_median")))
235            }
236            AggregationFunction::Min(col_name) => {
237                Ok(col(col_name).min().alias(format!("{col_name}_min")))
238            }
239            AggregationFunction::Max(col_name) => {
240                Ok(col(col_name).max().alias(format!("{col_name}_max")))
241            }
242            AggregationFunction::Std(col_name) => {
243                Ok(col(col_name).std(1).alias(format!("{col_name}_std")))
244            }
245            AggregationFunction::Var(col_name) => {
246                Ok(col(col_name).var(1).alias(format!("{col_name}_var")))
247            }
248            AggregationFunction::First(col_name) => {
249                Ok(col(col_name).first().alias(format!("{col_name}_first")))
250            }
251            AggregationFunction::Last(col_name) => {
252                Ok(col(col_name).last().alias(format!("{col_name}_last")))
253            }
254            AggregationFunction::List(col_name) => {
255                Ok(col(col_name).alias(format!("{col_name}_list")))
256            }
257            AggregationFunction::CountUnique(col_name) => Ok(col(col_name)
258                .n_unique()
259                .alias(format!("{col_name}_nunique"))),
260            AggregationFunction::StringConcat(col_name, separator) => {
261                let _sep = separator.as_deref().unwrap_or(",");
262                // String concatenation in groupby context requires custom aggregation
263                // For now, we'll collect into a list and handle concatenation in array processing
264                Ok(col(col_name).alias(format!("{col_name}_concat")))
265            }
266        }
267    }
268
269    /// Get the output column name for this aggregation
270    #[must_use]
271    pub fn output_column_name(&self) -> String {
272        match self {
273            AggregationFunction::Count => "count".to_string(),
274            AggregationFunction::Sum(col_name) => format!("{col_name}_sum"),
275            AggregationFunction::Mean(col_name) => format!("{col_name}_mean"),
276            AggregationFunction::Median(col_name) => format!("{col_name}_median"),
277            AggregationFunction::Min(col_name) => format!("{col_name}_min"),
278            AggregationFunction::Max(col_name) => format!("{col_name}_max"),
279            AggregationFunction::Std(col_name) => format!("{col_name}_std"),
280            AggregationFunction::Var(col_name) => format!("{col_name}_var"),
281            AggregationFunction::First(col_name) => format!("{col_name}_first"),
282            AggregationFunction::Last(col_name) => format!("{col_name}_last"),
283            AggregationFunction::List(col_name) => format!("{col_name}_list"),
284            AggregationFunction::CountUnique(col_name) => format!("{col_name}_nunique"),
285            AggregationFunction::StringConcat(col_name, _) => format!("{col_name}_concat"),
286        }
287    }
288}
289
290/// Apply aggregations to array of objects (jq-style)
291fn group_by_agg_array(
292    arr: &[Value],
293    group_columns: &[String],
294    aggregations: &[AggregationFunction],
295) -> Result<Value> {
296    // First group the data
297    let mut groups: std::collections::BTreeMap<String, Vec<&Value>> =
298        std::collections::BTreeMap::new();
299
300    for item in arr {
301        match item {
302            Value::Object(obj) => {
303                // Create group key from specified columns
304                let mut key_parts: SmallVec<[String; 8]> = SmallVec::new();
305                for col in group_columns {
306                    if let Some(val) = obj.get(col) {
307                        let key_part = match val {
308                            Value::String(s) => s.clone(),
309                            Value::Int(i) => i.to_string(),
310                            Value::BigInt(bi) => bi.to_string(),
311                            Value::Float(f) => f.to_string(),
312                            Value::Bool(b) => b.to_string(),
313                            Value::Null => "null".to_string(),
314                            _ => format!("{val:?}"), // For complex types, use debug
315                        };
316                        key_parts.push(key_part);
317                    } else {
318                        key_parts.push("null".to_string());
319                    }
320                }
321                let key = key_parts.join("|");
322
323                groups.entry(key).or_default().push(item);
324            }
325            _ => {
326                return Err(TypeError::UnsupportedOperation {
327                    operation: "group_by_agg".to_string(),
328                    typ: item.type_name().to_string(),
329                }
330                .into());
331            }
332        }
333    }
334
335    // Apply aggregations to each group
336    let mut result_rows = Vec::new();
337
338    for (group_key, group_items) in groups {
339        let mut result_row = HashMap::new();
340
341        // Add group key columns
342        let key_parts: Vec<&str> = group_key.split('|').collect();
343        for (i, col) in group_columns.iter().enumerate() {
344            if let Some(key_part) = key_parts.get(i) {
345                // Try to parse back the original value type
346                let value = if *key_part == "null" {
347                    Value::Null
348                } else if let Ok(int_val) = key_part.parse::<i64>() {
349                    Value::Int(int_val)
350                } else if let Ok(float_val) = key_part.parse::<f64>() {
351                    Value::Float(float_val)
352                } else if *key_part == "true" {
353                    Value::Bool(true)
354                } else if *key_part == "false" {
355                    Value::Bool(false)
356                } else {
357                    // Remove quotes if present
358                    let cleaned = key_part.trim_matches('"');
359                    Value::String(cleaned.to_string())
360                };
361                result_row.insert(col.clone(), value);
362            }
363        }
364
365        // Apply each aggregation
366        for agg in aggregations {
367            let agg_result = apply_aggregation_to_group(agg, &group_items)?;
368            let col_name = agg.output_column_name();
369            result_row.insert(col_name, agg_result);
370        }
371
372        result_rows.push(Value::Object(result_row));
373    }
374
375    Ok(Value::Array(result_rows))
376}
377
378/// Apply a single aggregation function to a group of objects
379fn apply_aggregation_to_group(agg: &AggregationFunction, group_items: &[&Value]) -> Result<Value> {
380    match agg {
381        AggregationFunction::Count => Ok(Value::Int(
382            i64::try_from(group_items.len()).unwrap_or(i64::MAX),
383        )),
384        AggregationFunction::Sum(col_name) => {
385            let mut sum = 0.0;
386            let mut count = 0;
387
388            for item in group_items {
389                if let Value::Object(obj) = item {
390                    if let Some(val) = obj.get(col_name) {
391                        match val {
392                            Value::Int(i) => {
393                                #[allow(clippy::cast_precision_loss)]
394                                {
395                                    sum += *i as f64;
396                                }
397                                count += 1;
398                            }
399                            Value::Float(f) => {
400                                sum += f;
401                                count += 1;
402                            }
403                            Value::Null => {} // Skip nulls
404                            _ => {
405                                return Err(TypeError::UnsupportedOperation {
406                                    operation: "sum".to_string(),
407                                    typ: val.type_name().to_string(),
408                                }
409                                .into());
410                            }
411                        }
412                    }
413                }
414            }
415
416            if count == 0 {
417                Ok(Value::Null)
418            } else {
419                #[allow(clippy::cast_precision_loss)]
420                if sum.fract() == 0.0 && sum <= i64::MAX as f64 && sum >= i64::MIN as f64 {
421                    #[allow(clippy::cast_possible_truncation)]
422                    Ok(Value::Int(sum as i64))
423                } else {
424                    Ok(Value::Float(sum))
425                }
426            }
427        }
428        AggregationFunction::Mean(col_name) => {
429            let mut sum = 0.0;
430            let mut count = 0;
431
432            for item in group_items {
433                if let Value::Object(obj) = item {
434                    if let Some(val) = obj.get(col_name) {
435                        match val {
436                            Value::Int(i) => {
437                                #[allow(clippy::cast_precision_loss)]
438                                {
439                                    sum += *i as f64;
440                                }
441                                count += 1;
442                            }
443                            Value::Float(f) => {
444                                sum += f;
445                                count += 1;
446                            }
447                            Value::Null => {} // Skip nulls
448                            _ => {
449                                return Err(TypeError::UnsupportedOperation {
450                                    operation: "mean".to_string(),
451                                    typ: val.type_name().to_string(),
452                                }
453                                .into());
454                            }
455                        }
456                    }
457                }
458            }
459
460            if count == 0 {
461                Ok(Value::Null)
462            } else {
463                Ok(Value::Float(sum / f64::from(count)))
464            }
465        }
466        AggregationFunction::Min(col_name) => {
467            let mut min_val: Option<&Value> = None;
468
469            for item in group_items {
470                if let Value::Object(obj) = item {
471                    if let Some(val) = obj.get(col_name) {
472                        if !matches!(val, Value::Null) {
473                            match min_val {
474                                None => min_val = Some(val),
475                                Some(current_min) => {
476                                    if compare_values_for_ordering(val, current_min)
477                                        == std::cmp::Ordering::Less
478                                    {
479                                        min_val = Some(val);
480                                    }
481                                }
482                            }
483                        }
484                    }
485                }
486            }
487
488            Ok(min_val.map_or(Value::Null, Clone::clone))
489        }
490        AggregationFunction::Max(col_name) => {
491            let mut max_val: Option<&Value> = None;
492
493            for item in group_items {
494                if let Value::Object(obj) = item {
495                    if let Some(val) = obj.get(col_name) {
496                        if !matches!(val, Value::Null) {
497                            match max_val {
498                                None => max_val = Some(val),
499                                Some(current_max) => {
500                                    if compare_values_for_ordering(val, current_max)
501                                        == std::cmp::Ordering::Greater
502                                    {
503                                        max_val = Some(val);
504                                    }
505                                }
506                            }
507                        }
508                    }
509                }
510            }
511
512            Ok(max_val.map_or(Value::Null, Clone::clone))
513        }
514        AggregationFunction::First(col_name) => {
515            for item in group_items {
516                if let Value::Object(obj) = item {
517                    if let Some(val) = obj.get(col_name) {
518                        return Ok(val.clone());
519                    }
520                }
521            }
522            Ok(Value::Null)
523        }
524        AggregationFunction::Last(col_name) => {
525            for item in group_items.iter().rev() {
526                if let Value::Object(obj) = item {
527                    if let Some(val) = obj.get(col_name) {
528                        return Ok(val.clone());
529                    }
530                }
531            }
532            Ok(Value::Null)
533        }
534        AggregationFunction::List(col_name) => {
535            let mut values: SmallVec<[Value; 16]> = SmallVec::new();
536
537            for item in group_items {
538                if let Value::Object(obj) = item {
539                    if let Some(val) = obj.get(col_name) {
540                        values.push(val.clone());
541                    } else {
542                        values.push(Value::Null);
543                    }
544                }
545            }
546
547            Ok(Value::Array(values.into_vec()))
548        }
549        AggregationFunction::CountUnique(col_name) => {
550            let mut unique_values = std::collections::HashSet::new();
551
552            for item in group_items {
553                if let Value::Object(obj) = item {
554                    if let Some(val) = obj.get(col_name) {
555                        unique_values.insert(format!("{val:?}"));
556                    }
557                }
558            }
559
560            #[allow(clippy::cast_possible_wrap)]
561            {
562                Ok(Value::Int(unique_values.len() as i64))
563            }
564        }
565        AggregationFunction::StringConcat(col_name, separator) => {
566            let mut string_values: SmallVec<[String; 16]> = SmallVec::new();
567            let sep = separator.as_deref().unwrap_or(",");
568
569            for item in group_items {
570                if let Value::Object(obj) = item {
571                    if let Some(val) = obj.get(col_name) {
572                        match val {
573                            Value::String(s) => string_values.push(s.clone()),
574                            Value::Null => {} // Skip nulls
575                            _ => string_values.push(val.to_string()),
576                        }
577                    }
578                }
579            }
580
581            Ok(Value::String(string_values.join(sep)))
582        }
583        AggregationFunction::Median(col_name) => {
584            let mut numeric_values = Vec::with_capacity(group_items.len());
585
586            for item in group_items {
587                if let Value::Object(obj) = item {
588                    if let Some(val) = obj.get(col_name) {
589                        match val {
590                            Value::Int(i) => {
591                                #[allow(clippy::cast_precision_loss)]
592                                {
593                                    numeric_values.push(*i as f64);
594                                }
595                            }
596                            Value::Float(f) => numeric_values.push(*f),
597                            Value::Null => {} // Skip nulls
598                            _ => {
599                                return Err(TypeError::UnsupportedOperation {
600                                    operation: "median".to_string(),
601                                    typ: val.type_name().to_string(),
602                                }
603                                .into());
604                            }
605                        }
606                    }
607                }
608            }
609
610            if numeric_values.is_empty() {
611                return Ok(Value::Null);
612            }
613
614            numeric_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
615
616            let median = if numeric_values.len() % 2 == 0 {
617                let mid = numeric_values.len() / 2;
618                f64::midpoint(numeric_values[mid - 1], numeric_values[mid])
619            } else {
620                numeric_values[numeric_values.len() / 2]
621            };
622
623            Ok(Value::Float(median))
624        }
625        AggregationFunction::Std(col_name) => {
626            let mut numeric_values = Vec::with_capacity(group_items.len());
627
628            for item in group_items {
629                if let Value::Object(obj) = item {
630                    if let Some(val) = obj.get(col_name) {
631                        match val {
632                            Value::Int(i) => {
633                                #[allow(clippy::cast_precision_loss)]
634                                {
635                                    numeric_values.push(*i as f64);
636                                }
637                            }
638                            Value::Float(f) => numeric_values.push(*f),
639                            Value::Null => {} // Skip nulls
640                            _ => {
641                                return Err(TypeError::UnsupportedOperation {
642                                    operation: "std".to_string(),
643                                    typ: val.type_name().to_string(),
644                                }
645                                .into());
646                            }
647                        }
648                    }
649                }
650            }
651
652            if numeric_values.len() <= 1 {
653                return Ok(Value::Null);
654            }
655
656            #[allow(clippy::cast_precision_loss)]
657            let mean = numeric_values.iter().sum::<f64>() / numeric_values.len() as f64;
658            #[allow(clippy::cast_precision_loss)]
659            let variance = numeric_values
660                .iter()
661                .map(|x| (x - mean).powi(2))
662                .sum::<f64>()
663                / (numeric_values.len() - 1) as f64;
664
665            Ok(Value::Float(variance.sqrt()))
666        }
667        AggregationFunction::Var(col_name) => {
668            let mut numeric_values = Vec::with_capacity(group_items.len());
669
670            for item in group_items {
671                if let Value::Object(obj) = item {
672                    if let Some(val) = obj.get(col_name) {
673                        match val {
674                            Value::Int(i) => {
675                                #[allow(clippy::cast_precision_loss)]
676                                {
677                                    numeric_values.push(*i as f64);
678                                }
679                            }
680                            Value::Float(f) => numeric_values.push(*f),
681                            Value::Null => {} // Skip nulls
682                            _ => {
683                                return Err(TypeError::UnsupportedOperation {
684                                    operation: "var".to_string(),
685                                    typ: val.type_name().to_string(),
686                                }
687                                .into());
688                            }
689                        }
690                    }
691                }
692            }
693
694            if numeric_values.len() <= 1 {
695                return Ok(Value::Null);
696            }
697
698            #[allow(clippy::cast_precision_loss)]
699            let mean = numeric_values.iter().sum::<f64>() / numeric_values.len() as f64;
700            #[allow(clippy::cast_precision_loss)]
701            let variance = numeric_values
702                .iter()
703                .map(|x| (x - mean).powi(2))
704                .sum::<f64>()
705                / (numeric_values.len() - 1) as f64;
706
707            Ok(Value::Float(variance))
708        }
709    }
710}
711
712/// Compare values for ordering (used in min/max)
713fn compare_values_for_ordering(a: &Value, b: &Value) -> std::cmp::Ordering {
714    use std::cmp::Ordering;
715
716    match (a, b) {
717        (Value::Null, Value::Null) => Ordering::Equal,
718        (Value::Null, _) => Ordering::Less,
719        (_, Value::Null) => Ordering::Greater,
720
721        (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
722        (Value::Int(a), Value::Int(b)) => a.cmp(b),
723        (Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
724        (Value::String(a), Value::String(b)) => a.cmp(b),
725
726        // Cross-type numeric comparisons
727        #[allow(clippy::cast_precision_loss)]
728        (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal),
729        #[allow(clippy::cast_precision_loss)]
730        (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal),
731
732        // For complex types, compare string representations
733        _ => a.to_string().cmp(&b.to_string()),
734    }
735}
736
737/// Pivot a `DataFrame` (convert rows to columns)
738///
739/// Equivalent to SQL's PIVOT operation or Excel's pivot tables.
740///
741/// # Examples
742///
743/// ```rust,ignore
744/// use dsq_core::ops::aggregate::pivot;
745/// use dsq_core::value::Value;
746///
747/// let result = pivot(
748///     &dataframe_value,
749///     &["id".to_string()],           // index columns
750///     "category",                     // column to pivot
751///     "value",                       // values to aggregate
752///     Some("sum")                    // aggregation function
753/// ).unwrap();
754/// ```
755pub fn pivot(
756    value: &Value,
757    index_columns: &[String],
758    _pivot_column: &str,
759    value_column: &str,
760    agg_function: Option<&str>,
761) -> Result<Value> {
762    match value {
763        Value::DataFrame(df) => {
764            let agg_expr = match agg_function {
765                Some("sum") => col(value_column).sum().alias("value_sum"),
766                Some("mean") => col(value_column).mean().alias("value_mean"),
767                Some("count") => col(value_column).count().alias("value_count"),
768                Some("min") => col(value_column).min().alias("value_min"),
769                Some("max") => col(value_column).max().alias("value_max"),
770                Some("first") | None => col(value_column).first().alias("value_first"),
771                Some("last") => col(value_column).last().alias("value_last"),
772                _ => {
773                    return Err(Error::operation(format!(
774                        "Unsupported aggregation function: {}",
775                        agg_function.unwrap_or("")
776                    )));
777                }
778            };
779
780            // Pivot operation using group_by and aggregation
781            // This is a simplified implementation - full pivot would require more complex logic
782            let pivoted = df
783                .clone()
784                .lazy()
785                .group_by(index_columns.iter().map(col).collect::<Vec<_>>())
786                .agg([agg_expr])
787                .collect()
788                .map_err(Error::from)?;
789
790            Ok(Value::DataFrame(pivoted))
791        }
792        Value::LazyFrame(lf) => {
793            let agg_expr = match agg_function {
794                Some("sum") => col(value_column).sum().alias("value_sum"),
795                Some("mean") => col(value_column).mean(),
796                Some("count") => col(value_column).count(),
797                Some("min") => col(value_column).min(),
798                Some("max") => col(value_column).max(),
799                Some("first") | None => col(value_column).first(),
800                Some("last") => col(value_column).last(),
801                _ => {
802                    return Err(Error::operation(format!(
803                        "Unsupported aggregation function: {}",
804                        agg_function.unwrap_or("")
805                    )));
806                }
807            };
808
809            // Pivot operation using group_by and aggregation
810            // This is a simplified implementation - full pivot would require more complex logic
811            let pivoted = lf
812                .clone()
813                .group_by(index_columns.iter().map(col).collect::<Vec<_>>())
814                .agg([agg_expr]);
815
816            Ok(Value::LazyFrame(Box::new(pivoted)))
817        }
818        _ => Err(TypeError::UnsupportedOperation {
819            operation: "pivot".to_string(),
820            typ: value.type_name().to_string(),
821        }
822        .into()),
823    }
824}
825
826/// Unpivot a `DataFrame` (convert columns to rows)
827///
828/// Equivalent to SQL's UNPIVOT operation or pandas' melt function.
829///
830/// # Examples
831///
832/// ```rust,ignore
833/// use dsq_core::ops::aggregate::unpivot;
834/// use dsq_core::value::Value;
835///
836/// let result = unpivot(
837///     &dataframe_value,
838///     &["id".to_string()],           // columns to keep as identifiers
839///     &["col1".to_string(), "col2".to_string()], // columns to unpivot
840///     "variable",                    // name for the variable column
841///     "value"                        // name for the value column
842/// ).unwrap();
843/// ```
844pub fn unpivot(
845    value: &Value,
846    id_columns: &[String],
847    value_columns: &[String],
848    variable_name: &str,
849    value_name: &str,
850) -> Result<Value> {
851    match value {
852        Value::DataFrame(df) => {
853            // Use unpivot method from UnpivotDF trait
854            let mut unpivoted = if id_columns.is_empty() {
855                df.clone()
856                    .unpivot([] as [&str; 0], value_columns)
857                    .map_err(Error::from)?
858            } else {
859                df.clone()
860                    .unpivot(id_columns, value_columns)
861                    .map_err(Error::from)?
862            };
863            unpivoted
864                .rename("variable", variable_name.into())
865                .map_err(Error::from)?;
866            unpivoted
867                .rename("value", value_name.into())
868                .map_err(Error::from)?;
869
870            Ok(Value::DataFrame(unpivoted))
871        }
872        Value::LazyFrame(lf) => {
873            let df = lf.clone().collect().map_err(Error::from)?;
874            unpivot(
875                &Value::DataFrame(df),
876                id_columns,
877                value_columns,
878                variable_name,
879                value_name,
880            )
881        }
882        _ => Err(TypeError::UnsupportedOperation {
883            operation: "unpivot".to_string(),
884            typ: value.type_name().to_string(),
885        }
886        .into()),
887    }
888}
889
890/// Rolling window aggregations
891///
892/// Apply aggregation functions over a rolling window of rows.
893///
894/// # Examples
895///
896/// ```rust,ignore
897/// use dsq_core::ops::aggregate::{rolling_agg, WindowFunction};
898/// use dsq_core::value::Value;
899///
900/// let result = rolling_agg(
901///     &dataframe_value,
902///     "value",                       // column to aggregate
903///     WindowFunction::Sum,           // aggregation function
904///     3,                            // window size
905///     None                          // min_periods (optional)
906/// ).unwrap();
907/// ```
908pub fn rolling_agg(
909    value: &Value,
910    _column: &str,
911    _function: WindowFunction,
912    window_size: usize,
913    min_periods: Option<usize>,
914) -> Result<Value> {
915    let _min_periods = min_periods.unwrap_or(window_size);
916
917    match value {
918        Value::DataFrame(_df) => {
919            // Rolling functions are not available in Polars 0.35 Expr API
920            // Use a simple implementation for now
921            Err(Error::operation(
922                "Rolling window functions not yet implemented",
923            ))
924        }
925        Value::LazyFrame(_lf) => {
926            // Rolling functions are not available in Polars 0.35 Expr API
927            // Use a simple implementation for now
928            Err(Error::operation(
929                "Rolling window functions not yet implemented",
930            ))
931        }
932        _ => Err(TypeError::UnsupportedOperation {
933            operation: "rolling_agg".to_string(),
934            typ: value.type_name().to_string(),
935        }
936        .into()),
937    }
938}
939
940/// Window functions for rolling aggregations
941#[derive(Debug, Clone)]
942pub enum WindowFunction {
943    /// Sum of values
944    Sum,
945    /// Mean (average) of values
946    Mean,
947    /// Minimum value
948    Min,
949    /// Maximum value
950    Max,
951    /// Count of values
952    Count,
953    /// Standard deviation
954    Std,
955    /// Variance
956    Var,
957}
958
959impl WindowFunction {
960    /// Get the function name as a string
961    #[must_use]
962    pub fn name(&self) -> &'static str {
963        match self {
964            WindowFunction::Sum => "sum",
965            WindowFunction::Mean => "mean",
966            WindowFunction::Min => "min",
967            WindowFunction::Max => "max",
968            WindowFunction::Count => "count",
969            WindowFunction::Std => "std",
970            WindowFunction::Var => "var",
971        }
972    }
973}
974
975/// Cumulative aggregations
976///
977/// Apply cumulative aggregation functions (running totals, etc.).
978///
979/// # Examples
980///
981/// ```rust,ignore
982/// use dsq_core::ops::aggregate::{cumulative_agg, WindowFunction};
983/// use dsq_core::value::Value;
984///
985/// let result = cumulative_agg(
986///     &dataframe_value,
987///     "value",                       // column to aggregate
988///     WindowFunction::Sum            // cumulative sum
989/// ).unwrap();
990/// ```
991#[allow(clippy::needless_pass_by_value)]
992pub fn cumulative_agg(value: &Value, _column: &str, function: WindowFunction) -> Result<Value> {
993    match value {
994        Value::DataFrame(_df) => {
995            // Cumulative functions need special window handling in polars
996            // For now, return an error indicating they're not implemented
997            Err(Error::operation(format!(
998                "Cumulative {} not yet implemented",
999                function.name()
1000            )))
1001        }
1002        Value::LazyFrame(_lf) => {
1003            // Cumulative functions need special window handling in polars
1004            // For now, return an error indicating they're not implemented
1005            Err(Error::operation(format!(
1006                "Cumulative {} not yet implemented",
1007                function.name()
1008            )))
1009        }
1010        _ => Err(TypeError::UnsupportedOperation {
1011            operation: "cumulative_agg".to_string(),
1012            typ: value.type_name().to_string(),
1013        }
1014        .into()),
1015    }
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020    use super::*;
1021    use std::collections::HashMap;
1022
1023    fn create_test_dataframe() -> DataFrame {
1024        df! {
1025            "department" => ["Sales", "Sales", "Marketing", "Marketing", "Engineering"],
1026            "employee" => ["Alice", "Bob", "Charlie", "Dave", "Eve"],
1027            "salary" => [50000, 55000, 60000, 65000, 80000],
1028            "age" => [25, 30, 35, 28, 32]
1029        }
1030        .unwrap()
1031    }
1032
1033    fn create_test_object(key: &str, value: Value) -> Value {
1034        Value::Object(HashMap::from([(key.to_string(), value)]))
1035    }
1036
1037    #[test]
1038    fn test_aggregation_functions() {
1039        // Test min/max with different types
1040        let test_values = vec![
1041            &Value::Int(10),
1042            &Value::Int(5),
1043            &Value::Int(20),
1044            &Value::Int(15),
1045        ];
1046
1047        // Test finding minimum
1048        let mut min_val: Option<&Value> = None;
1049        for val in &test_values {
1050            match min_val {
1051                None => min_val = Some(val),
1052                Some(current_min) => {
1053                    if compare_values_for_ordering(val, current_min) == std::cmp::Ordering::Less {
1054                        min_val = Some(val);
1055                    }
1056                }
1057            }
1058        }
1059
1060        assert_eq!(min_val, Some(&Value::Int(5)));
1061    }
1062
1063    #[test]
1064    fn test_pivot_unpivot() {
1065        let df = df! {
1066            "id" => [1, 2, 3],
1067            "category" => ["A", "B", "A"],
1068            "value" => [10, 20, 30]
1069        }
1070        .unwrap();
1071
1072        let value = Value::DataFrame(df);
1073
1074        // Test pivot
1075        let pivoted = pivot(
1076            &value,
1077            &["id".to_string()],
1078            "category",
1079            "value",
1080            Some("sum"),
1081        )
1082        .unwrap();
1083
1084        match pivoted {
1085            Value::DataFrame(df) => {
1086                assert!(df.width() >= 2); // At least id column and pivoted columns
1087            }
1088            _ => panic!("Expected DataFrame"),
1089        }
1090    }
1091
1092    #[test]
1093    fn test_aggregation_function_names() {
1094        let agg = AggregationFunction::Sum("salary".to_string());
1095        assert_eq!(agg.output_column_name(), "salary_sum");
1096
1097        let agg = AggregationFunction::Mean("age".to_string());
1098        assert_eq!(agg.output_column_name(), "age_mean");
1099
1100        let agg = AggregationFunction::Count;
1101        assert_eq!(agg.output_column_name(), "count");
1102    }
1103
1104    // #[test]
1105    // fn test_group_by_with_map_and_aggregation() {
1106    //     // Test the pattern from example_081: group_by(.department) | map({dept: .[0].department, count: length, avg_salary: (map(.salary) | add / length)})
1107    //     let df = df! {
1108    //         "id" => [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
1109    //         "name" => ["Alice Johnson", "Bob Smith", "Carol Williams", "David Brown", "Eve Davis", "Frank Miller", "Grace Wilson", "Henry Moore", "Ivy Taylor", "Jack Anderson"],
1110    //         "age" => [28, 34, 29, 41, 26, 38, 31, 45, 27, 33],
1111    //         "city" => ["New York", "Los Angeles", "Chicago", "Boston", "Miami", "Seattle", "Denver", "Austin", "Nashville", "Portland"],
1112    //         "salary" => [75000, 82000, 68000, 95000, 62000, 88000, 71000, 102000, 65000, 79000],
1113    //         "department" => ["Engineering", "Sales", "Marketing", "Engineering", "HR", "Sales", "Marketing", "Engineering", "HR", "Sales"]
1114    //     }.unwrap();
1115
1116    //     let value = Value::DataFrame(df);
1117
1118    //     // First, group by department
1119    //     let columns = vec!["department".to_string()];
1120    //     let grouped = group_by(&value, &columns).unwrap();
1121
1122    //     match grouped {
1123    //         Value::Array(groups) => {
1124    //             assert_eq!(groups.len(), 4); // Engineering, Sales, Marketing, HR
1125
1126    //             // For each group, simulate the map operation: {dept: .[0].department, count: length, avg_salary: (map(.salary) | add / length)}
1127    //             let mut results = Vec::new();
1128    //             for group in groups {
1129    //                 if let Value::Array(items) = group {
1130    //                     // Get department from first item
1131    //                     let dept = if let Some(Value::Object(first_obj)) = items.first() {
1132    //                         if let Some(Value::String(dept_str)) = first_obj.get("department") {
1133    //                             dept_str.clone()
1134    //                         } else {
1135    //                             continue;
1136    //                         }
1137    //                     } else {
1138    //                         continue;
1139    //                     };
1140
1141    //                     let count = items.len();
1142
1143    //                     // Calculate average salary
1144    //                     let mut total_salary = 0.0;
1145    //                     for item in &items {
1146    //                         if let Value::Object(obj) = item {
1147    //                             if let Some(Value::Int(salary)) = obj.get("salary") {
1148    //                                 total_salary += *salary as f64;
1149    //                             }
1150    //                         }
1151    //                     }
1152    //                     let avg_salary = total_salary / count as f64;
1153
1154    //                     results.push((dept, count, avg_salary));
1155    //                 }
1156    //             }
1157
1158    //             // Sort results by department for consistent testing
1159    //             results.sort_by(|a, b| a.0.cmp(&b.0));
1160
1161    //             // Verify results have correct structure and departments
1162    //             assert_eq!(results.len(), 4);
1163    //             let depts: Vec<&str> = results.iter().map(|(dept, _, _)| dept.as_str()).collect();
1164    //             assert!(depts.contains(&"Engineering".into()));
1165    //             assert!(depts.contains(&"HR".into()));
1166    //             assert!(depts.contains(&"Marketing".into()));
1167    //             assert!(depts.contains(&"Sales".into()));
1168
1169    //             // Check counts
1170    //             let eng_result = results
1171    //                 .iter()
1172    //                 .find(|(dept, _, _)| dept == "Engineering")
1173    //                 .unwrap();
1174    //             assert_eq!(eng_result.1, 3); // 3 engineers
1175    //             let hr_result = results.iter().find(|(dept, _, _)| dept == "HR").unwrap();
1176    //             assert_eq!(hr_result.1, 2); // 2 HR
1177    //         }
1178    //         _ => panic!("Expected Array"),
1179    //     }
1180    // }
1181
1182    #[test]
1183    fn test_string_concatenation() {
1184        let alice = Value::Object(HashMap::from([(
1185            "name".to_string(),
1186            Value::String("Alice".to_string()),
1187        )]));
1188        let bob = Value::Object(HashMap::from([(
1189            "name".to_string(),
1190            Value::String("Bob".to_string()),
1191        )]));
1192        let charlie = Value::Object(HashMap::from([(
1193            "name".to_string(),
1194            Value::String("Charlie".to_string()),
1195        )]));
1196
1197        let group_items = vec![&alice, &bob, &charlie];
1198
1199        let agg = AggregationFunction::StringConcat("name".to_string(), Some(", ".to_string()));
1200        let result = apply_aggregation_to_group(&agg, &group_items).unwrap();
1201
1202        assert_eq!(result, Value::String("Alice, Bob, Charlie".to_string()));
1203    }
1204
1205    #[test]
1206    fn test_median_aggregation() {
1207        let obj1 = Value::Object(HashMap::from([("value".to_string(), Value::Int(1))]));
1208        let obj2 = Value::Object(HashMap::from([("value".to_string(), Value::Int(3))]));
1209        let obj3 = Value::Object(HashMap::from([("value".to_string(), Value::Int(2))]));
1210        let items = vec![&obj1, &obj2, &obj3];
1211
1212        let agg = AggregationFunction::Median("value".to_string());
1213        let result = apply_aggregation_to_group(&agg, &items).unwrap();
1214        assert_eq!(result, Value::Float(2.0));
1215
1216        // Even number of items
1217        let obj4 = create_test_object("value", Value::Int(1));
1218        let obj5 = create_test_object("value", Value::Int(2));
1219        let obj6 = create_test_object("value", Value::Int(3));
1220        let obj7 = create_test_object("value", Value::Int(4));
1221        let items_even = vec![&obj4, &obj5, &obj6, &obj7];
1222
1223        let agg_even = AggregationFunction::Median("value".to_string());
1224        let result_even = apply_aggregation_to_group(&agg_even, &items_even).unwrap();
1225        assert_eq!(result_even, Value::Float(2.5));
1226
1227        let first_agg = AggregationFunction::First("value".to_string());
1228        let first_result = apply_aggregation_to_group(&first_agg, &items).unwrap();
1229        assert_eq!(first_result, Value::Int(1));
1230
1231        let last_agg = AggregationFunction::Last("value".to_string());
1232        let last_result = apply_aggregation_to_group(&last_agg, &items).unwrap();
1233        assert_eq!(last_result, Value::Int(2)); // Last item in [1, 3, 2] is 2
1234
1235        // Empty group
1236        let empty_items: Vec<&Value> = vec![];
1237        let first_empty = apply_aggregation_to_group(&first_agg, &empty_items).unwrap();
1238        assert_eq!(first_empty, Value::Null);
1239
1240        let last_empty = apply_aggregation_to_group(&last_agg, &empty_items).unwrap();
1241        assert_eq!(last_empty, Value::Null);
1242    }
1243
1244    #[test]
1245    fn test_list_aggregation() {
1246        let obj1 = create_test_object("value", Value::Int(1));
1247        let obj2 = create_test_object("value", Value::Int(2));
1248        let obj3 = create_test_object("value", Value::Null);
1249        let items = vec![&obj1, &obj2, &obj3];
1250
1251        let list_agg = AggregationFunction::List("value".to_string());
1252        let result = apply_aggregation_to_group(&list_agg, &items).unwrap();
1253
1254        match result {
1255            Value::Array(arr) => {
1256                assert_eq!(arr.len(), 3);
1257                assert_eq!(arr[0], Value::Int(1));
1258                assert_eq!(arr[1], Value::Int(2));
1259                assert_eq!(arr[2], Value::Null);
1260            }
1261            _ => panic!("Expected Array"),
1262        }
1263
1264        // Missing column
1265        let missing_obj = Value::Object(HashMap::from([("other".to_string(), Value::Int(1))]));
1266        let items_missing = vec![&missing_obj];
1267        let result_missing = apply_aggregation_to_group(&list_agg, &items_missing).unwrap();
1268        match result_missing {
1269            Value::Array(arr) => {
1270                assert_eq!(arr.len(), 1);
1271                assert_eq!(arr[0], Value::Null);
1272            }
1273            _ => panic!("Expected Array"),
1274        }
1275    }
1276
1277    #[test]
1278    fn test_count_unique_aggregation() {
1279        let obj1 = Value::Object(HashMap::from([("value".to_string(), Value::Int(1))]));
1280        let obj2 = Value::Object(HashMap::from([("value".to_string(), Value::Int(2))]));
1281        let obj3 = Value::Object(HashMap::from([("value".to_string(), Value::Int(1))]));
1282        let obj4 = Value::Object(HashMap::from([(
1283            "value".to_string(),
1284            Value::String("test".to_string()),
1285        )]));
1286        let items = vec![&obj1, &obj2, &obj3, &obj4];
1287
1288        let count_unique_agg = AggregationFunction::CountUnique("value".to_string());
1289        let result = apply_aggregation_to_group(&count_unique_agg, &items).unwrap();
1290        assert_eq!(result, Value::Int(3)); // 1, 2, "test"
1291
1292        // Empty group
1293        let empty_items: Vec<&Value> = vec![];
1294        let result_empty = apply_aggregation_to_group(&count_unique_agg, &empty_items).unwrap();
1295        assert_eq!(result_empty, Value::Int(0));
1296    }
1297
1298    #[test]
1299    fn test_sum_mean_with_nulls_and_mixed_types() {
1300        let v1 = Value::Object(HashMap::from([("value".to_string(), Value::Int(10))]));
1301        let v2 = Value::Object(HashMap::from([("value".to_string(), Value::Null)]));
1302        let v3 = Value::Object(HashMap::from([("value".to_string(), Value::Float(20.5))]));
1303        let v4 = Value::Object(HashMap::from([("value".to_string(), Value::Int(5))]));
1304        let items = vec![&v1, &v2, &v3, &v4];
1305
1306        let sum_agg = AggregationFunction::Sum("value".to_string());
1307        let sum_result = apply_aggregation_to_group(&sum_agg, &items).unwrap();
1308        assert_eq!(sum_result, Value::Float(35.5)); // 10 + 20.5 + 5
1309
1310        let mean_agg = AggregationFunction::Mean("value".to_string());
1311        let mean_result = apply_aggregation_to_group(&mean_agg, &items).unwrap();
1312        assert_eq!(mean_result, Value::Float(11.833333333333334)); // 35.5 / 3
1313
1314        // All nulls
1315        let null1 = Value::Object(HashMap::from([("value".to_string(), Value::Null)]));
1316        let null2 = Value::Object(HashMap::from([("value".to_string(), Value::Null)]));
1317        let null_items = vec![&null1, &null2];
1318        let sum_null = apply_aggregation_to_group(&sum_agg, &null_items).unwrap();
1319        assert_eq!(sum_null, Value::Null);
1320
1321        let mean_null = apply_aggregation_to_group(&mean_agg, &null_items).unwrap();
1322        assert_eq!(mean_null, Value::Null);
1323    }
1324
1325    #[test]
1326    fn test_min_max_with_different_types() {
1327        let v1 = Value::Object(HashMap::from([("int_val".to_string(), Value::Int(10))]));
1328        let v2 = Value::Object(HashMap::from([("int_val".to_string(), Value::Int(5))]));
1329        let v3 = Value::Object(HashMap::from([(
1330            "float_val".to_string(),
1331            Value::Float(7.5),
1332        )]));
1333        let v4 = Value::Object(HashMap::from([(
1334            "float_val".to_string(),
1335            Value::Float(12.3),
1336        )]));
1337        let v5 = Value::Object(HashMap::from([(
1338            "str_val".to_string(),
1339            Value::String("apple".to_string()),
1340        )]));
1341        let v6 = Value::Object(HashMap::from([(
1342            "str_val".to_string(),
1343            Value::String("banana".to_string()),
1344        )]));
1345        let items = vec![&v1, &v2, &v3, &v4, &v5, &v6];
1346
1347        let min_int = AggregationFunction::Min("int_val".to_string());
1348        let min_int_result = apply_aggregation_to_group(&min_int, &items).unwrap();
1349        assert_eq!(min_int_result, Value::Int(5));
1350
1351        let max_float = AggregationFunction::Max("float_val".to_string());
1352        let max_float_result = apply_aggregation_to_group(&max_float, &items).unwrap();
1353        assert_eq!(max_float_result, Value::Float(12.3));
1354
1355        let min_str = AggregationFunction::Min("str_val".to_string());
1356        let min_str_result = apply_aggregation_to_group(&min_str, &items).unwrap();
1357        assert_eq!(min_str_result, Value::String("apple".to_string()));
1358
1359        let max_str = AggregationFunction::Max("str_val".to_string());
1360        let max_str_result = apply_aggregation_to_group(&max_str, &items).unwrap();
1361        assert_eq!(max_str_result, Value::String("banana".to_string()));
1362    }
1363
1364    #[test]
1365    fn test_group_by_multiple_columns() {
1366        let array_value = Value::Array(vec![
1367            Value::Object(HashMap::from([
1368                ("dept".to_string(), Value::String("Sales".to_string())),
1369                ("region".to_string(), Value::String("North".to_string())),
1370                ("salary".to_string(), Value::Int(50000)),
1371            ])),
1372            Value::Object(HashMap::from([
1373                ("dept".to_string(), Value::String("Sales".to_string())),
1374                ("region".to_string(), Value::String("South".to_string())),
1375                ("salary".to_string(), Value::Int(55000)),
1376            ])),
1377            Value::Object(HashMap::from([
1378                ("dept".to_string(), Value::String("Sales".to_string())),
1379                ("region".to_string(), Value::String("North".to_string())),
1380                ("salary".to_string(), Value::Int(60000)),
1381            ])),
1382        ]);
1383
1384        let group_cols = vec!["dept".to_string(), "region".to_string()];
1385        let agg_funcs = vec![AggregationFunction::Sum("salary".to_string())];
1386
1387        let result = group_by_agg(&array_value, &group_cols, &agg_funcs).unwrap();
1388
1389        match result {
1390            Value::Array(arr) => {
1391                assert_eq!(arr.len(), 2); // Two groups: Sales-North and Sales-South
1392
1393                let mut found_north = false;
1394                let mut found_south = false;
1395
1396                for item in &arr {
1397                    if let Value::Object(obj) = item {
1398                        if let Some(Value::String(dept)) = obj.get("dept") {
1399                            if let Some(Value::String(region)) = obj.get("region") {
1400                                if let Some(Value::Int(sum)) = obj.get("salary_sum") {
1401                                    if *dept == "Sales" && *region == "North" && *sum == 110000 {
1402                                        found_north = true;
1403                                    } else if *dept == "Sales"
1404                                        && *region == "South"
1405                                        && *sum == 55000
1406                                    {
1407                                        found_south = true;
1408                                    }
1409                                }
1410                            }
1411                        }
1412                    }
1413                }
1414
1415                assert!(found_north, "North group not found or incorrect");
1416                assert!(found_south, "South group not found or incorrect");
1417            }
1418            _ => panic!("Expected Array"),
1419        }
1420    }
1421
1422    #[test]
1423    fn test_error_conditions() {
1424        // Empty group columns
1425        let array_value = Value::Array(vec![Value::Object(HashMap::from([(
1426            "value".to_string(),
1427            Value::Int(1),
1428        )]))]);
1429
1430        let result = group_by(&array_value, &[]);
1431        assert!(result.is_err());
1432        assert!(result
1433            .unwrap_err()
1434            .to_string()
1435            .contains("at least one column"));
1436
1437        let result_agg = group_by_agg(&array_value, &[], &[]);
1438        assert!(result_agg.is_err());
1439
1440        // Empty aggregations
1441        let result_agg_empty = group_by_agg(&array_value, &["value".to_string()], &[]);
1442        assert!(result_agg_empty.is_err());
1443
1444        // Unsupported type for group_by
1445        let int_value = Value::Int(42);
1446        let result_unsupported = group_by(&int_value, &["test".to_string()]);
1447        assert!(result_unsupported.is_err());
1448
1449        // Unsupported aggregation type
1450        let bool_val = Value::Object(HashMap::from([("value".to_string(), Value::Bool(true))]));
1451        let items = vec![&bool_val];
1452        let sum_agg = AggregationFunction::Sum("value".to_string());
1453        let result_type_error = apply_aggregation_to_group(&sum_agg, &items);
1454        assert!(result_type_error.is_err());
1455    }
1456
1457    #[test]
1458    fn test_pivot_current_behavior() {
1459        // Test that pivot currently just does group_by with aggregation
1460        let df = df! {
1461            "id" => [1, 2, 3],
1462            "category" => ["A", "B", "A"],
1463            "value" => [10, 20, 30]
1464        }
1465        .unwrap();
1466
1467        let value = Value::DataFrame(df);
1468
1469        let pivoted = pivot(
1470            &value,
1471            &["id".to_string()],
1472            "category",
1473            "value",
1474            Some("sum"),
1475        )
1476        .unwrap();
1477
1478        // Currently just returns grouped data, not actually pivoted
1479        match pivoted {
1480            Value::DataFrame(df) => {
1481                // Should have id and value_sum columns
1482                assert!(df
1483                    .get_column_names()
1484                    .iter()
1485                    .any(|name| name.as_str() == "id"));
1486                assert!(df
1487                    .get_column_names()
1488                    .iter()
1489                    .any(|name| name.as_str() == "value_sum"));
1490            }
1491            _ => panic!("Expected DataFrame"),
1492        }
1493    }
1494
1495    #[test]
1496    fn test_unpivot() {
1497        let df = df! {
1498            "id" => [1, 2],
1499            "A" => [10, 20],
1500            "B" => [30, 40]
1501        }
1502        .unwrap();
1503
1504        let value = Value::DataFrame(df);
1505
1506        let unpivoted = unpivot(
1507            &value,
1508            &["id".to_string()],
1509            &["A".to_string(), "B".to_string()],
1510            "category",
1511            "value",
1512        )
1513        .unwrap();
1514
1515        match unpivoted {
1516            Value::DataFrame(df) => {
1517                assert_eq!(df.height(), 2); // Current unpivot behavior
1518                assert!(df
1519                    .get_column_names()
1520                    .contains(&&PlSmallStr::from("category")));
1521                assert!(df.get_column_names().contains(&&PlSmallStr::from("value")));
1522            }
1523            _ => panic!("Expected DataFrame"),
1524        }
1525    }
1526
1527    #[test]
1528    fn test_rolling_agg_not_implemented() {
1529        let df = create_test_dataframe();
1530        let value = Value::DataFrame(df);
1531
1532        let result = rolling_agg(&value, "salary", WindowFunction::Sum, 3, None);
1533
1534        assert!(result.is_err());
1535        assert!(result
1536            .unwrap_err()
1537            .to_string()
1538            .contains("not yet implemented"));
1539    }
1540
1541    #[test]
1542    fn test_cumulative_agg_not_implemented() {
1543        let df = create_test_dataframe();
1544        let value = Value::DataFrame(df);
1545
1546        let result = cumulative_agg(&value, "salary", WindowFunction::Sum);
1547
1548        assert!(result.is_err());
1549        assert!(result
1550            .unwrap_err()
1551            .to_string()
1552            .contains("not yet implemented"));
1553    }
1554
1555    #[test]
1556    fn test_aggregation_function_to_polars_expr() {
1557        let sum_agg = AggregationFunction::Sum("salary".to_string());
1558        let _expr = sum_agg.to_polars_expr().unwrap();
1559        // Just check it doesn't panic and returns an expr
1560
1561        let count_agg = AggregationFunction::Count;
1562        let _expr_count = count_agg.to_polars_expr().unwrap();
1563
1564        let string_concat_agg =
1565            AggregationFunction::StringConcat("name".to_string(), Some(",".to_string()));
1566        let _expr_concat = string_concat_agg.to_polars_expr().unwrap();
1567    }
1568
1569    #[test]
1570    fn test_compare_values_for_ordering() {
1571        assert_eq!(
1572            compare_values_for_ordering(&Value::Int(1), &Value::Int(2)),
1573            std::cmp::Ordering::Less
1574        );
1575        assert_eq!(
1576            compare_values_for_ordering(&Value::Float(1.0), &Value::Float(2.0)),
1577            std::cmp::Ordering::Less
1578        );
1579        assert_eq!(
1580            compare_values_for_ordering(
1581                &Value::String("a".to_string()),
1582                &Value::String("b".to_string())
1583            ),
1584            std::cmp::Ordering::Less
1585        );
1586        assert_eq!(
1587            compare_values_for_ordering(&Value::Bool(false), &Value::Bool(true)),
1588            std::cmp::Ordering::Less
1589        );
1590        assert_eq!(
1591            compare_values_for_ordering(&Value::Null, &Value::Int(1)),
1592            std::cmp::Ordering::Less
1593        );
1594        assert_eq!(
1595            compare_values_for_ordering(&Value::Int(1), &Value::Null),
1596            std::cmp::Ordering::Greater
1597        );
1598        assert_eq!(
1599            compare_values_for_ordering(&Value::Null, &Value::Null),
1600            std::cmp::Ordering::Equal
1601        );
1602        assert_eq!(
1603            compare_values_for_ordering(&Value::Int(1), &Value::Float(1.0)),
1604            std::cmp::Ordering::Equal
1605        );
1606    }
1607}