1use std::collections::HashMap;
22
23use super::super::engine::binding::{Binding, Value, Var};
24use super::value_compare::total_compare_values;
25
26pub trait Aggregator: Send + Sync {
32 fn accumulate(&mut self, value: Option<&Value>);
34
35 fn finalize(&self) -> Value;
37
38 fn reset(&mut self);
40
41 fn new_instance(&self) -> Box<dyn Aggregator>;
43
44 fn name(&self) -> &'static str;
46}
47
48#[derive(Debug, Clone, Default)]
54pub struct CountAggregator {
55 count: i64,
56 count_all: bool, }
58
59impl CountAggregator {
60 pub fn count_all() -> Self {
62 Self {
63 count: 0,
64 count_all: true,
65 }
66 }
67
68 pub fn count_column() -> Self {
70 Self {
71 count: 0,
72 count_all: false,
73 }
74 }
75}
76
77impl Aggregator for CountAggregator {
78 fn accumulate(&mut self, value: Option<&Value>) {
79 if self.count_all || (value.is_some() && !matches!(value, Some(Value::Null))) {
80 self.count += 1;
81 }
82 }
83
84 fn finalize(&self) -> Value {
85 Value::Integer(self.count)
86 }
87
88 fn reset(&mut self) {
89 self.count = 0;
90 }
91
92 fn new_instance(&self) -> Box<dyn Aggregator> {
93 Box::new(Self {
94 count: 0,
95 count_all: self.count_all,
96 })
97 }
98
99 fn name(&self) -> &'static str {
100 "COUNT"
101 }
102}
103
104#[derive(Debug, Clone, Default)]
110pub struct CountDistinctAggregator {
111 seen: std::collections::HashSet<String>,
112}
113
114impl CountDistinctAggregator {
115 pub fn new() -> Self {
116 Self {
117 seen: std::collections::HashSet::new(),
118 }
119 }
120}
121
122impl Aggregator for CountDistinctAggregator {
123 fn accumulate(&mut self, value: Option<&Value>) {
124 if let Some(v) = value {
125 if !matches!(v, Value::Null) {
126 self.seen.insert(value_to_string(v));
127 }
128 }
129 }
130
131 fn finalize(&self) -> Value {
132 Value::Integer(self.seen.len() as i64)
133 }
134
135 fn reset(&mut self) {
136 self.seen.clear();
137 }
138
139 fn new_instance(&self) -> Box<dyn Aggregator> {
140 Box::new(Self::new())
141 }
142
143 fn name(&self) -> &'static str {
144 "COUNT_DISTINCT"
145 }
146}
147
148#[derive(Debug, Clone, Default)]
154pub struct SumAggregator {
155 sum: f64,
156 has_value: bool,
157 all_integers: bool,
158}
159
160impl SumAggregator {
161 pub fn new() -> Self {
162 Self {
163 sum: 0.0,
164 has_value: false,
165 all_integers: true,
166 }
167 }
168}
169
170impl Aggregator for SumAggregator {
171 fn accumulate(&mut self, value: Option<&Value>) {
172 if let Some(v) = value {
173 match v {
174 Value::Integer(i) => {
175 self.sum += *i as f64;
176 self.has_value = true;
177 }
178 Value::Float(f) => {
179 self.sum += *f;
180 self.has_value = true;
181 if f.fract() != 0.0 {
182 self.all_integers = false;
183 }
184 }
185 _ => {}
186 }
187 }
188 }
189
190 fn finalize(&self) -> Value {
191 if self.has_value {
192 if self.all_integers && self.sum.fract() == 0.0 {
193 Value::Integer(self.sum as i64)
194 } else {
195 Value::Float(self.sum)
196 }
197 } else {
198 Value::Null
199 }
200 }
201
202 fn reset(&mut self) {
203 self.sum = 0.0;
204 self.has_value = false;
205 self.all_integers = true;
206 }
207
208 fn new_instance(&self) -> Box<dyn Aggregator> {
209 Box::new(Self::new())
210 }
211
212 fn name(&self) -> &'static str {
213 "SUM"
214 }
215}
216
217#[derive(Debug, Clone, Default)]
223pub struct AvgAggregator {
224 sum: f64,
225 count: i64,
226}
227
228impl AvgAggregator {
229 pub fn new() -> Self {
230 Self { sum: 0.0, count: 0 }
231 }
232}
233
234impl Aggregator for AvgAggregator {
235 fn accumulate(&mut self, value: Option<&Value>) {
236 if let Some(v) = value {
237 if let Some(n) = value_to_number(v) {
238 self.sum += n;
239 self.count += 1;
240 }
241 }
242 }
243
244 fn finalize(&self) -> Value {
245 if self.count > 0 {
246 Value::Float(self.sum / self.count as f64)
247 } else {
248 Value::Null
249 }
250 }
251
252 fn reset(&mut self) {
253 self.sum = 0.0;
254 self.count = 0;
255 }
256
257 fn new_instance(&self) -> Box<dyn Aggregator> {
258 Box::new(Self::new())
259 }
260
261 fn name(&self) -> &'static str {
262 "AVG"
263 }
264}
265
266#[derive(Debug, Clone, Default)]
272pub struct MinAggregator {
273 min: Option<Value>,
274}
275
276impl MinAggregator {
277 pub fn new() -> Self {
278 Self { min: None }
279 }
280}
281
282impl Aggregator for MinAggregator {
283 fn accumulate(&mut self, value: Option<&Value>) {
284 if let Some(v) = value {
285 if matches!(v, Value::Null) {
286 return;
287 }
288 match &self.min {
289 None => self.min = Some(v.clone()),
290 Some(current) => {
291 if total_compare_values(v, current) == std::cmp::Ordering::Less {
292 self.min = Some(v.clone());
293 }
294 }
295 }
296 }
297 }
298
299 fn finalize(&self) -> Value {
300 self.min.clone().unwrap_or(Value::Null)
301 }
302
303 fn reset(&mut self) {
304 self.min = None;
305 }
306
307 fn new_instance(&self) -> Box<dyn Aggregator> {
308 Box::new(Self::new())
309 }
310
311 fn name(&self) -> &'static str {
312 "MIN"
313 }
314}
315
316#[derive(Debug, Clone, Default)]
322pub struct MaxAggregator {
323 max: Option<Value>,
324}
325
326impl MaxAggregator {
327 pub fn new() -> Self {
328 Self { max: None }
329 }
330}
331
332impl Aggregator for MaxAggregator {
333 fn accumulate(&mut self, value: Option<&Value>) {
334 if let Some(v) = value {
335 if matches!(v, Value::Null) {
336 return;
337 }
338 match &self.max {
339 None => self.max = Some(v.clone()),
340 Some(current) => {
341 if total_compare_values(v, current) == std::cmp::Ordering::Greater {
342 self.max = Some(v.clone());
343 }
344 }
345 }
346 }
347 }
348
349 fn finalize(&self) -> Value {
350 self.max.clone().unwrap_or(Value::Null)
351 }
352
353 fn reset(&mut self) {
354 self.max = None;
355 }
356
357 fn new_instance(&self) -> Box<dyn Aggregator> {
358 Box::new(Self::new())
359 }
360
361 fn name(&self) -> &'static str {
362 "MAX"
363 }
364}
365
366#[derive(Debug, Clone, Default)]
372pub struct SampleAggregator {
373 value: Option<Value>,
374}
375
376impl SampleAggregator {
377 pub fn new() -> Self {
378 Self { value: None }
379 }
380}
381
382impl Aggregator for SampleAggregator {
383 fn accumulate(&mut self, value: Option<&Value>) {
384 if self.value.is_some() {
385 return;
386 }
387 if let Some(v) = value {
388 if !matches!(v, Value::Null) {
389 self.value = Some(v.clone());
390 }
391 }
392 }
393
394 fn finalize(&self) -> Value {
395 self.value.clone().unwrap_or(Value::Null)
396 }
397
398 fn reset(&mut self) {
399 self.value = None;
400 }
401
402 fn new_instance(&self) -> Box<dyn Aggregator> {
403 Box::new(Self::new())
404 }
405
406 fn name(&self) -> &'static str {
407 "SAMPLE"
408 }
409}
410
411#[derive(Debug, Clone)]
417pub struct GroupConcatAggregator {
418 separator: String,
419 values: Vec<String>,
420}
421
422impl GroupConcatAggregator {
423 pub fn new(separator: Option<String>) -> Self {
424 Self {
425 separator: separator.unwrap_or_else(|| " ".to_string()),
426 values: Vec::new(),
427 }
428 }
429}
430
431impl Aggregator for GroupConcatAggregator {
432 fn accumulate(&mut self, value: Option<&Value>) {
433 if let Some(v) = value {
434 if !matches!(v, Value::Null) {
435 self.values.push(value_to_string(v));
436 }
437 }
438 }
439
440 fn finalize(&self) -> Value {
441 if self.values.is_empty() {
442 Value::Null
443 } else {
444 Value::String(self.values.join(&self.separator))
445 }
446 }
447
448 fn reset(&mut self) {
449 self.values.clear();
450 }
451
452 fn new_instance(&self) -> Box<dyn Aggregator> {
453 Box::new(Self::new(Some(self.separator.clone())))
454 }
455
456 fn name(&self) -> &'static str {
457 "GROUP_CONCAT"
458 }
459}
460
461#[derive(Debug, Clone, Default)]
467pub struct StdDevAggregator {
468 values: Vec<f64>,
469}
470
471impl StdDevAggregator {
472 pub fn new() -> Self {
473 Self { values: Vec::new() }
474 }
475}
476
477impl Aggregator for StdDevAggregator {
478 fn accumulate(&mut self, value: Option<&Value>) {
479 if let Some(v) = value {
480 if let Some(n) = value_to_number(v) {
481 self.values.push(n);
482 }
483 }
484 }
485
486 fn finalize(&self) -> Value {
487 if self.values.is_empty() {
488 return Value::Null;
489 }
490
491 let n = self.values.len() as f64;
492 let mean = self.values.iter().sum::<f64>() / n;
493 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
494
495 Value::Float(variance.sqrt())
496 }
497
498 fn reset(&mut self) {
499 self.values.clear();
500 }
501
502 fn new_instance(&self) -> Box<dyn Aggregator> {
503 Box::new(Self::new())
504 }
505
506 fn name(&self) -> &'static str {
507 "STDDEV"
508 }
509}
510
511#[derive(Debug, Clone, Default)]
517pub struct VarianceAggregator {
518 values: Vec<f64>,
519}
520
521impl VarianceAggregator {
522 pub fn new() -> Self {
523 Self { values: Vec::new() }
524 }
525}
526
527impl Aggregator for VarianceAggregator {
528 fn accumulate(&mut self, value: Option<&Value>) {
529 if let Some(v) = value {
530 if let Some(n) = value_to_number(v) {
531 self.values.push(n);
532 }
533 }
534 }
535
536 fn finalize(&self) -> Value {
537 if self.values.is_empty() {
538 return Value::Null;
539 }
540
541 let n = self.values.len() as f64;
542 let mean = self.values.iter().sum::<f64>() / n;
543 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
544
545 Value::Float(variance)
546 }
547
548 fn reset(&mut self) {
549 self.values.clear();
550 }
551
552 fn new_instance(&self) -> Box<dyn Aggregator> {
553 Box::new(Self::new())
554 }
555
556 fn name(&self) -> &'static str {
557 "VARIANCE"
558 }
559}
560
561#[derive(Debug, Clone)]
567pub struct PercentileAggregator {
568 values: Vec<f64>,
569 percentile: f64, }
571
572impl PercentileAggregator {
573 pub fn new(percentile: f64) -> Self {
574 Self {
575 values: Vec::new(),
576 percentile: percentile.clamp(0.0, 1.0),
577 }
578 }
579
580 pub fn median() -> Self {
582 Self::new(0.5)
583 }
584}
585
586impl Aggregator for PercentileAggregator {
587 fn accumulate(&mut self, value: Option<&Value>) {
588 if let Some(v) = value {
589 if let Some(n) = value_to_number(v) {
590 self.values.push(n);
591 }
592 }
593 }
594
595 fn finalize(&self) -> Value {
596 if self.values.is_empty() {
597 return Value::Null;
598 }
599
600 let mut sorted = self.values.clone();
601 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
602
603 let index = (self.percentile * (sorted.len() - 1) as f64).round() as usize;
604 Value::Float(sorted[index])
605 }
606
607 fn reset(&mut self) {
608 self.values.clear();
609 }
610
611 fn new_instance(&self) -> Box<dyn Aggregator> {
612 Box::new(Self::new(self.percentile))
613 }
614
615 fn name(&self) -> &'static str {
616 "PERCENTILE"
617 }
618}
619
620pub struct AggregationDef {
626 pub source_var: Var,
628 pub result_var: Var,
630 pub aggregator: Box<dyn Aggregator>,
632}
633
634impl std::fmt::Debug for AggregationDef {
635 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636 f.debug_struct("AggregationDef")
637 .field("source_var", &self.source_var)
638 .field("result_var", &self.result_var)
639 .field("aggregator", &self.aggregator.name())
640 .finish()
641 }
642}
643
644const WORK_MEM_BYTES: usize = 64 * 1024 * 1024; const AVG_GROUP_ENTRY_BYTES: usize = 512;
661
662pub fn execute_group_by(
673 bindings: Vec<Binding>,
674 group_vars: &[Var],
675 aggregations: &[AggregationDef],
676) -> Vec<Binding> {
677 let mut groups: HashMap<String, (Binding, Vec<Box<dyn Aggregator>>)> = HashMap::new();
680
681 for binding in &bindings {
682 let key = make_group_key(binding, group_vars);
683 let entry = groups.entry(key).or_insert_with(|| {
684 let mut key_binding = Binding::empty();
686 for var in group_vars {
687 if let Some(value) = binding.get(var) {
688 let partial = Binding::one(var.clone(), value.clone());
689 key_binding = key_binding.merge(&partial).unwrap_or(key_binding);
690 }
691 }
692 let agg_instances = aggregations
694 .iter()
695 .map(|a| a.aggregator.new_instance())
696 .collect();
697 (key_binding, agg_instances)
698 });
699
700 for (i, agg_def) in aggregations.iter().enumerate() {
702 entry.1[i].accumulate(binding.get(&agg_def.source_var));
703 }
704
705 #[cfg(debug_assertions)]
711 if groups.len() * AVG_GROUP_ENTRY_BYTES > WORK_MEM_BYTES {
712 static WARNED: std::sync::atomic::AtomicBool =
715 std::sync::atomic::AtomicBool::new(false);
716 if !WARNED.swap(true, std::sync::atomic::Ordering::Relaxed) {
717 eprintln!(
718 "[reddb] hash-agg: {} distinct groups × {} B ≈ {} MiB exceeds WORK_MEM {} MiB; \
719 disk spill not yet wired — upgrade calling convention to streaming for OOM safety",
720 groups.len(),
721 AVG_GROUP_ENTRY_BYTES,
722 (groups.len() * AVG_GROUP_ENTRY_BYTES) / (1024 * 1024),
723 WORK_MEM_BYTES / (1024 * 1024),
724 );
725 }
726 }
727 }
728
729 let mut results = Vec::with_capacity(groups.len());
731 for (_, (key_binding, agg_instances)) in groups {
732 let mut result = key_binding;
733 for (i, agg_def) in aggregations.iter().enumerate() {
734 let agg_result = agg_instances[i].finalize();
735 let partial = Binding::one(agg_def.result_var.clone(), agg_result);
736 result = result.merge(&partial).unwrap_or(result);
737 }
738 results.push(result);
739 }
740 results
741}
742
743pub fn execute_having<F>(bindings: Vec<Binding>, predicate: F) -> Vec<Binding>
745where
746 F: Fn(&Binding) -> bool,
747{
748 bindings.into_iter().filter(|b| predicate(b)).collect()
749}
750
751fn value_to_string(value: &Value) -> String {
760 match value {
761 Value::Node(id) => format!("node:{}", id),
762 Value::Edge(id) => format!("edge:{}", id),
763 Value::String(s) => s.clone(),
764 Value::Integer(i) => i.to_string(),
765 Value::Float(f) => f.to_string(),
766 Value::Boolean(b) => b.to_string(),
767 Value::Uri(u) => u.clone(),
768 Value::Null => "null".to_string(),
769 }
770}
771
772fn make_group_key(binding: &Binding, group_vars: &[Var]) -> String {
783 use std::fmt::Write;
784 let mut key = String::with_capacity(64);
788 for (i, var) in group_vars.iter().enumerate() {
789 if i > 0 {
790 key.push('|');
791 }
792 key.push_str(var.name());
793 key.push('=');
794 match binding.get(var) {
795 None => key.push_str("NULL"),
796 Some(Value::Null) => key.push_str("null"),
797 Some(Value::String(s)) => key.push_str(s),
798 Some(Value::Integer(n)) => {
799 let _ = write!(key, "{n}");
800 }
801 Some(Value::Float(f)) => {
802 let _ = write!(key, "{f}");
803 }
804 Some(Value::Boolean(b)) => {
805 let _ = write!(key, "{b}");
806 }
807 Some(Value::Node(id)) => {
808 key.push_str("node:");
809 key.push_str(id);
810 }
811 Some(Value::Edge(id)) => {
812 key.push_str("edge:");
813 key.push_str(id);
814 }
815 Some(Value::Uri(u)) => key.push_str(u),
816 }
817 }
818 key
819}
820
821fn value_to_number(value: &Value) -> Option<f64> {
822 match value {
823 Value::Integer(i) => Some(*i as f64),
824 Value::Float(f) => Some(*f),
825 Value::String(s) => s.parse().ok(),
826 _ => None,
827 }
828}
829
830pub fn create_aggregator(name: &str) -> Option<Box<dyn Aggregator>> {
836 match name.to_uppercase().as_str() {
837 "COUNT" => Some(Box::new(CountAggregator::count_all())),
838 "COUNT_COLUMN" => Some(Box::new(CountAggregator::count_column())),
839 "COUNT_DISTINCT" => Some(Box::new(CountDistinctAggregator::new())),
840 "SUM" => Some(Box::new(SumAggregator::new())),
841 "AVG" => Some(Box::new(AvgAggregator::new())),
842 "MIN" => Some(Box::new(MinAggregator::new())),
843 "MAX" => Some(Box::new(MaxAggregator::new())),
844 "STDDEV" => Some(Box::new(StdDevAggregator::new())),
845 "VARIANCE" => Some(Box::new(VarianceAggregator::new())),
846 "MEDIAN" => Some(Box::new(PercentileAggregator::median())),
847 "SAMPLE" => Some(Box::new(SampleAggregator::new())),
848 "GROUP_CONCAT" => Some(Box::new(GroupConcatAggregator::new(None))),
849 _ => None,
850 }
851}
852
853#[cfg(test)]
858mod tests {
859 use super::*;
860
861 fn make_binding(pairs: &[(&str, Value)]) -> Binding {
862 if pairs.is_empty() {
863 return Binding::empty();
864 }
865
866 let mut result = Binding::one(Var::new(pairs[0].0), pairs[0].1.clone());
867
868 for (k, v) in pairs.iter().skip(1) {
869 let next = Binding::one(Var::new(k), v.clone());
870 result = result.merge(&next).unwrap_or(result);
871 }
872
873 result
874 }
875
876 #[test]
877 fn test_count() {
878 let mut counter = CountAggregator::count_all();
879 counter.accumulate(Some(&Value::Integer(1)));
880 counter.accumulate(Some(&Value::Integer(2)));
881 counter.accumulate(None);
882 counter.accumulate(Some(&Value::Null));
883
884 assert_eq!(counter.finalize(), Value::Integer(4));
885 }
886
887 #[test]
888 fn test_count_column() {
889 let mut counter = CountAggregator::count_column();
890 counter.accumulate(Some(&Value::Integer(1)));
891 counter.accumulate(None);
892 counter.accumulate(Some(&Value::Null));
893 counter.accumulate(Some(&Value::Integer(2)));
894
895 assert_eq!(counter.finalize(), Value::Integer(2)); }
897
898 #[test]
899 fn test_sum() {
900 let mut sum = SumAggregator::new();
901 sum.accumulate(Some(&Value::Integer(10)));
902 sum.accumulate(Some(&Value::Float(5.5)));
903 sum.accumulate(Some(&Value::Integer(4)));
904
905 assert_eq!(sum.finalize(), Value::Float(19.5));
906 }
907
908 #[test]
909 fn test_avg() {
910 let mut avg = AvgAggregator::new();
911 avg.accumulate(Some(&Value::Integer(10)));
912 avg.accumulate(Some(&Value::Integer(20)));
913 avg.accumulate(Some(&Value::Integer(30)));
914
915 assert_eq!(avg.finalize(), Value::Float(20.0));
916 }
917
918 #[test]
919 fn test_min_max() {
920 let mut min = MinAggregator::new();
921 let mut max = MaxAggregator::new();
922
923 for val in [5, 2, 8, 1, 9] {
924 min.accumulate(Some(&Value::Integer(val)));
925 max.accumulate(Some(&Value::Integer(val)));
926 }
927
928 assert_eq!(min.finalize(), Value::Integer(1));
929 assert_eq!(max.finalize(), Value::Integer(9));
930 }
931
932 #[test]
933 fn test_count_distinct() {
934 let mut distinct = CountDistinctAggregator::new();
935 distinct.accumulate(Some(&Value::String("a".to_string())));
936 distinct.accumulate(Some(&Value::String("b".to_string())));
937 distinct.accumulate(Some(&Value::String("a".to_string())));
938 distinct.accumulate(Some(&Value::String("c".to_string())));
939
940 assert_eq!(distinct.finalize(), Value::Integer(3));
941 }
942
943 #[test]
944 fn test_group_by() {
945 let bindings = vec![
946 make_binding(&[
947 ("dept", Value::String("Sales".to_string())),
948 ("salary", Value::Integer(50000)),
949 ]),
950 make_binding(&[
951 ("dept", Value::String("Sales".to_string())),
952 ("salary", Value::Integer(60000)),
953 ]),
954 make_binding(&[
955 ("dept", Value::String("Engineering".to_string())),
956 ("salary", Value::Integer(80000)),
957 ]),
958 make_binding(&[
959 ("dept", Value::String("Engineering".to_string())),
960 ("salary", Value::Integer(90000)),
961 ]),
962 ];
963
964 let aggs = vec![
965 AggregationDef {
966 source_var: Var::new("salary"),
967 result_var: Var::new("total"),
968 aggregator: Box::new(SumAggregator::new()),
969 },
970 AggregationDef {
971 source_var: Var::new("salary"),
972 result_var: Var::new("count"),
973 aggregator: Box::new(CountAggregator::count_all()),
974 },
975 ];
976
977 let results = execute_group_by(bindings, &[Var::new("dept")], &aggs);
978
979 assert_eq!(results.len(), 2);
980
981 let sales = results
983 .iter()
984 .find(|b| b.get(&Var::new("dept")) == Some(&Value::String("Sales".to_string())))
985 .expect("Sales group not found");
986
987 assert_eq!(sales.get(&Var::new("total")), Some(&Value::Integer(110000)));
988 assert_eq!(sales.get(&Var::new("count")), Some(&Value::Integer(2)));
989 }
990
991 #[test]
992 fn test_percentile() {
993 let mut p50 = PercentileAggregator::median();
994 for v in [1, 2, 3, 4, 5, 6, 7, 8, 9] {
995 p50.accumulate(Some(&Value::Integer(v)));
996 }
997 assert_eq!(p50.finalize(), Value::Float(5.0));
999 }
1000}