1use crate::{Error, Result, TypeError, Value};
13use polars::prelude::*;
14use smallvec::SmallVec;
15use std::collections::HashMap;
16
17fn any_value_to_value(any_val: &AnyValue) -> Result<Value> {
19 use serde_json::Value as JsonValue;
20 let json_val = match any_val {
21 AnyValue::Null => JsonValue::Null,
22 AnyValue::Boolean(b) => JsonValue::Bool(*b),
23 AnyValue::Int8(i) => JsonValue::Number(serde_json::Number::from(*i)),
24 AnyValue::Int16(i) => JsonValue::Number(serde_json::Number::from(*i)),
25 AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
26 AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
27 AnyValue::UInt8(i) => JsonValue::Number(serde_json::Number::from(*i)),
28 AnyValue::UInt16(i) => JsonValue::Number(serde_json::Number::from(*i)),
29 AnyValue::UInt32(i) => JsonValue::Number(serde_json::Number::from(*i)),
30 AnyValue::UInt64(i) => JsonValue::Number(serde_json::Number::from(*i)),
31 AnyValue::Float32(f) => JsonValue::Number(
32 serde_json::Number::from_f64(f64::from(*f))
33 .ok_or_else(|| Error::operation("Invalid float"))?,
34 ),
35 AnyValue::Float64(f) => JsonValue::Number(
36 serde_json::Number::from_f64(*f).ok_or_else(|| Error::operation("Invalid float"))?,
37 ),
38 AnyValue::String(s) => JsonValue::String((*s).to_string()),
39 _ => return Err(Error::operation("Unsupported AnyValue type")),
40 };
41 Ok(Value::from_json(json_val))
42}
43
44fn df_to_array(df: &DataFrame) -> Result<Vec<Value>> {
46 let columns = df.get_column_names();
47 let mut result = Vec::with_capacity(df.height());
48
49 for row_idx in 0..df.height() {
50 let mut obj = std::collections::HashMap::new();
51 for col_name in &columns {
52 let series = df.column(col_name).map_err(Error::from)?;
53 let any_val = series.get(row_idx).map_err(Error::from)?;
54 let value = any_value_to_value(&any_val)?;
55 obj.insert(col_name.to_string(), value);
56 }
57 result.push(Value::Object(obj));
58 }
59
60 Ok(result)
61}
62
63pub fn group_by(value: &Value, columns: &[String]) -> Result<Value> {
64 if columns.is_empty() {
65 return Err(Error::operation("Group by requires at least one column"));
66 }
67
68 match value {
69 Value::DataFrame(df) => {
70 let arr = df_to_array(df)?;
72 group_by(&Value::Array(arr), columns)
73 }
74 Value::LazyFrame(lf) => {
75 let grouped = lf
76 .clone()
77 .group_by(columns.iter().map(col).collect::<Vec<_>>())
78 .agg([col("*").count().alias("count")]);
79 Ok(Value::LazyFrame(Box::new(grouped)))
80 }
81 Value::Array(arr) => {
82 let mut groups: std::collections::BTreeMap<String, Vec<Value>> =
84 std::collections::BTreeMap::new();
85
86 for item in arr {
87 if let Value::Object(obj) = item {
88 let mut key_parts = Vec::new();
90 for col in columns {
91 if let Some(val) = obj.get(col) {
92 key_parts.push(format!("{val:?}"));
93 } else {
94 key_parts.push("null".to_string());
95 }
96 }
97 let key = key_parts.join("|");
98
99 groups.entry(key).or_default().push(item.clone());
100 } else {
101 return Err(TypeError::UnsupportedOperation {
102 operation: "group_by".to_string(),
103 typ: item.type_name().to_string(),
104 }
105 .into());
106 }
107 }
108
109 let grouped: Vec<Value> = groups.into_values().map(Value::Array).collect();
111
112 Ok(Value::Array(grouped))
113 }
114 _ => Err(TypeError::UnsupportedOperation {
115 operation: "group_by".to_string(),
116 typ: value.type_name().to_string(),
117 }
118 .into()),
119 }
120}
121
122pub fn group_by_agg(
139 value: &Value,
140 group_columns: &[String],
141 aggregations: &[AggregationFunction],
142) -> Result<Value> {
143 if group_columns.is_empty() {
144 return Err(Error::operation("Group by requires at least one column"));
145 }
146
147 if aggregations.is_empty() {
148 return Err(Error::operation(
149 "Aggregation requires at least one function",
150 ));
151 }
152
153 match value {
154 Value::DataFrame(df) => {
155 let group_exprs: Vec<Expr> = group_columns.iter().map(col).collect();
156 let agg_exprs: Vec<Expr> = aggregations
157 .iter()
158 .map(AggregationFunction::to_polars_expr)
159 .collect::<crate::Result<Vec<_>>>()?;
160
161 let grouped = df
162 .clone()
163 .lazy()
164 .group_by(group_exprs)
165 .agg(agg_exprs)
166 .collect()
167 .map_err(Error::from)?;
168
169 Ok(Value::DataFrame(grouped))
170 }
171 Value::LazyFrame(lf) => {
172 let group_exprs: Vec<Expr> = group_columns.iter().map(col).collect();
173 let agg_exprs: Vec<Expr> = aggregations
174 .iter()
175 .map(AggregationFunction::to_polars_expr)
176 .collect::<crate::Result<Vec<_>>>()?;
177
178 let grouped = lf.clone().group_by(group_exprs).agg(agg_exprs);
179
180 Ok(Value::LazyFrame(Box::new(grouped)))
181 }
182 Value::Array(arr) => group_by_agg_array(arr, group_columns, aggregations),
183 _ => Err(TypeError::UnsupportedOperation {
184 operation: "group_by_agg".to_string(),
185 typ: value.type_name().to_string(),
186 }
187 .into()),
188 }
189}
190
191#[derive(Debug, Clone)]
193pub enum AggregationFunction {
194 Count,
196 Sum(String),
198 Mean(String),
200 Median(String),
202 Min(String),
204 Max(String),
206 Std(String),
208 Var(String),
210 First(String),
212 Last(String),
214 List(String),
216 CountUnique(String),
218 StringConcat(String, Option<String>), }
221
222impl AggregationFunction {
223 pub fn to_polars_expr(&self) -> Result<Expr> {
225 match self {
226 AggregationFunction::Count => Ok(len().alias("count")),
227 AggregationFunction::Sum(col_name) => {
228 Ok(col(col_name).sum().alias(format!("{col_name}_sum")))
229 }
230 AggregationFunction::Mean(col_name) => {
231 Ok(col(col_name).mean().alias(format!("{col_name}_mean")))
232 }
233 AggregationFunction::Median(col_name) => {
234 Ok(col(col_name).median().alias(format!("{col_name}_median")))
235 }
236 AggregationFunction::Min(col_name) => {
237 Ok(col(col_name).min().alias(format!("{col_name}_min")))
238 }
239 AggregationFunction::Max(col_name) => {
240 Ok(col(col_name).max().alias(format!("{col_name}_max")))
241 }
242 AggregationFunction::Std(col_name) => {
243 Ok(col(col_name).std(1).alias(format!("{col_name}_std")))
244 }
245 AggregationFunction::Var(col_name) => {
246 Ok(col(col_name).var(1).alias(format!("{col_name}_var")))
247 }
248 AggregationFunction::First(col_name) => {
249 Ok(col(col_name).first().alias(format!("{col_name}_first")))
250 }
251 AggregationFunction::Last(col_name) => {
252 Ok(col(col_name).last().alias(format!("{col_name}_last")))
253 }
254 AggregationFunction::List(col_name) => {
255 Ok(col(col_name).alias(format!("{col_name}_list")))
256 }
257 AggregationFunction::CountUnique(col_name) => Ok(col(col_name)
258 .n_unique()
259 .alias(format!("{col_name}_nunique"))),
260 AggregationFunction::StringConcat(col_name, separator) => {
261 let _sep = separator.as_deref().unwrap_or(",");
262 Ok(col(col_name).alias(format!("{col_name}_concat")))
265 }
266 }
267 }
268
269 #[must_use]
271 pub fn output_column_name(&self) -> String {
272 match self {
273 AggregationFunction::Count => "count".to_string(),
274 AggregationFunction::Sum(col_name) => format!("{col_name}_sum"),
275 AggregationFunction::Mean(col_name) => format!("{col_name}_mean"),
276 AggregationFunction::Median(col_name) => format!("{col_name}_median"),
277 AggregationFunction::Min(col_name) => format!("{col_name}_min"),
278 AggregationFunction::Max(col_name) => format!("{col_name}_max"),
279 AggregationFunction::Std(col_name) => format!("{col_name}_std"),
280 AggregationFunction::Var(col_name) => format!("{col_name}_var"),
281 AggregationFunction::First(col_name) => format!("{col_name}_first"),
282 AggregationFunction::Last(col_name) => format!("{col_name}_last"),
283 AggregationFunction::List(col_name) => format!("{col_name}_list"),
284 AggregationFunction::CountUnique(col_name) => format!("{col_name}_nunique"),
285 AggregationFunction::StringConcat(col_name, _) => format!("{col_name}_concat"),
286 }
287 }
288}
289
290fn group_by_agg_array(
292 arr: &[Value],
293 group_columns: &[String],
294 aggregations: &[AggregationFunction],
295) -> Result<Value> {
296 let mut groups: std::collections::BTreeMap<String, Vec<&Value>> =
298 std::collections::BTreeMap::new();
299
300 for item in arr {
301 match item {
302 Value::Object(obj) => {
303 let mut key_parts: SmallVec<[String; 8]> = SmallVec::new();
305 for col in group_columns {
306 if let Some(val) = obj.get(col) {
307 let key_part = match val {
308 Value::String(s) => s.clone(),
309 Value::Int(i) => i.to_string(),
310 Value::BigInt(bi) => bi.to_string(),
311 Value::Float(f) => f.to_string(),
312 Value::Bool(b) => b.to_string(),
313 Value::Null => "null".to_string(),
314 _ => format!("{val:?}"), };
316 key_parts.push(key_part);
317 } else {
318 key_parts.push("null".to_string());
319 }
320 }
321 let key = key_parts.join("|");
322
323 groups.entry(key).or_default().push(item);
324 }
325 _ => {
326 return Err(TypeError::UnsupportedOperation {
327 operation: "group_by_agg".to_string(),
328 typ: item.type_name().to_string(),
329 }
330 .into());
331 }
332 }
333 }
334
335 let mut result_rows = Vec::new();
337
338 for (group_key, group_items) in groups {
339 let mut result_row = HashMap::new();
340
341 let key_parts: Vec<&str> = group_key.split('|').collect();
343 for (i, col) in group_columns.iter().enumerate() {
344 if let Some(key_part) = key_parts.get(i) {
345 let value = if *key_part == "null" {
347 Value::Null
348 } else if let Ok(int_val) = key_part.parse::<i64>() {
349 Value::Int(int_val)
350 } else if let Ok(float_val) = key_part.parse::<f64>() {
351 Value::Float(float_val)
352 } else if *key_part == "true" {
353 Value::Bool(true)
354 } else if *key_part == "false" {
355 Value::Bool(false)
356 } else {
357 let cleaned = key_part.trim_matches('"');
359 Value::String(cleaned.to_string())
360 };
361 result_row.insert(col.clone(), value);
362 }
363 }
364
365 for agg in aggregations {
367 let agg_result = apply_aggregation_to_group(agg, &group_items)?;
368 let col_name = agg.output_column_name();
369 result_row.insert(col_name, agg_result);
370 }
371
372 result_rows.push(Value::Object(result_row));
373 }
374
375 Ok(Value::Array(result_rows))
376}
377
378fn apply_aggregation_to_group(agg: &AggregationFunction, group_items: &[&Value]) -> Result<Value> {
380 match agg {
381 AggregationFunction::Count => Ok(Value::Int(
382 i64::try_from(group_items.len()).unwrap_or(i64::MAX),
383 )),
384 AggregationFunction::Sum(col_name) => {
385 let mut sum = 0.0;
386 let mut count = 0;
387
388 for item in group_items {
389 if let Value::Object(obj) = item {
390 if let Some(val) = obj.get(col_name) {
391 match val {
392 Value::Int(i) => {
393 #[allow(clippy::cast_precision_loss)]
394 {
395 sum += *i as f64;
396 }
397 count += 1;
398 }
399 Value::Float(f) => {
400 sum += f;
401 count += 1;
402 }
403 Value::Null => {} _ => {
405 return Err(TypeError::UnsupportedOperation {
406 operation: "sum".to_string(),
407 typ: val.type_name().to_string(),
408 }
409 .into());
410 }
411 }
412 }
413 }
414 }
415
416 if count == 0 {
417 Ok(Value::Null)
418 } else {
419 #[allow(clippy::cast_precision_loss)]
420 if sum.fract() == 0.0 && sum <= i64::MAX as f64 && sum >= i64::MIN as f64 {
421 #[allow(clippy::cast_possible_truncation)]
422 Ok(Value::Int(sum as i64))
423 } else {
424 Ok(Value::Float(sum))
425 }
426 }
427 }
428 AggregationFunction::Mean(col_name) => {
429 let mut sum = 0.0;
430 let mut count = 0;
431
432 for item in group_items {
433 if let Value::Object(obj) = item {
434 if let Some(val) = obj.get(col_name) {
435 match val {
436 Value::Int(i) => {
437 #[allow(clippy::cast_precision_loss)]
438 {
439 sum += *i as f64;
440 }
441 count += 1;
442 }
443 Value::Float(f) => {
444 sum += f;
445 count += 1;
446 }
447 Value::Null => {} _ => {
449 return Err(TypeError::UnsupportedOperation {
450 operation: "mean".to_string(),
451 typ: val.type_name().to_string(),
452 }
453 .into());
454 }
455 }
456 }
457 }
458 }
459
460 if count == 0 {
461 Ok(Value::Null)
462 } else {
463 Ok(Value::Float(sum / f64::from(count)))
464 }
465 }
466 AggregationFunction::Min(col_name) => {
467 let mut min_val: Option<&Value> = None;
468
469 for item in group_items {
470 if let Value::Object(obj) = item {
471 if let Some(val) = obj.get(col_name) {
472 if !matches!(val, Value::Null) {
473 match min_val {
474 None => min_val = Some(val),
475 Some(current_min) => {
476 if compare_values_for_ordering(val, current_min)
477 == std::cmp::Ordering::Less
478 {
479 min_val = Some(val);
480 }
481 }
482 }
483 }
484 }
485 }
486 }
487
488 Ok(min_val.map_or(Value::Null, Clone::clone))
489 }
490 AggregationFunction::Max(col_name) => {
491 let mut max_val: Option<&Value> = None;
492
493 for item in group_items {
494 if let Value::Object(obj) = item {
495 if let Some(val) = obj.get(col_name) {
496 if !matches!(val, Value::Null) {
497 match max_val {
498 None => max_val = Some(val),
499 Some(current_max) => {
500 if compare_values_for_ordering(val, current_max)
501 == std::cmp::Ordering::Greater
502 {
503 max_val = Some(val);
504 }
505 }
506 }
507 }
508 }
509 }
510 }
511
512 Ok(max_val.map_or(Value::Null, Clone::clone))
513 }
514 AggregationFunction::First(col_name) => {
515 for item in group_items {
516 if let Value::Object(obj) = item {
517 if let Some(val) = obj.get(col_name) {
518 return Ok(val.clone());
519 }
520 }
521 }
522 Ok(Value::Null)
523 }
524 AggregationFunction::Last(col_name) => {
525 for item in group_items.iter().rev() {
526 if let Value::Object(obj) = item {
527 if let Some(val) = obj.get(col_name) {
528 return Ok(val.clone());
529 }
530 }
531 }
532 Ok(Value::Null)
533 }
534 AggregationFunction::List(col_name) => {
535 let mut values: SmallVec<[Value; 16]> = SmallVec::new();
536
537 for item in group_items {
538 if let Value::Object(obj) = item {
539 if let Some(val) = obj.get(col_name) {
540 values.push(val.clone());
541 } else {
542 values.push(Value::Null);
543 }
544 }
545 }
546
547 Ok(Value::Array(values.into_vec()))
548 }
549 AggregationFunction::CountUnique(col_name) => {
550 let mut unique_values = std::collections::HashSet::new();
551
552 for item in group_items {
553 if let Value::Object(obj) = item {
554 if let Some(val) = obj.get(col_name) {
555 unique_values.insert(format!("{val:?}"));
556 }
557 }
558 }
559
560 #[allow(clippy::cast_possible_wrap)]
561 {
562 Ok(Value::Int(unique_values.len() as i64))
563 }
564 }
565 AggregationFunction::StringConcat(col_name, separator) => {
566 let mut string_values: SmallVec<[String; 16]> = SmallVec::new();
567 let sep = separator.as_deref().unwrap_or(",");
568
569 for item in group_items {
570 if let Value::Object(obj) = item {
571 if let Some(val) = obj.get(col_name) {
572 match val {
573 Value::String(s) => string_values.push(s.clone()),
574 Value::Null => {} _ => string_values.push(val.to_string()),
576 }
577 }
578 }
579 }
580
581 Ok(Value::String(string_values.join(sep)))
582 }
583 AggregationFunction::Median(col_name) => {
584 let mut numeric_values = Vec::with_capacity(group_items.len());
585
586 for item in group_items {
587 if let Value::Object(obj) = item {
588 if let Some(val) = obj.get(col_name) {
589 match val {
590 Value::Int(i) => {
591 #[allow(clippy::cast_precision_loss)]
592 {
593 numeric_values.push(*i as f64);
594 }
595 }
596 Value::Float(f) => numeric_values.push(*f),
597 Value::Null => {} _ => {
599 return Err(TypeError::UnsupportedOperation {
600 operation: "median".to_string(),
601 typ: val.type_name().to_string(),
602 }
603 .into());
604 }
605 }
606 }
607 }
608 }
609
610 if numeric_values.is_empty() {
611 return Ok(Value::Null);
612 }
613
614 numeric_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
615
616 let median = if numeric_values.len() % 2 == 0 {
617 let mid = numeric_values.len() / 2;
618 f64::midpoint(numeric_values[mid - 1], numeric_values[mid])
619 } else {
620 numeric_values[numeric_values.len() / 2]
621 };
622
623 Ok(Value::Float(median))
624 }
625 AggregationFunction::Std(col_name) => {
626 let mut numeric_values = Vec::with_capacity(group_items.len());
627
628 for item in group_items {
629 if let Value::Object(obj) = item {
630 if let Some(val) = obj.get(col_name) {
631 match val {
632 Value::Int(i) => {
633 #[allow(clippy::cast_precision_loss)]
634 {
635 numeric_values.push(*i as f64);
636 }
637 }
638 Value::Float(f) => numeric_values.push(*f),
639 Value::Null => {} _ => {
641 return Err(TypeError::UnsupportedOperation {
642 operation: "std".to_string(),
643 typ: val.type_name().to_string(),
644 }
645 .into());
646 }
647 }
648 }
649 }
650 }
651
652 if numeric_values.len() <= 1 {
653 return Ok(Value::Null);
654 }
655
656 #[allow(clippy::cast_precision_loss)]
657 let mean = numeric_values.iter().sum::<f64>() / numeric_values.len() as f64;
658 #[allow(clippy::cast_precision_loss)]
659 let variance = numeric_values
660 .iter()
661 .map(|x| (x - mean).powi(2))
662 .sum::<f64>()
663 / (numeric_values.len() - 1) as f64;
664
665 Ok(Value::Float(variance.sqrt()))
666 }
667 AggregationFunction::Var(col_name) => {
668 let mut numeric_values = Vec::with_capacity(group_items.len());
669
670 for item in group_items {
671 if let Value::Object(obj) = item {
672 if let Some(val) = obj.get(col_name) {
673 match val {
674 Value::Int(i) => {
675 #[allow(clippy::cast_precision_loss)]
676 {
677 numeric_values.push(*i as f64);
678 }
679 }
680 Value::Float(f) => numeric_values.push(*f),
681 Value::Null => {} _ => {
683 return Err(TypeError::UnsupportedOperation {
684 operation: "var".to_string(),
685 typ: val.type_name().to_string(),
686 }
687 .into());
688 }
689 }
690 }
691 }
692 }
693
694 if numeric_values.len() <= 1 {
695 return Ok(Value::Null);
696 }
697
698 #[allow(clippy::cast_precision_loss)]
699 let mean = numeric_values.iter().sum::<f64>() / numeric_values.len() as f64;
700 #[allow(clippy::cast_precision_loss)]
701 let variance = numeric_values
702 .iter()
703 .map(|x| (x - mean).powi(2))
704 .sum::<f64>()
705 / (numeric_values.len() - 1) as f64;
706
707 Ok(Value::Float(variance))
708 }
709 }
710}
711
712fn compare_values_for_ordering(a: &Value, b: &Value) -> std::cmp::Ordering {
714 use std::cmp::Ordering;
715
716 match (a, b) {
717 (Value::Null, Value::Null) => Ordering::Equal,
718 (Value::Null, _) => Ordering::Less,
719 (_, Value::Null) => Ordering::Greater,
720
721 (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
722 (Value::Int(a), Value::Int(b)) => a.cmp(b),
723 (Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
724 (Value::String(a), Value::String(b)) => a.cmp(b),
725
726 #[allow(clippy::cast_precision_loss)]
728 (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal),
729 #[allow(clippy::cast_precision_loss)]
730 (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal),
731
732 _ => a.to_string().cmp(&b.to_string()),
734 }
735}
736
737pub fn pivot(
756 value: &Value,
757 index_columns: &[String],
758 _pivot_column: &str,
759 value_column: &str,
760 agg_function: Option<&str>,
761) -> Result<Value> {
762 match value {
763 Value::DataFrame(df) => {
764 let agg_expr = match agg_function {
765 Some("sum") => col(value_column).sum().alias("value_sum"),
766 Some("mean") => col(value_column).mean().alias("value_mean"),
767 Some("count") => col(value_column).count().alias("value_count"),
768 Some("min") => col(value_column).min().alias("value_min"),
769 Some("max") => col(value_column).max().alias("value_max"),
770 Some("first") | None => col(value_column).first().alias("value_first"),
771 Some("last") => col(value_column).last().alias("value_last"),
772 _ => {
773 return Err(Error::operation(format!(
774 "Unsupported aggregation function: {}",
775 agg_function.unwrap_or("")
776 )));
777 }
778 };
779
780 let pivoted = df
783 .clone()
784 .lazy()
785 .group_by(index_columns.iter().map(col).collect::<Vec<_>>())
786 .agg([agg_expr])
787 .collect()
788 .map_err(Error::from)?;
789
790 Ok(Value::DataFrame(pivoted))
791 }
792 Value::LazyFrame(lf) => {
793 let agg_expr = match agg_function {
794 Some("sum") => col(value_column).sum().alias("value_sum"),
795 Some("mean") => col(value_column).mean(),
796 Some("count") => col(value_column).count(),
797 Some("min") => col(value_column).min(),
798 Some("max") => col(value_column).max(),
799 Some("first") | None => col(value_column).first(),
800 Some("last") => col(value_column).last(),
801 _ => {
802 return Err(Error::operation(format!(
803 "Unsupported aggregation function: {}",
804 agg_function.unwrap_or("")
805 )));
806 }
807 };
808
809 let pivoted = lf
812 .clone()
813 .group_by(index_columns.iter().map(col).collect::<Vec<_>>())
814 .agg([agg_expr]);
815
816 Ok(Value::LazyFrame(Box::new(pivoted)))
817 }
818 _ => Err(TypeError::UnsupportedOperation {
819 operation: "pivot".to_string(),
820 typ: value.type_name().to_string(),
821 }
822 .into()),
823 }
824}
825
826pub fn unpivot(
845 value: &Value,
846 id_columns: &[String],
847 value_columns: &[String],
848 variable_name: &str,
849 value_name: &str,
850) -> Result<Value> {
851 match value {
852 Value::DataFrame(df) => {
853 let mut unpivoted = if id_columns.is_empty() {
855 df.clone()
856 .unpivot([] as [&str; 0], value_columns)
857 .map_err(Error::from)?
858 } else {
859 df.clone()
860 .unpivot(id_columns, value_columns)
861 .map_err(Error::from)?
862 };
863 unpivoted
864 .rename("variable", variable_name.into())
865 .map_err(Error::from)?;
866 unpivoted
867 .rename("value", value_name.into())
868 .map_err(Error::from)?;
869
870 Ok(Value::DataFrame(unpivoted))
871 }
872 Value::LazyFrame(lf) => {
873 let df = lf.clone().collect().map_err(Error::from)?;
874 unpivot(
875 &Value::DataFrame(df),
876 id_columns,
877 value_columns,
878 variable_name,
879 value_name,
880 )
881 }
882 _ => Err(TypeError::UnsupportedOperation {
883 operation: "unpivot".to_string(),
884 typ: value.type_name().to_string(),
885 }
886 .into()),
887 }
888}
889
890pub fn rolling_agg(
909 value: &Value,
910 _column: &str,
911 _function: WindowFunction,
912 window_size: usize,
913 min_periods: Option<usize>,
914) -> Result<Value> {
915 let _min_periods = min_periods.unwrap_or(window_size);
916
917 match value {
918 Value::DataFrame(_df) => {
919 Err(Error::operation(
922 "Rolling window functions not yet implemented",
923 ))
924 }
925 Value::LazyFrame(_lf) => {
926 Err(Error::operation(
929 "Rolling window functions not yet implemented",
930 ))
931 }
932 _ => Err(TypeError::UnsupportedOperation {
933 operation: "rolling_agg".to_string(),
934 typ: value.type_name().to_string(),
935 }
936 .into()),
937 }
938}
939
940#[derive(Debug, Clone)]
942pub enum WindowFunction {
943 Sum,
945 Mean,
947 Min,
949 Max,
951 Count,
953 Std,
955 Var,
957}
958
959impl WindowFunction {
960 #[must_use]
962 pub fn name(&self) -> &'static str {
963 match self {
964 WindowFunction::Sum => "sum",
965 WindowFunction::Mean => "mean",
966 WindowFunction::Min => "min",
967 WindowFunction::Max => "max",
968 WindowFunction::Count => "count",
969 WindowFunction::Std => "std",
970 WindowFunction::Var => "var",
971 }
972 }
973}
974
975#[allow(clippy::needless_pass_by_value)]
992pub fn cumulative_agg(value: &Value, _column: &str, function: WindowFunction) -> Result<Value> {
993 match value {
994 Value::DataFrame(_df) => {
995 Err(Error::operation(format!(
998 "Cumulative {} not yet implemented",
999 function.name()
1000 )))
1001 }
1002 Value::LazyFrame(_lf) => {
1003 Err(Error::operation(format!(
1006 "Cumulative {} not yet implemented",
1007 function.name()
1008 )))
1009 }
1010 _ => Err(TypeError::UnsupportedOperation {
1011 operation: "cumulative_agg".to_string(),
1012 typ: value.type_name().to_string(),
1013 }
1014 .into()),
1015 }
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020 use super::*;
1021 use std::collections::HashMap;
1022
1023 fn create_test_dataframe() -> DataFrame {
1024 df! {
1025 "department" => ["Sales", "Sales", "Marketing", "Marketing", "Engineering"],
1026 "employee" => ["Alice", "Bob", "Charlie", "Dave", "Eve"],
1027 "salary" => [50000, 55000, 60000, 65000, 80000],
1028 "age" => [25, 30, 35, 28, 32]
1029 }
1030 .unwrap()
1031 }
1032
1033 fn create_test_object(key: &str, value: Value) -> Value {
1034 Value::Object(HashMap::from([(key.to_string(), value)]))
1035 }
1036
1037 #[test]
1038 fn test_aggregation_functions() {
1039 let test_values = vec![
1041 &Value::Int(10),
1042 &Value::Int(5),
1043 &Value::Int(20),
1044 &Value::Int(15),
1045 ];
1046
1047 let mut min_val: Option<&Value> = None;
1049 for val in &test_values {
1050 match min_val {
1051 None => min_val = Some(val),
1052 Some(current_min) => {
1053 if compare_values_for_ordering(val, current_min) == std::cmp::Ordering::Less {
1054 min_val = Some(val);
1055 }
1056 }
1057 }
1058 }
1059
1060 assert_eq!(min_val, Some(&Value::Int(5)));
1061 }
1062
1063 #[test]
1064 fn test_pivot_unpivot() {
1065 let df = df! {
1066 "id" => [1, 2, 3],
1067 "category" => ["A", "B", "A"],
1068 "value" => [10, 20, 30]
1069 }
1070 .unwrap();
1071
1072 let value = Value::DataFrame(df);
1073
1074 let pivoted = pivot(
1076 &value,
1077 &["id".to_string()],
1078 "category",
1079 "value",
1080 Some("sum"),
1081 )
1082 .unwrap();
1083
1084 match pivoted {
1085 Value::DataFrame(df) => {
1086 assert!(df.width() >= 2); }
1088 _ => panic!("Expected DataFrame"),
1089 }
1090 }
1091
1092 #[test]
1093 fn test_aggregation_function_names() {
1094 let agg = AggregationFunction::Sum("salary".to_string());
1095 assert_eq!(agg.output_column_name(), "salary_sum");
1096
1097 let agg = AggregationFunction::Mean("age".to_string());
1098 assert_eq!(agg.output_column_name(), "age_mean");
1099
1100 let agg = AggregationFunction::Count;
1101 assert_eq!(agg.output_column_name(), "count");
1102 }
1103
1104 #[test]
1183 fn test_string_concatenation() {
1184 let alice = Value::Object(HashMap::from([(
1185 "name".to_string(),
1186 Value::String("Alice".to_string()),
1187 )]));
1188 let bob = Value::Object(HashMap::from([(
1189 "name".to_string(),
1190 Value::String("Bob".to_string()),
1191 )]));
1192 let charlie = Value::Object(HashMap::from([(
1193 "name".to_string(),
1194 Value::String("Charlie".to_string()),
1195 )]));
1196
1197 let group_items = vec![&alice, &bob, &charlie];
1198
1199 let agg = AggregationFunction::StringConcat("name".to_string(), Some(", ".to_string()));
1200 let result = apply_aggregation_to_group(&agg, &group_items).unwrap();
1201
1202 assert_eq!(result, Value::String("Alice, Bob, Charlie".to_string()));
1203 }
1204
1205 #[test]
1206 fn test_median_aggregation() {
1207 let obj1 = Value::Object(HashMap::from([("value".to_string(), Value::Int(1))]));
1208 let obj2 = Value::Object(HashMap::from([("value".to_string(), Value::Int(3))]));
1209 let obj3 = Value::Object(HashMap::from([("value".to_string(), Value::Int(2))]));
1210 let items = vec![&obj1, &obj2, &obj3];
1211
1212 let agg = AggregationFunction::Median("value".to_string());
1213 let result = apply_aggregation_to_group(&agg, &items).unwrap();
1214 assert_eq!(result, Value::Float(2.0));
1215
1216 let obj4 = create_test_object("value", Value::Int(1));
1218 let obj5 = create_test_object("value", Value::Int(2));
1219 let obj6 = create_test_object("value", Value::Int(3));
1220 let obj7 = create_test_object("value", Value::Int(4));
1221 let items_even = vec![&obj4, &obj5, &obj6, &obj7];
1222
1223 let agg_even = AggregationFunction::Median("value".to_string());
1224 let result_even = apply_aggregation_to_group(&agg_even, &items_even).unwrap();
1225 assert_eq!(result_even, Value::Float(2.5));
1226
1227 let first_agg = AggregationFunction::First("value".to_string());
1228 let first_result = apply_aggregation_to_group(&first_agg, &items).unwrap();
1229 assert_eq!(first_result, Value::Int(1));
1230
1231 let last_agg = AggregationFunction::Last("value".to_string());
1232 let last_result = apply_aggregation_to_group(&last_agg, &items).unwrap();
1233 assert_eq!(last_result, Value::Int(2)); let empty_items: Vec<&Value> = vec![];
1237 let first_empty = apply_aggregation_to_group(&first_agg, &empty_items).unwrap();
1238 assert_eq!(first_empty, Value::Null);
1239
1240 let last_empty = apply_aggregation_to_group(&last_agg, &empty_items).unwrap();
1241 assert_eq!(last_empty, Value::Null);
1242 }
1243
1244 #[test]
1245 fn test_list_aggregation() {
1246 let obj1 = create_test_object("value", Value::Int(1));
1247 let obj2 = create_test_object("value", Value::Int(2));
1248 let obj3 = create_test_object("value", Value::Null);
1249 let items = vec![&obj1, &obj2, &obj3];
1250
1251 let list_agg = AggregationFunction::List("value".to_string());
1252 let result = apply_aggregation_to_group(&list_agg, &items).unwrap();
1253
1254 match result {
1255 Value::Array(arr) => {
1256 assert_eq!(arr.len(), 3);
1257 assert_eq!(arr[0], Value::Int(1));
1258 assert_eq!(arr[1], Value::Int(2));
1259 assert_eq!(arr[2], Value::Null);
1260 }
1261 _ => panic!("Expected Array"),
1262 }
1263
1264 let missing_obj = Value::Object(HashMap::from([("other".to_string(), Value::Int(1))]));
1266 let items_missing = vec![&missing_obj];
1267 let result_missing = apply_aggregation_to_group(&list_agg, &items_missing).unwrap();
1268 match result_missing {
1269 Value::Array(arr) => {
1270 assert_eq!(arr.len(), 1);
1271 assert_eq!(arr[0], Value::Null);
1272 }
1273 _ => panic!("Expected Array"),
1274 }
1275 }
1276
1277 #[test]
1278 fn test_count_unique_aggregation() {
1279 let obj1 = Value::Object(HashMap::from([("value".to_string(), Value::Int(1))]));
1280 let obj2 = Value::Object(HashMap::from([("value".to_string(), Value::Int(2))]));
1281 let obj3 = Value::Object(HashMap::from([("value".to_string(), Value::Int(1))]));
1282 let obj4 = Value::Object(HashMap::from([(
1283 "value".to_string(),
1284 Value::String("test".to_string()),
1285 )]));
1286 let items = vec![&obj1, &obj2, &obj3, &obj4];
1287
1288 let count_unique_agg = AggregationFunction::CountUnique("value".to_string());
1289 let result = apply_aggregation_to_group(&count_unique_agg, &items).unwrap();
1290 assert_eq!(result, Value::Int(3)); let empty_items: Vec<&Value> = vec![];
1294 let result_empty = apply_aggregation_to_group(&count_unique_agg, &empty_items).unwrap();
1295 assert_eq!(result_empty, Value::Int(0));
1296 }
1297
1298 #[test]
1299 fn test_sum_mean_with_nulls_and_mixed_types() {
1300 let v1 = Value::Object(HashMap::from([("value".to_string(), Value::Int(10))]));
1301 let v2 = Value::Object(HashMap::from([("value".to_string(), Value::Null)]));
1302 let v3 = Value::Object(HashMap::from([("value".to_string(), Value::Float(20.5))]));
1303 let v4 = Value::Object(HashMap::from([("value".to_string(), Value::Int(5))]));
1304 let items = vec![&v1, &v2, &v3, &v4];
1305
1306 let sum_agg = AggregationFunction::Sum("value".to_string());
1307 let sum_result = apply_aggregation_to_group(&sum_agg, &items).unwrap();
1308 assert_eq!(sum_result, Value::Float(35.5)); let mean_agg = AggregationFunction::Mean("value".to_string());
1311 let mean_result = apply_aggregation_to_group(&mean_agg, &items).unwrap();
1312 assert_eq!(mean_result, Value::Float(11.833333333333334)); let null1 = Value::Object(HashMap::from([("value".to_string(), Value::Null)]));
1316 let null2 = Value::Object(HashMap::from([("value".to_string(), Value::Null)]));
1317 let null_items = vec![&null1, &null2];
1318 let sum_null = apply_aggregation_to_group(&sum_agg, &null_items).unwrap();
1319 assert_eq!(sum_null, Value::Null);
1320
1321 let mean_null = apply_aggregation_to_group(&mean_agg, &null_items).unwrap();
1322 assert_eq!(mean_null, Value::Null);
1323 }
1324
1325 #[test]
1326 fn test_min_max_with_different_types() {
1327 let v1 = Value::Object(HashMap::from([("int_val".to_string(), Value::Int(10))]));
1328 let v2 = Value::Object(HashMap::from([("int_val".to_string(), Value::Int(5))]));
1329 let v3 = Value::Object(HashMap::from([(
1330 "float_val".to_string(),
1331 Value::Float(7.5),
1332 )]));
1333 let v4 = Value::Object(HashMap::from([(
1334 "float_val".to_string(),
1335 Value::Float(12.3),
1336 )]));
1337 let v5 = Value::Object(HashMap::from([(
1338 "str_val".to_string(),
1339 Value::String("apple".to_string()),
1340 )]));
1341 let v6 = Value::Object(HashMap::from([(
1342 "str_val".to_string(),
1343 Value::String("banana".to_string()),
1344 )]));
1345 let items = vec![&v1, &v2, &v3, &v4, &v5, &v6];
1346
1347 let min_int = AggregationFunction::Min("int_val".to_string());
1348 let min_int_result = apply_aggregation_to_group(&min_int, &items).unwrap();
1349 assert_eq!(min_int_result, Value::Int(5));
1350
1351 let max_float = AggregationFunction::Max("float_val".to_string());
1352 let max_float_result = apply_aggregation_to_group(&max_float, &items).unwrap();
1353 assert_eq!(max_float_result, Value::Float(12.3));
1354
1355 let min_str = AggregationFunction::Min("str_val".to_string());
1356 let min_str_result = apply_aggregation_to_group(&min_str, &items).unwrap();
1357 assert_eq!(min_str_result, Value::String("apple".to_string()));
1358
1359 let max_str = AggregationFunction::Max("str_val".to_string());
1360 let max_str_result = apply_aggregation_to_group(&max_str, &items).unwrap();
1361 assert_eq!(max_str_result, Value::String("banana".to_string()));
1362 }
1363
1364 #[test]
1365 fn test_group_by_multiple_columns() {
1366 let array_value = Value::Array(vec![
1367 Value::Object(HashMap::from([
1368 ("dept".to_string(), Value::String("Sales".to_string())),
1369 ("region".to_string(), Value::String("North".to_string())),
1370 ("salary".to_string(), Value::Int(50000)),
1371 ])),
1372 Value::Object(HashMap::from([
1373 ("dept".to_string(), Value::String("Sales".to_string())),
1374 ("region".to_string(), Value::String("South".to_string())),
1375 ("salary".to_string(), Value::Int(55000)),
1376 ])),
1377 Value::Object(HashMap::from([
1378 ("dept".to_string(), Value::String("Sales".to_string())),
1379 ("region".to_string(), Value::String("North".to_string())),
1380 ("salary".to_string(), Value::Int(60000)),
1381 ])),
1382 ]);
1383
1384 let group_cols = vec!["dept".to_string(), "region".to_string()];
1385 let agg_funcs = vec![AggregationFunction::Sum("salary".to_string())];
1386
1387 let result = group_by_agg(&array_value, &group_cols, &agg_funcs).unwrap();
1388
1389 match result {
1390 Value::Array(arr) => {
1391 assert_eq!(arr.len(), 2); let mut found_north = false;
1394 let mut found_south = false;
1395
1396 for item in &arr {
1397 if let Value::Object(obj) = item {
1398 if let Some(Value::String(dept)) = obj.get("dept") {
1399 if let Some(Value::String(region)) = obj.get("region") {
1400 if let Some(Value::Int(sum)) = obj.get("salary_sum") {
1401 if *dept == "Sales" && *region == "North" && *sum == 110000 {
1402 found_north = true;
1403 } else if *dept == "Sales"
1404 && *region == "South"
1405 && *sum == 55000
1406 {
1407 found_south = true;
1408 }
1409 }
1410 }
1411 }
1412 }
1413 }
1414
1415 assert!(found_north, "North group not found or incorrect");
1416 assert!(found_south, "South group not found or incorrect");
1417 }
1418 _ => panic!("Expected Array"),
1419 }
1420 }
1421
1422 #[test]
1423 fn test_error_conditions() {
1424 let array_value = Value::Array(vec![Value::Object(HashMap::from([(
1426 "value".to_string(),
1427 Value::Int(1),
1428 )]))]);
1429
1430 let result = group_by(&array_value, &[]);
1431 assert!(result.is_err());
1432 assert!(result
1433 .unwrap_err()
1434 .to_string()
1435 .contains("at least one column"));
1436
1437 let result_agg = group_by_agg(&array_value, &[], &[]);
1438 assert!(result_agg.is_err());
1439
1440 let result_agg_empty = group_by_agg(&array_value, &["value".to_string()], &[]);
1442 assert!(result_agg_empty.is_err());
1443
1444 let int_value = Value::Int(42);
1446 let result_unsupported = group_by(&int_value, &["test".to_string()]);
1447 assert!(result_unsupported.is_err());
1448
1449 let bool_val = Value::Object(HashMap::from([("value".to_string(), Value::Bool(true))]));
1451 let items = vec![&bool_val];
1452 let sum_agg = AggregationFunction::Sum("value".to_string());
1453 let result_type_error = apply_aggregation_to_group(&sum_agg, &items);
1454 assert!(result_type_error.is_err());
1455 }
1456
1457 #[test]
1458 fn test_pivot_current_behavior() {
1459 let df = df! {
1461 "id" => [1, 2, 3],
1462 "category" => ["A", "B", "A"],
1463 "value" => [10, 20, 30]
1464 }
1465 .unwrap();
1466
1467 let value = Value::DataFrame(df);
1468
1469 let pivoted = pivot(
1470 &value,
1471 &["id".to_string()],
1472 "category",
1473 "value",
1474 Some("sum"),
1475 )
1476 .unwrap();
1477
1478 match pivoted {
1480 Value::DataFrame(df) => {
1481 assert!(df
1483 .get_column_names()
1484 .iter()
1485 .any(|name| name.as_str() == "id"));
1486 assert!(df
1487 .get_column_names()
1488 .iter()
1489 .any(|name| name.as_str() == "value_sum"));
1490 }
1491 _ => panic!("Expected DataFrame"),
1492 }
1493 }
1494
1495 #[test]
1496 fn test_unpivot() {
1497 let df = df! {
1498 "id" => [1, 2],
1499 "A" => [10, 20],
1500 "B" => [30, 40]
1501 }
1502 .unwrap();
1503
1504 let value = Value::DataFrame(df);
1505
1506 let unpivoted = unpivot(
1507 &value,
1508 &["id".to_string()],
1509 &["A".to_string(), "B".to_string()],
1510 "category",
1511 "value",
1512 )
1513 .unwrap();
1514
1515 match unpivoted {
1516 Value::DataFrame(df) => {
1517 assert_eq!(df.height(), 2); assert!(df
1519 .get_column_names()
1520 .contains(&&PlSmallStr::from("category")));
1521 assert!(df.get_column_names().contains(&&PlSmallStr::from("value")));
1522 }
1523 _ => panic!("Expected DataFrame"),
1524 }
1525 }
1526
1527 #[test]
1528 fn test_rolling_agg_not_implemented() {
1529 let df = create_test_dataframe();
1530 let value = Value::DataFrame(df);
1531
1532 let result = rolling_agg(&value, "salary", WindowFunction::Sum, 3, None);
1533
1534 assert!(result.is_err());
1535 assert!(result
1536 .unwrap_err()
1537 .to_string()
1538 .contains("not yet implemented"));
1539 }
1540
1541 #[test]
1542 fn test_cumulative_agg_not_implemented() {
1543 let df = create_test_dataframe();
1544 let value = Value::DataFrame(df);
1545
1546 let result = cumulative_agg(&value, "salary", WindowFunction::Sum);
1547
1548 assert!(result.is_err());
1549 assert!(result
1550 .unwrap_err()
1551 .to_string()
1552 .contains("not yet implemented"));
1553 }
1554
1555 #[test]
1556 fn test_aggregation_function_to_polars_expr() {
1557 let sum_agg = AggregationFunction::Sum("salary".to_string());
1558 let _expr = sum_agg.to_polars_expr().unwrap();
1559 let count_agg = AggregationFunction::Count;
1562 let _expr_count = count_agg.to_polars_expr().unwrap();
1563
1564 let string_concat_agg =
1565 AggregationFunction::StringConcat("name".to_string(), Some(",".to_string()));
1566 let _expr_concat = string_concat_agg.to_polars_expr().unwrap();
1567 }
1568
1569 #[test]
1570 fn test_compare_values_for_ordering() {
1571 assert_eq!(
1572 compare_values_for_ordering(&Value::Int(1), &Value::Int(2)),
1573 std::cmp::Ordering::Less
1574 );
1575 assert_eq!(
1576 compare_values_for_ordering(&Value::Float(1.0), &Value::Float(2.0)),
1577 std::cmp::Ordering::Less
1578 );
1579 assert_eq!(
1580 compare_values_for_ordering(
1581 &Value::String("a".to_string()),
1582 &Value::String("b".to_string())
1583 ),
1584 std::cmp::Ordering::Less
1585 );
1586 assert_eq!(
1587 compare_values_for_ordering(&Value::Bool(false), &Value::Bool(true)),
1588 std::cmp::Ordering::Less
1589 );
1590 assert_eq!(
1591 compare_values_for_ordering(&Value::Null, &Value::Int(1)),
1592 std::cmp::Ordering::Less
1593 );
1594 assert_eq!(
1595 compare_values_for_ordering(&Value::Int(1), &Value::Null),
1596 std::cmp::Ordering::Greater
1597 );
1598 assert_eq!(
1599 compare_values_for_ordering(&Value::Null, &Value::Null),
1600 std::cmp::Ordering::Equal
1601 );
1602 assert_eq!(
1603 compare_values_for_ordering(&Value::Int(1), &Value::Float(1.0)),
1604 std::cmp::Ordering::Equal
1605 );
1606 }
1607}