Skip to main content

nodedb_cluster/distributed_document/
partial_group.rs

1//! Distributed GROUP BY for document queries.
2//!
3//! Each shard computes local partial aggregates per group key. The
4//! coordinator merges partials across shards to produce the global result.
5//!
6//! Example: `SELECT status, COUNT(*), AVG(age) FROM users GROUP BY status`
7//! - Each shard returns: `[("active", count=50, sum_age=1500, count_age=50), ...]`
8//! - Coordinator merges: `("active", count=150, avg_age=sum_ages/count_ages)`
9
10use std::collections::HashMap;
11
12use serde::{Deserialize, Serialize};
13
14/// A partial aggregate for one group key from one shard.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PartialGroup {
17    /// The group key as a JSON string (e.g., `"active"` or `["us-east","web"]`).
18    pub group_key: String,
19    /// COUNT(*) partial.
20    pub count: u64,
21    /// Per-column partial aggregates: column_name → PartialColumnAgg.
22    pub columns: HashMap<String, PartialColumnAgg>,
23}
24
25/// Partial aggregate state for a single column within a group.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct PartialColumnAgg {
28    pub sum: f64,
29    pub count: u64,
30    pub min: f64,
31    pub max: f64,
32}
33
34impl PartialColumnAgg {
35    pub fn merge(&mut self, other: &PartialColumnAgg) {
36        self.sum += other.sum;
37        self.count += other.count;
38        if other.min < self.min {
39            self.min = other.min;
40        }
41        if other.max > self.max {
42            self.max = other.max;
43        }
44    }
45
46    pub fn avg(&self) -> f64 {
47        if self.count == 0 {
48            f64::NAN
49        } else {
50            self.sum / self.count as f64
51        }
52    }
53}
54
55/// Merger for distributed GROUP BY results.
56pub struct PartialGroupByMerger {
57    /// group_key → merged partial.
58    groups: HashMap<String, PartialGroup>,
59}
60
61impl PartialGroupByMerger {
62    pub fn new() -> Self {
63        Self {
64            groups: HashMap::new(),
65        }
66    }
67
68    /// Add a shard's partial GROUP BY results.
69    pub fn add_shard_results(&mut self, partials: &[PartialGroup]) {
70        for partial in partials {
71            let entry = self
72                .groups
73                .entry(partial.group_key.clone())
74                .or_insert_with(|| PartialGroup {
75                    group_key: partial.group_key.clone(),
76                    count: 0,
77                    columns: HashMap::new(),
78                });
79
80            entry.count += partial.count;
81
82            for (col_name, col_agg) in &partial.columns {
83                entry
84                    .columns
85                    .entry(col_name.clone())
86                    .and_modify(|existing| existing.merge(col_agg))
87                    .or_insert_with(|| col_agg.clone());
88            }
89        }
90    }
91
92    /// Get the merged results.
93    pub fn finalize(&self) -> Vec<&PartialGroup> {
94        self.groups.values().collect()
95    }
96
97    /// Number of distinct groups.
98    pub fn group_count(&self) -> usize {
99        self.groups.len()
100    }
101}
102
103impl Default for PartialGroupByMerger {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn merge_two_shards() {
115        let mut merger = PartialGroupByMerger::new();
116
117        // Shard 0: status="active" has 50 users, avg age ~30.
118        merger.add_shard_results(&[PartialGroup {
119            group_key: "active".into(),
120            count: 50,
121            columns: HashMap::from([(
122                "age".into(),
123                PartialColumnAgg {
124                    sum: 1500.0,
125                    count: 50,
126                    min: 18.0,
127                    max: 65.0,
128                },
129            )]),
130        }]);
131
132        // Shard 1: status="active" has 80 users, avg age ~35.
133        merger.add_shard_results(&[PartialGroup {
134            group_key: "active".into(),
135            count: 80,
136            columns: HashMap::from([(
137                "age".into(),
138                PartialColumnAgg {
139                    sum: 2800.0,
140                    count: 80,
141                    min: 20.0,
142                    max: 70.0,
143                },
144            )]),
145        }]);
146
147        let results = merger.finalize();
148        assert_eq!(results.len(), 1);
149        let active = &results[0];
150        assert_eq!(active.count, 130);
151        let age = &active.columns["age"];
152        assert_eq!(age.count, 130);
153        assert_eq!(age.sum, 4300.0);
154        assert!((age.avg() - 33.08).abs() < 0.1); // 4300/130 ≈ 33.08
155        assert_eq!(age.min, 18.0);
156        assert_eq!(age.max, 70.0);
157    }
158
159    #[test]
160    fn merge_multiple_groups() {
161        let mut merger = PartialGroupByMerger::new();
162
163        merger.add_shard_results(&[
164            PartialGroup {
165                group_key: "active".into(),
166                count: 50,
167                columns: HashMap::new(),
168            },
169            PartialGroup {
170                group_key: "inactive".into(),
171                count: 10,
172                columns: HashMap::new(),
173            },
174        ]);
175        merger.add_shard_results(&[PartialGroup {
176            group_key: "active".into(),
177            count: 30,
178            columns: HashMap::new(),
179        }]);
180
181        assert_eq!(merger.group_count(), 2);
182        let active = merger.groups.get("active").unwrap();
183        assert_eq!(active.count, 80);
184        let inactive = merger.groups.get("inactive").unwrap();
185        assert_eq!(inactive.count, 10);
186    }
187}