nodedb_cluster/distributed_document/
partial_group.rs1use std::collections::HashMap;
13
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PartialGroup {
19 pub group_key: String,
21 pub count: u64,
23 pub columns: HashMap<String, PartialColumnAgg>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PartialColumnAgg {
30 pub sum: f64,
31 pub count: u64,
32 pub min: f64,
33 pub max: f64,
34}
35
36impl PartialColumnAgg {
37 pub fn merge(&mut self, other: &PartialColumnAgg) {
38 self.sum += other.sum;
39 self.count += other.count;
40 if other.min < self.min {
41 self.min = other.min;
42 }
43 if other.max > self.max {
44 self.max = other.max;
45 }
46 }
47
48 pub fn avg(&self) -> f64 {
49 if self.count == 0 {
50 f64::NAN
51 } else {
52 self.sum / self.count as f64
53 }
54 }
55}
56
57pub struct PartialGroupByMerger {
59 groups: HashMap<String, PartialGroup>,
61}
62
63impl PartialGroupByMerger {
64 pub fn new() -> Self {
65 Self {
66 groups: HashMap::new(),
67 }
68 }
69
70 pub fn add_shard_results(&mut self, partials: &[PartialGroup]) {
72 for partial in partials {
73 let entry = self
74 .groups
75 .entry(partial.group_key.clone())
76 .or_insert_with(|| PartialGroup {
77 group_key: partial.group_key.clone(),
78 count: 0,
79 columns: HashMap::new(),
80 });
81
82 entry.count += partial.count;
83
84 for (col_name, col_agg) in &partial.columns {
85 entry
86 .columns
87 .entry(col_name.clone())
88 .and_modify(|existing| existing.merge(col_agg))
89 .or_insert_with(|| col_agg.clone());
90 }
91 }
92 }
93
94 pub fn finalize(&self) -> Vec<&PartialGroup> {
96 self.groups.values().collect()
97 }
98
99 pub fn group_count(&self) -> usize {
101 self.groups.len()
102 }
103}
104
105impl Default for PartialGroupByMerger {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn merge_two_shards() {
117 let mut merger = PartialGroupByMerger::new();
118
119 merger.add_shard_results(&[PartialGroup {
121 group_key: "active".into(),
122 count: 50,
123 columns: HashMap::from([(
124 "age".into(),
125 PartialColumnAgg {
126 sum: 1500.0,
127 count: 50,
128 min: 18.0,
129 max: 65.0,
130 },
131 )]),
132 }]);
133
134 merger.add_shard_results(&[PartialGroup {
136 group_key: "active".into(),
137 count: 80,
138 columns: HashMap::from([(
139 "age".into(),
140 PartialColumnAgg {
141 sum: 2800.0,
142 count: 80,
143 min: 20.0,
144 max: 70.0,
145 },
146 )]),
147 }]);
148
149 let results = merger.finalize();
150 assert_eq!(results.len(), 1);
151 let active = &results[0];
152 assert_eq!(active.count, 130);
153 let age = &active.columns["age"];
154 assert_eq!(age.count, 130);
155 assert_eq!(age.sum, 4300.0);
156 assert!((age.avg() - 33.08).abs() < 0.1); assert_eq!(age.min, 18.0);
158 assert_eq!(age.max, 70.0);
159 }
160
161 #[test]
162 fn merge_multiple_groups() {
163 let mut merger = PartialGroupByMerger::new();
164
165 merger.add_shard_results(&[
166 PartialGroup {
167 group_key: "active".into(),
168 count: 50,
169 columns: HashMap::new(),
170 },
171 PartialGroup {
172 group_key: "inactive".into(),
173 count: 10,
174 columns: HashMap::new(),
175 },
176 ]);
177 merger.add_shard_results(&[PartialGroup {
178 group_key: "active".into(),
179 count: 30,
180 columns: HashMap::new(),
181 }]);
182
183 assert_eq!(merger.group_count(), 2);
184 let active = merger.groups.get("active").unwrap();
185 assert_eq!(active.count, 80);
186 let inactive = merger.groups.get("inactive").unwrap();
187 assert_eq!(inactive.count, 10);
188 }
189}