Skip to main content

nodedb_cluster/distributed_document/
partial_group.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Distributed GROUP BY for document queries.
4//!
5//! Each shard computes local partial aggregates per group key. The
6//! coordinator merges partials across shards to produce the global result.
7//!
8//! Example: `SELECT status, COUNT(*), AVG(age) FROM users GROUP BY status`
9//! - Each shard returns: `[("active", count=50, sum_age=1500, count_age=50), ...]`
10//! - Coordinator merges: `("active", count=150, avg_age=sum_ages/count_ages)`
11
12use std::collections::HashMap;
13
14use serde::{Deserialize, Serialize};
15
16/// A partial aggregate for one group key from one shard.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PartialGroup {
19    /// The group key as a JSON string (e.g., `"active"` or `["us-east","web"]`).
20    pub group_key: String,
21    /// COUNT(*) partial.
22    pub count: u64,
23    /// Per-column partial aggregates: column_name → PartialColumnAgg.
24    pub columns: HashMap<String, PartialColumnAgg>,
25}
26
27/// Partial aggregate state for a single column within a group.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PartialColumnAgg {
30    pub sum: f64,
31    pub count: u64,
32    pub min: f64,
33    pub max: f64,
34}
35
36impl PartialColumnAgg {
37    pub fn merge(&mut self, other: &PartialColumnAgg) {
38        self.sum += other.sum;
39        self.count += other.count;
40        if other.min < self.min {
41            self.min = other.min;
42        }
43        if other.max > self.max {
44            self.max = other.max;
45        }
46    }
47
48    pub fn avg(&self) -> f64 {
49        if self.count == 0 {
50            f64::NAN
51        } else {
52            self.sum / self.count as f64
53        }
54    }
55}
56
57/// Merger for distributed GROUP BY results.
58pub struct PartialGroupByMerger {
59    /// group_key → merged partial.
60    groups: HashMap<String, PartialGroup>,
61}
62
63impl PartialGroupByMerger {
64    pub fn new() -> Self {
65        Self {
66            groups: HashMap::new(),
67        }
68    }
69
70    /// Add a shard's partial GROUP BY results.
71    pub fn add_shard_results(&mut self, partials: &[PartialGroup]) {
72        for partial in partials {
73            let entry = self
74                .groups
75                .entry(partial.group_key.clone())
76                .or_insert_with(|| PartialGroup {
77                    group_key: partial.group_key.clone(),
78                    count: 0,
79                    columns: HashMap::new(),
80                });
81
82            entry.count += partial.count;
83
84            for (col_name, col_agg) in &partial.columns {
85                entry
86                    .columns
87                    .entry(col_name.clone())
88                    .and_modify(|existing| existing.merge(col_agg))
89                    .or_insert_with(|| col_agg.clone());
90            }
91        }
92    }
93
94    /// Get the merged results.
95    pub fn finalize(&self) -> Vec<&PartialGroup> {
96        self.groups.values().collect()
97    }
98
99    /// Number of distinct groups.
100    pub fn group_count(&self) -> usize {
101        self.groups.len()
102    }
103}
104
105impl Default for PartialGroupByMerger {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn merge_two_shards() {
117        let mut merger = PartialGroupByMerger::new();
118
119        // Shard 0: status="active" has 50 users, avg age ~30.
120        merger.add_shard_results(&[PartialGroup {
121            group_key: "active".into(),
122            count: 50,
123            columns: HashMap::from([(
124                "age".into(),
125                PartialColumnAgg {
126                    sum: 1500.0,
127                    count: 50,
128                    min: 18.0,
129                    max: 65.0,
130                },
131            )]),
132        }]);
133
134        // Shard 1: status="active" has 80 users, avg age ~35.
135        merger.add_shard_results(&[PartialGroup {
136            group_key: "active".into(),
137            count: 80,
138            columns: HashMap::from([(
139                "age".into(),
140                PartialColumnAgg {
141                    sum: 2800.0,
142                    count: 80,
143                    min: 20.0,
144                    max: 70.0,
145                },
146            )]),
147        }]);
148
149        let results = merger.finalize();
150        assert_eq!(results.len(), 1);
151        let active = &results[0];
152        assert_eq!(active.count, 130);
153        let age = &active.columns["age"];
154        assert_eq!(age.count, 130);
155        assert_eq!(age.sum, 4300.0);
156        assert!((age.avg() - 33.08).abs() < 0.1); // 4300/130 ≈ 33.08
157        assert_eq!(age.min, 18.0);
158        assert_eq!(age.max, 70.0);
159    }
160
161    #[test]
162    fn merge_multiple_groups() {
163        let mut merger = PartialGroupByMerger::new();
164
165        merger.add_shard_results(&[
166            PartialGroup {
167                group_key: "active".into(),
168                count: 50,
169                columns: HashMap::new(),
170            },
171            PartialGroup {
172                group_key: "inactive".into(),
173                count: 10,
174                columns: HashMap::new(),
175            },
176        ]);
177        merger.add_shard_results(&[PartialGroup {
178            group_key: "active".into(),
179            count: 30,
180            columns: HashMap::new(),
181        }]);
182
183        assert_eq!(merger.group_count(), 2);
184        let active = merger.groups.get("active").unwrap();
185        assert_eq!(active.count, 80);
186        let inactive = merger.groups.get("inactive").unwrap();
187        assert_eq!(inactive.count, 10);
188    }
189}