nodedb_cluster/distributed_document/
partial_group.rs1use std::collections::HashMap;
11
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PartialGroup {
17 pub group_key: String,
19 pub count: u64,
21 pub columns: HashMap<String, PartialColumnAgg>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct PartialColumnAgg {
28 pub sum: f64,
29 pub count: u64,
30 pub min: f64,
31 pub max: f64,
32}
33
34impl PartialColumnAgg {
35 pub fn merge(&mut self, other: &PartialColumnAgg) {
36 self.sum += other.sum;
37 self.count += other.count;
38 if other.min < self.min {
39 self.min = other.min;
40 }
41 if other.max > self.max {
42 self.max = other.max;
43 }
44 }
45
46 pub fn avg(&self) -> f64 {
47 if self.count == 0 {
48 f64::NAN
49 } else {
50 self.sum / self.count as f64
51 }
52 }
53}
54
55pub struct PartialGroupByMerger {
57 groups: HashMap<String, PartialGroup>,
59}
60
61impl PartialGroupByMerger {
62 pub fn new() -> Self {
63 Self {
64 groups: HashMap::new(),
65 }
66 }
67
68 pub fn add_shard_results(&mut self, partials: &[PartialGroup]) {
70 for partial in partials {
71 let entry = self
72 .groups
73 .entry(partial.group_key.clone())
74 .or_insert_with(|| PartialGroup {
75 group_key: partial.group_key.clone(),
76 count: 0,
77 columns: HashMap::new(),
78 });
79
80 entry.count += partial.count;
81
82 for (col_name, col_agg) in &partial.columns {
83 entry
84 .columns
85 .entry(col_name.clone())
86 .and_modify(|existing| existing.merge(col_agg))
87 .or_insert_with(|| col_agg.clone());
88 }
89 }
90 }
91
92 pub fn finalize(&self) -> Vec<&PartialGroup> {
94 self.groups.values().collect()
95 }
96
97 pub fn group_count(&self) -> usize {
99 self.groups.len()
100 }
101}
102
103impl Default for PartialGroupByMerger {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn merge_two_shards() {
115 let mut merger = PartialGroupByMerger::new();
116
117 merger.add_shard_results(&[PartialGroup {
119 group_key: "active".into(),
120 count: 50,
121 columns: HashMap::from([(
122 "age".into(),
123 PartialColumnAgg {
124 sum: 1500.0,
125 count: 50,
126 min: 18.0,
127 max: 65.0,
128 },
129 )]),
130 }]);
131
132 merger.add_shard_results(&[PartialGroup {
134 group_key: "active".into(),
135 count: 80,
136 columns: HashMap::from([(
137 "age".into(),
138 PartialColumnAgg {
139 sum: 2800.0,
140 count: 80,
141 min: 20.0,
142 max: 70.0,
143 },
144 )]),
145 }]);
146
147 let results = merger.finalize();
148 assert_eq!(results.len(), 1);
149 let active = &results[0];
150 assert_eq!(active.count, 130);
151 let age = &active.columns["age"];
152 assert_eq!(age.count, 130);
153 assert_eq!(age.sum, 4300.0);
154 assert!((age.avg() - 33.08).abs() < 0.1); assert_eq!(age.min, 18.0);
156 assert_eq!(age.max, 70.0);
157 }
158
159 #[test]
160 fn merge_multiple_groups() {
161 let mut merger = PartialGroupByMerger::new();
162
163 merger.add_shard_results(&[
164 PartialGroup {
165 group_key: "active".into(),
166 count: 50,
167 columns: HashMap::new(),
168 },
169 PartialGroup {
170 group_key: "inactive".into(),
171 count: 10,
172 columns: HashMap::new(),
173 },
174 ]);
175 merger.add_shard_results(&[PartialGroup {
176 group_key: "active".into(),
177 count: 30,
178 columns: HashMap::new(),
179 }]);
180
181 assert_eq!(merger.group_count(), 2);
182 let active = merger.groups.get("active").unwrap();
183 assert_eq!(active.count, 80);
184 let inactive = merger.groups.get("inactive").unwrap();
185 assert_eq!(inactive.count, 10);
186 }
187}