1use crate::error::OxirsError;
11use crate::model::{Literal, Term};
12use crate::rdf_store::VariableBinding;
13use crate::sparql::modifiers::compare_terms;
14use crate::Result;
15use ahash::{AHashMap, AHashSet};
16use std::collections::hash_map::Entry;
17
18#[cfg(feature = "parallel")]
19use rayon::prelude::*;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum AggregateFunction {
24 Count,
25 Sum,
26 Avg,
27 Min,
28 Max,
29 GroupConcat {
30 separator: String,
31 },
32 Sample,
33 Median,
35 Variance,
36 StdDev,
37 Percentile {
38 percentile: u8,
39 }, }
41
42#[derive(Debug, Clone)]
44pub struct AggregateExpression {
45 pub function: AggregateFunction,
46 pub variable: Option<String>, pub alias: String,
48 pub distinct: bool, }
50
51#[derive(Debug, Clone)]
53pub struct GroupBySpec {
54 pub variables: Vec<String>,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
59struct GroupKey(Vec<TermHash>);
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
63enum TermHash {
64 NamedNode(String),
65 BlankNode(String),
66 Literal {
67 value: String,
68 datatype: Option<String>,
69 language: Option<String>,
70 },
71 Unbound,
72}
73
74impl From<&Term> for TermHash {
75 fn from(term: &Term) -> Self {
76 match term {
77 Term::NamedNode(n) => TermHash::NamedNode(n.as_str().to_string()),
78 Term::BlankNode(b) => TermHash::BlankNode(b.as_str().to_string()),
79 Term::Literal(l) => TermHash::Literal {
80 value: l.value().to_string(),
81 datatype: Some(l.datatype().as_str().to_string()),
82 language: l.language().map(|lang| lang.to_string()),
83 },
84 Term::Variable(v) => TermHash::NamedNode(format!("?{}", v.as_str())),
85 Term::QuotedTriple(qt) => TermHash::NamedNode(format!("<<{}>>", qt)),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92struct AggregateAccumulator {
93 function: AggregateFunction,
94 count: usize,
95 sum: f64,
96 values: Vec<Term>,
97 seen_values: AHashSet<TermHash>, min_value: Option<Term>,
99 max_value: Option<Term>,
100 concat_values: Vec<String>, sample_value: Option<Term>, distinct: bool,
103}
104
105impl AggregateAccumulator {
106 fn new(function: AggregateFunction, distinct: bool) -> Self {
108 Self {
109 function,
110 count: 0,
111 sum: 0.0,
112 values: Vec::new(),
113 seen_values: AHashSet::new(),
114 min_value: None,
115 max_value: None,
116 concat_values: Vec::new(),
117 sample_value: None,
118 distinct,
119 }
120 }
121
122 fn add_value(&mut self, term: Option<&Term>) {
124 let Some(term) = term else {
125 return;
126 };
127
128 if self.distinct {
130 let term_hash = TermHash::from(term);
131 if !self.seen_values.insert(term_hash) {
132 return; }
134 }
135
136 self.count += 1;
137
138 match &self.function {
139 AggregateFunction::Count => {
140 }
142 AggregateFunction::Sum | AggregateFunction::Avg => {
143 if let Term::Literal(lit) = term {
144 if let Ok(val) = lit.value().parse::<f64>() {
145 self.sum += val;
146 if matches!(self.function, AggregateFunction::Avg) {
147 self.values.push(term.clone());
148 }
149 }
150 }
151 }
152 AggregateFunction::Min => {
153 if let Some(ref current_min) = self.min_value {
154 if compare_terms(term, current_min).is_lt() {
155 self.min_value = Some(term.clone());
156 }
157 } else {
158 self.min_value = Some(term.clone());
159 }
160 }
161 AggregateFunction::Max => {
162 if let Some(ref current_max) = self.max_value {
163 if compare_terms(term, current_max).is_gt() {
164 self.max_value = Some(term.clone());
165 }
166 } else {
167 self.max_value = Some(term.clone());
168 }
169 }
170 AggregateFunction::GroupConcat { .. } => {
171 if let Term::Literal(lit) = term {
172 self.concat_values.push(lit.value().to_string());
173 } else {
174 self.concat_values.push(term.to_string());
175 }
176 }
177 AggregateFunction::Sample => {
178 if self.sample_value.is_none() {
179 self.sample_value = Some(term.clone());
180 }
181 }
182 AggregateFunction::Median
184 | AggregateFunction::Variance
185 | AggregateFunction::StdDev
186 | AggregateFunction::Percentile { .. } => {
187 if let Term::Literal(lit) = term {
188 if lit.value().parse::<f64>().is_ok() {
189 self.values.push(term.clone());
190 }
191 }
192 }
193 }
194 }
195
196 fn finalize(&self) -> Term {
198 match &self.function {
199 AggregateFunction::Count => Term::from(Literal::new(self.count.to_string())),
200 AggregateFunction::Sum => Term::from(Literal::new(self.sum.to_string())),
201 AggregateFunction::Avg => {
202 let avg = if self.count > 0 {
203 self.sum / self.count as f64
204 } else {
205 0.0
206 };
207 Term::from(Literal::new(avg.to_string()))
208 }
209 AggregateFunction::Min => self
210 .min_value
211 .clone()
212 .unwrap_or_else(|| Term::from(Literal::new(""))),
213 AggregateFunction::Max => self
214 .max_value
215 .clone()
216 .unwrap_or_else(|| Term::from(Literal::new(""))),
217 AggregateFunction::GroupConcat { separator } => {
218 let concatenated = self.concat_values.join(separator);
219 Term::from(Literal::new(concatenated))
220 }
221 AggregateFunction::Sample => self
222 .sample_value
223 .clone()
224 .unwrap_or_else(|| Term::from(Literal::new(""))),
225 AggregateFunction::Median => {
227 let result = compute_median(&self.values);
228 Term::from(Literal::new(result.to_string()))
229 }
230 AggregateFunction::Variance => {
231 let result = compute_variance(&self.values);
232 Term::from(Literal::new(result.to_string()))
233 }
234 AggregateFunction::StdDev => {
235 let variance = compute_variance(&self.values);
236 let stddev = variance.sqrt();
237 Term::from(Literal::new(stddev.to_string()))
238 }
239 AggregateFunction::Percentile { percentile } => {
240 let result = compute_percentile(&self.values, *percentile);
241 Term::from(Literal::new(result.to_string()))
242 }
243 }
244 }
245}
246
247pub fn extract_aggregates(sparql: &str) -> Result<Vec<AggregateExpression>> {
249 let mut aggregates = Vec::new();
250
251 if let Some(select_start) = sparql.to_uppercase().find("SELECT") {
252 if let Some(where_start) = sparql.to_uppercase().find("WHERE") {
253 let select_clause = &sparql[select_start + 6..where_start];
254
255 let mut pos = 0;
257 while pos < select_clause.len() {
258 if let Some(paren_start) = select_clause[pos..].find('(') {
259 let abs_pos = pos + paren_start;
260
261 if let Some(paren_end) = find_matching_paren(&select_clause[abs_pos..]) {
263 let expr = &select_clause[abs_pos..abs_pos + paren_end + 1];
264
265 let expr_upper = expr.to_uppercase();
267 let function = if expr_upper.starts_with("(COUNT") {
268 Some(AggregateFunction::Count)
269 } else if expr_upper.starts_with("(SUM") {
270 Some(AggregateFunction::Sum)
271 } else if expr_upper.starts_with("(AVG") {
272 Some(AggregateFunction::Avg)
273 } else if expr_upper.starts_with("(MIN") {
274 Some(AggregateFunction::Min)
275 } else if expr_upper.starts_with("(MAX") {
276 Some(AggregateFunction::Max)
277 } else {
278 None
279 };
280
281 if let Some(func) = function {
282 let inner = &expr[1..expr.len() - 1]; let func_name_end = if let Some(inner_paren) = inner.find('(') {
287 inner_paren
288 } else {
289 continue;
290 };
291
292 let after_func = &inner[func_name_end..];
294 let after_func_upper = after_func.to_uppercase();
295 let (var_part, alias_part) =
296 if let Some(as_pos) = after_func_upper.find(" AS ") {
297 (&after_func[1..as_pos], &after_func[as_pos + 4..])
298 } else {
299 (&after_func[1..], "")
300 };
301
302 let args_trimmed = var_part.trim_end_matches(')').trim();
303
304 let variable = if args_trimmed == "*" {
306 None
307 } else if let Some(var_name) = args_trimmed.strip_prefix('?') {
308 Some(var_name.to_string())
309 } else {
310 Some(args_trimmed.to_string())
311 };
312
313 let mut alias = String::from("aggregate");
315 if !alias_part.is_empty() {
316 for token in alias_part.split_whitespace() {
317 if let Some(var_name) = token.strip_prefix('?') {
318 alias = var_name.trim_end_matches(')').to_string();
319 break;
320 }
321 }
322 }
323
324 let distinct = expr_upper.contains("DISTINCT");
326
327 aggregates.push(AggregateExpression {
328 function: func,
329 variable,
330 alias,
331 distinct,
332 });
333 }
334
335 pos = abs_pos + paren_end + 1;
336 } else {
337 break;
338 }
339 } else {
340 break;
341 }
342 }
343 }
344 }
345
346 Ok(aggregates)
347}
348
349pub fn find_matching_paren(text: &str) -> Option<usize> {
351 let mut paren_count = 1;
352 let chars: Vec<char> = text.chars().collect();
353
354 for (i, &ch) in chars.iter().enumerate().skip(1) {
355 if ch == '(' {
356 paren_count += 1;
357 } else if ch == ')' {
358 paren_count -= 1;
359 if paren_count == 0 {
360 return Some(i);
361 }
362 }
363 }
364
365 None
366}
367
368pub fn apply_aggregates(
376 results: Vec<VariableBinding>,
377 aggregates: &[AggregateExpression],
378) -> Result<(Vec<VariableBinding>, Vec<String>)> {
379 if aggregates.is_empty() {
380 return Err(OxirsError::Query("No aggregates to apply".to_string()));
381 }
382
383 apply_aggregates_no_grouping(results, aggregates)
385}
386
387pub fn apply_aggregates_with_grouping(
391 results: Vec<VariableBinding>,
392 aggregates: &[AggregateExpression],
393 group_by: &GroupBySpec,
394) -> Result<(Vec<VariableBinding>, Vec<String>)> {
395 if aggregates.is_empty() {
396 return Err(OxirsError::Query("No aggregates to apply".to_string()));
397 }
398
399 let mut groups: AHashMap<GroupKey, Vec<VariableBinding>> = AHashMap::new();
401
402 for binding in results {
404 let key = extract_group_key(&binding, &group_by.variables);
405 match groups.entry(key) {
406 Entry::Occupied(mut entry) => {
407 entry.get_mut().push(binding);
408 }
409 Entry::Vacant(entry) => {
410 entry.insert(vec![binding]);
411 }
412 }
413 }
414
415 #[cfg(feature = "parallel")]
417 let group_results: Vec<_> = {
418 let groups_vec: Vec<_> = groups.into_iter().collect();
419 if groups_vec.len() > 10 {
420 groups_vec
422 .into_par_iter()
423 .map(|(key, group_bindings)| {
424 process_group(key, group_bindings, aggregates, &group_by.variables)
425 })
426 .collect::<Result<Vec<_>>>()?
427 } else {
428 groups_vec
429 .into_iter()
430 .map(|(key, group_bindings)| {
431 process_group(key, group_bindings, aggregates, &group_by.variables)
432 })
433 .collect::<Result<Vec<_>>>()?
434 }
435 };
436
437 #[cfg(not(feature = "parallel"))]
438 let group_results: Vec<_> = groups
439 .into_iter()
440 .map(|(key, group_bindings)| {
441 process_group(key, group_bindings, aggregates, &group_by.variables)
442 })
443 .collect::<Result<Vec<_>>>()?;
444
445 let mut result_variables = group_by.variables.clone();
447 for agg_expr in aggregates {
448 result_variables.push(agg_expr.alias.clone());
449 }
450
451 Ok((group_results, result_variables))
452}
453
454fn apply_aggregates_no_grouping(
456 results: Vec<VariableBinding>,
457 aggregates: &[AggregateExpression],
458) -> Result<(Vec<VariableBinding>, Vec<String>)> {
459 let mut result_variables = Vec::new();
460 let mut aggregate_binding = VariableBinding::new();
461
462 let mut accumulators: Vec<AggregateAccumulator> = aggregates
464 .iter()
465 .map(|agg| AggregateAccumulator::new(agg.function.clone(), agg.distinct))
466 .collect();
467
468 for binding in &results {
470 for (acc, agg_expr) in accumulators.iter_mut().zip(aggregates.iter()) {
471 let value = if let Some(var) = &agg_expr.variable {
472 binding.get(var)
473 } else {
474 Some(&Term::from(Literal::new("1")))
476 };
477 acc.add_value(value);
478 }
479 }
480
481 for (acc, agg_expr) in accumulators.iter().zip(aggregates.iter()) {
483 let value = acc.finalize();
484 aggregate_binding.bind(agg_expr.alias.clone(), value);
485 result_variables.push(agg_expr.alias.clone());
486 }
487
488 Ok((vec![aggregate_binding], result_variables))
489}
490
491fn extract_group_key(binding: &VariableBinding, group_vars: &[String]) -> GroupKey {
493 let key_terms: Vec<TermHash> = group_vars
494 .iter()
495 .map(|var| {
496 binding
497 .get(var)
498 .map(TermHash::from)
499 .unwrap_or(TermHash::Unbound)
500 })
501 .collect();
502 GroupKey(key_terms)
503}
504
505fn process_group(
507 _key: GroupKey,
508 group_bindings: Vec<VariableBinding>,
509 aggregates: &[AggregateExpression],
510 group_vars: &[String],
511) -> Result<VariableBinding> {
512 let mut result_binding = VariableBinding::new();
513
514 if let Some(first_binding) = group_bindings.first() {
516 for var in group_vars {
517 if let Some(value) = first_binding.get(var) {
518 result_binding.bind(var.clone(), value.clone());
519 }
520 }
521 }
522
523 let mut accumulators: Vec<AggregateAccumulator> = aggregates
525 .iter()
526 .map(|agg| AggregateAccumulator::new(agg.function.clone(), agg.distinct))
527 .collect();
528
529 for binding in &group_bindings {
531 for (acc, agg_expr) in accumulators.iter_mut().zip(aggregates.iter()) {
532 let value = if let Some(var) = &agg_expr.variable {
533 binding.get(var)
534 } else {
535 Some(&Term::from(Literal::new("1")))
537 };
538 acc.add_value(value);
539 }
540 }
541
542 for (acc, agg_expr) in accumulators.iter().zip(aggregates.iter()) {
544 let value = acc.finalize();
545 result_binding.bind(agg_expr.alias.clone(), value);
546 }
547
548 Ok(result_binding)
549}
550
551fn compute_median(values: &[Term]) -> f64 {
555 if values.is_empty() {
556 return 0.0;
557 }
558
559 let mut nums: Vec<f64> = values
560 .iter()
561 .filter_map(|term| {
562 if let Term::Literal(lit) = term {
563 lit.value().parse::<f64>().ok()
564 } else {
565 None
566 }
567 })
568 .collect();
569
570 if nums.is_empty() {
571 return 0.0;
572 }
573
574 nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
575
576 let len = nums.len();
577 if len % 2 == 0 {
578 (nums[len / 2 - 1] + nums[len / 2]) / 2.0
580 } else {
581 nums[len / 2]
583 }
584}
585
586fn compute_variance(values: &[Term]) -> f64 {
589 if values.len() < 2 {
590 return 0.0;
591 }
592
593 let nums: Vec<f64> = values
594 .iter()
595 .filter_map(|term| {
596 if let Term::Literal(lit) = term {
597 lit.value().parse::<f64>().ok()
598 } else {
599 None
600 }
601 })
602 .collect();
603
604 if nums.len() < 2 {
605 return 0.0;
606 }
607
608 let mean = nums.iter().sum::<f64>() / nums.len() as f64;
610
611 let squared_diffs: f64 = nums.iter().map(|x| (x - mean).powi(2)).sum();
613
614 squared_diffs / (nums.len() - 1) as f64
616}
617
618fn compute_percentile(values: &[Term], percentile: u8) -> f64 {
622 if values.is_empty() || percentile > 100 {
623 return 0.0;
624 }
625
626 let mut nums: Vec<f64> = values
627 .iter()
628 .filter_map(|term| {
629 if let Term::Literal(lit) = term {
630 lit.value().parse::<f64>().ok()
631 } else {
632 None
633 }
634 })
635 .collect();
636
637 if nums.is_empty() {
638 return 0.0;
639 }
640
641 nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
642
643 if percentile == 0 {
644 return nums[0];
645 }
646 if percentile == 100 {
647 return nums[nums.len() - 1];
648 }
649
650 let rank = (percentile as f64 / 100.0) * (nums.len() - 1) as f64;
652 let lower_index = rank.floor() as usize;
653 let upper_index = rank.ceil() as usize;
654
655 if lower_index == upper_index {
656 nums[lower_index]
657 } else {
658 let lower_value = nums[lower_index];
660 let upper_value = nums[upper_index];
661 let fraction = rank - lower_index as f64;
662 lower_value + fraction * (upper_value - lower_value)
663 }
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 fn create_test_binding(values: Vec<(&str, f64)>) -> VariableBinding {
671 let mut binding = VariableBinding::new();
672 for (var, val) in values {
673 binding.bind(var.to_string(), Term::from(Literal::new(val.to_string())));
674 }
675 binding
676 }
677
678 #[test]
679 fn test_count_aggregate() {
680 let results = vec![
681 create_test_binding(vec![("x", 1.0)]),
682 create_test_binding(vec![("x", 2.0)]),
683 create_test_binding(vec![("x", 3.0)]),
684 ];
685
686 let agg = AggregateExpression {
687 function: AggregateFunction::Count,
688 variable: Some("x".to_string()),
689 alias: "count".to_string(),
690 distinct: false,
691 };
692
693 let (result, vars) =
694 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
695 assert_eq!(result.len(), 1);
696 assert_eq!(vars, vec!["count"]);
697
698 if let Term::Literal(lit) = result[0].get("count").expect("binding should exist") {
699 assert_eq!(lit.value(), "3");
700 } else {
701 panic!("Expected literal");
702 }
703 }
704
705 #[test]
706 fn test_sum_aggregate() {
707 let results = vec![
708 create_test_binding(vec![("x", 10.0)]),
709 create_test_binding(vec![("x", 20.0)]),
710 create_test_binding(vec![("x", 30.0)]),
711 ];
712
713 let agg = AggregateExpression {
714 function: AggregateFunction::Sum,
715 variable: Some("x".to_string()),
716 alias: "sum".to_string(),
717 distinct: false,
718 };
719
720 let (result, _) =
721 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
722
723 if let Term::Literal(lit) = result[0].get("sum").expect("binding should exist") {
724 let sum: f64 = lit.value().parse().expect("parse should succeed");
725 assert!((sum - 60.0).abs() < 0.0001);
726 } else {
727 panic!("Expected literal");
728 }
729 }
730
731 #[test]
732 fn test_avg_aggregate() {
733 let results = vec![
734 create_test_binding(vec![("x", 10.0)]),
735 create_test_binding(vec![("x", 20.0)]),
736 create_test_binding(vec![("x", 30.0)]),
737 ];
738
739 let agg = AggregateExpression {
740 function: AggregateFunction::Avg,
741 variable: Some("x".to_string()),
742 alias: "avg".to_string(),
743 distinct: false,
744 };
745
746 let (result, _) =
747 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
748
749 if let Term::Literal(lit) = result[0].get("avg").expect("binding should exist") {
750 let avg: f64 = lit.value().parse().expect("parse should succeed");
751 assert!((avg - 20.0).abs() < 0.0001);
752 } else {
753 panic!("Expected literal");
754 }
755 }
756
757 #[test]
758 fn test_count_distinct() {
759 let results = vec![
760 create_test_binding(vec![("x", 1.0)]),
761 create_test_binding(vec![("x", 2.0)]),
762 create_test_binding(vec![("x", 1.0)]), create_test_binding(vec![("x", 3.0)]),
764 ];
765
766 let agg = AggregateExpression {
767 function: AggregateFunction::Count,
768 variable: Some("x".to_string()),
769 alias: "count".to_string(),
770 distinct: true, };
772
773 let (result, _) =
774 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
775
776 if let Term::Literal(lit) = result[0].get("count").expect("binding should exist") {
777 assert_eq!(lit.value(), "3"); } else {
779 panic!("Expected literal");
780 }
781 }
782
783 #[test]
784 fn test_group_concat() {
785 let mut binding1 = VariableBinding::new();
786 binding1.bind("x".to_string(), Term::from(Literal::new("apple")));
787 let mut binding2 = VariableBinding::new();
788 binding2.bind("x".to_string(), Term::from(Literal::new("banana")));
789 let mut binding3 = VariableBinding::new();
790 binding3.bind("x".to_string(), Term::from(Literal::new("cherry")));
791
792 let results = vec![binding1, binding2, binding3];
793
794 let agg = AggregateExpression {
795 function: AggregateFunction::GroupConcat {
796 separator: ", ".to_string(),
797 },
798 variable: Some("x".to_string()),
799 alias: "concat".to_string(),
800 distinct: false,
801 };
802
803 let (result, _) =
804 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
805
806 if let Term::Literal(lit) = result[0].get("concat").expect("binding should exist") {
807 assert_eq!(lit.value(), "apple, banana, cherry");
808 } else {
809 panic!("Expected literal");
810 }
811 }
812
813 #[test]
814 fn test_sample_aggregate() {
815 let results = vec![
816 create_test_binding(vec![("x", 10.0)]),
817 create_test_binding(vec![("x", 20.0)]),
818 create_test_binding(vec![("x", 30.0)]),
819 ];
820
821 let agg = AggregateExpression {
822 function: AggregateFunction::Sample,
823 variable: Some("x".to_string()),
824 alias: "sample".to_string(),
825 distinct: false,
826 };
827
828 let (result, _) =
829 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
830
831 assert!(result[0].get("sample").is_some());
833 }
834
835 #[test]
836 fn test_group_by_hash_based() {
837 let mut binding1 = VariableBinding::new();
839 binding1.bind("category".to_string(), Term::from(Literal::new("A")));
840 binding1.bind("value".to_string(), Term::from(Literal::new("10")));
841
842 let mut binding2 = VariableBinding::new();
843 binding2.bind("category".to_string(), Term::from(Literal::new("A")));
844 binding2.bind("value".to_string(), Term::from(Literal::new("20")));
845
846 let mut binding3 = VariableBinding::new();
847 binding3.bind("category".to_string(), Term::from(Literal::new("B")));
848 binding3.bind("value".to_string(), Term::from(Literal::new("30")));
849
850 let results = vec![binding1, binding2, binding3];
851
852 let agg = AggregateExpression {
853 function: AggregateFunction::Sum,
854 variable: Some("value".to_string()),
855 alias: "total".to_string(),
856 distinct: false,
857 };
858
859 let group_by = GroupBySpec {
860 variables: vec!["category".to_string()],
861 };
862
863 let (result, vars) = apply_aggregates_with_grouping(results, &[agg], &group_by)
864 .expect("aggregate operation should succeed");
865
866 assert_eq!(result.len(), 2);
868 assert_eq!(vars, vec!["category", "total"]);
869
870 for binding in &result {
872 if let Term::Literal(cat) = binding.get("category").expect("binding should exist") {
873 if let Term::Literal(total) = binding.get("total").expect("binding should exist") {
874 let total_val: f64 = total.value().parse().expect("parse should succeed");
875 if cat.value() == "A" {
876 assert!((total_val - 30.0).abs() < 0.0001); } else if cat.value() == "B" {
878 assert!((total_val - 30.0).abs() < 0.0001);
879 }
880 }
881 }
882 }
883 }
884
885 #[test]
886 fn test_multiple_aggregates() {
887 let results = vec![
888 create_test_binding(vec![("x", 10.0)]),
889 create_test_binding(vec![("x", 20.0)]),
890 create_test_binding(vec![("x", 30.0)]),
891 ];
892
893 let aggregates = vec![
894 AggregateExpression {
895 function: AggregateFunction::Count,
896 variable: Some("x".to_string()),
897 alias: "count".to_string(),
898 distinct: false,
899 },
900 AggregateExpression {
901 function: AggregateFunction::Sum,
902 variable: Some("x".to_string()),
903 alias: "sum".to_string(),
904 distinct: false,
905 },
906 AggregateExpression {
907 function: AggregateFunction::Avg,
908 variable: Some("x".to_string()),
909 alias: "avg".to_string(),
910 distinct: false,
911 },
912 ];
913
914 let (result, vars) =
915 apply_aggregates(results, &aggregates).expect("aggregate operation should succeed");
916 assert_eq!(result.len(), 1);
917 assert_eq!(vars, vec!["count", "sum", "avg"]);
918
919 assert!(result[0].get("count").is_some());
921 assert!(result[0].get("sum").is_some());
922 assert!(result[0].get("avg").is_some());
923 }
924
925 #[test]
926 fn test_median_aggregate() {
927 let results = vec![
929 create_test_binding(vec![("x", 1.0)]),
930 create_test_binding(vec![("x", 3.0)]),
931 create_test_binding(vec![("x", 5.0)]),
932 create_test_binding(vec![("x", 7.0)]),
933 create_test_binding(vec![("x", 9.0)]),
934 ];
935
936 let agg = AggregateExpression {
937 function: AggregateFunction::Median,
938 variable: Some("x".to_string()),
939 alias: "median".to_string(),
940 distinct: false,
941 };
942
943 let (result, _) =
944 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
945 if let Term::Literal(lit) = result[0].get("median").expect("binding should exist") {
946 let median: f64 = lit.value().parse().expect("parse should succeed");
947 assert!((median - 5.0).abs() < 0.001);
948 }
949
950 let results = vec![
952 create_test_binding(vec![("x", 2.0)]),
953 create_test_binding(vec![("x", 4.0)]),
954 create_test_binding(vec![("x", 6.0)]),
955 create_test_binding(vec![("x", 8.0)]),
956 ];
957
958 let agg = AggregateExpression {
959 function: AggregateFunction::Median,
960 variable: Some("x".to_string()),
961 alias: "median".to_string(),
962 distinct: false,
963 };
964
965 let (result, _) =
966 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
967 if let Term::Literal(lit) = result[0].get("median").expect("binding should exist") {
968 let median: f64 = lit.value().parse().expect("parse should succeed");
969 assert!((median - 5.0).abs() < 0.001); }
971 }
972
973 #[test]
974 fn test_variance_aggregate() {
975 let results = vec![
977 create_test_binding(vec![("x", 2.0)]),
978 create_test_binding(vec![("x", 4.0)]),
979 create_test_binding(vec![("x", 6.0)]),
980 create_test_binding(vec![("x", 8.0)]),
981 ];
982
983 let agg = AggregateExpression {
984 function: AggregateFunction::Variance,
985 variable: Some("x".to_string()),
986 alias: "variance".to_string(),
987 distinct: false,
988 };
989
990 let (result, _) =
991 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
992 if let Term::Literal(lit) = result[0].get("variance").expect("binding should exist") {
993 let variance: f64 = lit.value().parse().expect("parse should succeed");
994 assert!((variance - 6.666666666666667).abs() < 0.001);
996 }
997 }
998
999 #[test]
1000 fn test_stddev_aggregate() {
1001 let results = vec![
1003 create_test_binding(vec![("x", 2.0)]),
1004 create_test_binding(vec![("x", 4.0)]),
1005 create_test_binding(vec![("x", 6.0)]),
1006 create_test_binding(vec![("x", 8.0)]),
1007 ];
1008
1009 let agg = AggregateExpression {
1010 function: AggregateFunction::StdDev,
1011 variable: Some("x".to_string()),
1012 alias: "stddev".to_string(),
1013 distinct: false,
1014 };
1015
1016 let (result, _) =
1017 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
1018 if let Term::Literal(lit) = result[0].get("stddev").expect("binding should exist") {
1019 let stddev: f64 = lit.value().parse().expect("parse should succeed");
1020 assert!((stddev - 2.581988897471611).abs() < 0.001);
1022 }
1023 }
1024
1025 #[test]
1026 fn test_percentile_aggregate() {
1027 let results = vec![
1028 create_test_binding(vec![("x", 1.0)]),
1029 create_test_binding(vec![("x", 2.0)]),
1030 create_test_binding(vec![("x", 3.0)]),
1031 create_test_binding(vec![("x", 4.0)]),
1032 create_test_binding(vec![("x", 5.0)]),
1033 create_test_binding(vec![("x", 6.0)]),
1034 create_test_binding(vec![("x", 7.0)]),
1035 create_test_binding(vec![("x", 8.0)]),
1036 create_test_binding(vec![("x", 9.0)]),
1037 create_test_binding(vec![("x", 10.0)]),
1038 ];
1039
1040 let agg = AggregateExpression {
1042 function: AggregateFunction::Percentile { percentile: 50 },
1043 variable: Some("x".to_string()),
1044 alias: "p50".to_string(),
1045 distinct: false,
1046 };
1047
1048 let (result, _) =
1049 apply_aggregates(results.clone(), &[agg]).expect("aggregate operation should succeed");
1050 if let Term::Literal(lit) = result[0].get("p50").expect("binding should exist") {
1051 let p50: f64 = lit.value().parse().expect("parse should succeed");
1052 assert!((p50 - 5.5).abs() < 0.001);
1053 }
1054
1055 let agg = AggregateExpression {
1057 function: AggregateFunction::Percentile { percentile: 95 },
1058 variable: Some("x".to_string()),
1059 alias: "p95".to_string(),
1060 distinct: false,
1061 };
1062
1063 let (result, _) =
1064 apply_aggregates(results.clone(), &[agg]).expect("aggregate operation should succeed");
1065 if let Term::Literal(lit) = result[0].get("p95").expect("binding should exist") {
1066 let p95: f64 = lit.value().parse().expect("parse should succeed");
1067 assert!((p95 - 9.55).abs() < 0.01);
1068 }
1069
1070 let agg = AggregateExpression {
1072 function: AggregateFunction::Percentile { percentile: 25 },
1073 variable: Some("x".to_string()),
1074 alias: "p25".to_string(),
1075 distinct: false,
1076 };
1077
1078 let (result, _) =
1079 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
1080 if let Term::Literal(lit) = result[0].get("p25").expect("binding should exist") {
1081 let p25: f64 = lit.value().parse().expect("parse should succeed");
1082 assert!((p25 - 3.25).abs() < 0.01);
1083 }
1084 }
1085
1086 #[test]
1087 fn test_statistical_aggregates_with_grouping() {
1088 let mut binding1 = VariableBinding::new();
1090 binding1.bind("category".to_string(), Term::from(Literal::new("A")));
1091 binding1.bind("value".to_string(), Term::from(Literal::new("10")));
1092
1093 let mut binding2 = VariableBinding::new();
1094 binding2.bind("category".to_string(), Term::from(Literal::new("A")));
1095 binding2.bind("value".to_string(), Term::from(Literal::new("20")));
1096
1097 let mut binding3 = VariableBinding::new();
1098 binding3.bind("category".to_string(), Term::from(Literal::new("A")));
1099 binding3.bind("value".to_string(), Term::from(Literal::new("30")));
1100
1101 let mut binding4 = VariableBinding::new();
1102 binding4.bind("category".to_string(), Term::from(Literal::new("B")));
1103 binding4.bind("value".to_string(), Term::from(Literal::new("5")));
1104
1105 let mut binding5 = VariableBinding::new();
1106 binding5.bind("category".to_string(), Term::from(Literal::new("B")));
1107 binding5.bind("value".to_string(), Term::from(Literal::new("15")));
1108
1109 let results = vec![binding1, binding2, binding3, binding4, binding5];
1110
1111 let agg = AggregateExpression {
1112 function: AggregateFunction::Median,
1113 variable: Some("value".to_string()),
1114 alias: "median".to_string(),
1115 distinct: false,
1116 };
1117
1118 let group_by = GroupBySpec {
1119 variables: vec!["category".to_string()],
1120 };
1121
1122 let (result, _) = apply_aggregates_with_grouping(results, &[agg], &group_by)
1123 .expect("aggregate operation should succeed");
1124
1125 assert_eq!(result.len(), 2);
1127
1128 for binding in &result {
1130 if let Term::Literal(cat) = binding.get("category").expect("binding should exist") {
1131 if let Term::Literal(median) = binding.get("median").expect("binding should exist")
1132 {
1133 let median_val: f64 = median.value().parse().expect("parse should succeed");
1134 if cat.value() == "A" {
1135 assert!((median_val - 20.0).abs() < 0.001);
1137 } else if cat.value() == "B" {
1138 assert!((median_val - 10.0).abs() < 0.001);
1140 }
1141 }
1142 }
1143 }
1144 }
1145
1146 #[test]
1147 fn test_statistical_aggregate_edge_cases() {
1148 let results: Vec<VariableBinding> = vec![];
1150
1151 let agg = AggregateExpression {
1152 function: AggregateFunction::Median,
1153 variable: Some("x".to_string()),
1154 alias: "median".to_string(),
1155 distinct: false,
1156 };
1157
1158 let (result, _) =
1159 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
1160 if let Term::Literal(lit) = result[0].get("median").expect("binding should exist") {
1162 let median: f64 = lit.value().parse().expect("parse should succeed");
1163 assert_eq!(median, 0.0);
1164 }
1165
1166 let results = vec![create_test_binding(vec![("x", 5.0)])];
1168
1169 let agg = AggregateExpression {
1170 function: AggregateFunction::Variance,
1171 variable: Some("x".to_string()),
1172 alias: "variance".to_string(),
1173 distinct: false,
1174 };
1175
1176 let (result, _) =
1177 apply_aggregates(results, &[agg]).expect("aggregate operation should succeed");
1178 if let Term::Literal(lit) = result[0].get("variance").expect("binding should exist") {
1180 let variance: f64 = lit.value().parse().expect("parse should succeed");
1181 assert_eq!(variance, 0.0);
1182 }
1183 }
1184}