Skip to main content

nodedb_cluster/distributed_array/
merge.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Merge functions for distributed array query results.
4//!
5//! Slice merges concatenate row sets from shards in arrival order and
6//! apply the coordinator-side limit. The output preserves each shard's
7//! intra-shard order; cross-shard ordering reflects arrival, since the
8//! wire response (`ArrayShardSliceResp::rows_msgpack`) is a flat opaque
9//! `Vec<Vec<u8>>` with no per-row sort key. A globally Hilbert-ordered
10//! merge would require carrying a parallel prefix column on the wire and
11//! a k-way merge here — that is a wire-format change, not a merger
12//! change, and lives outside this module.
13//!
14//! Aggregate merges combine per-shard partial aggregates using
15//! reducer-specific arithmetic (SUM/COUNT/MIN/MAX — same Welford
16//! technique as the timeseries merger).
17
18use serde::{Deserialize, Serialize};
19
20use super::wire::{ArrayShardAggResp, ArrayShardSliceResp};
21
22/// Partial aggregate contributed by a single shard for one group-by bucket.
23///
24/// Carries enough state for all supported reducers (SUM, COUNT, MIN, MAX,
25/// MEAN). Welford fields enable variance/stddev if a future reducer needs it.
26#[derive(
27    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
28)]
29pub struct ArrayAggPartial {
30    /// Group-by dimension value, or 0 when group_by_dim < 0 (scalar aggregate).
31    pub group_key: i64,
32    pub count: u64,
33    pub sum: f64,
34    pub min: f64,
35    pub max: f64,
36    /// Welford mean — enables MEAN without a second pass.
37    pub welford_mean: f64,
38    pub welford_m2: f64,
39}
40
41impl ArrayAggPartial {
42    /// Create from a single cell value.
43    pub fn from_single(group_key: i64, val: f64) -> Self {
44        Self {
45            group_key,
46            count: 1,
47            sum: val,
48            min: val,
49            max: val,
50            welford_mean: val,
51            welford_m2: 0.0,
52        }
53    }
54
55    /// Merge another partial into this one using parallel Welford.
56    pub fn merge(&mut self, other: &ArrayAggPartial) {
57        if other.count == 0 {
58            return;
59        }
60        if self.count == 0 {
61            *self = other.clone();
62            return;
63        }
64        self.sum += other.sum;
65        if other.min < self.min {
66            self.min = other.min;
67        }
68        if other.max > self.max {
69            self.max = other.max;
70        }
71        let new_count = self.count + other.count;
72        let delta = other.welford_mean - self.welford_mean;
73        let combined_mean = (self.welford_mean * self.count as f64
74            + other.welford_mean * other.count as f64)
75            / new_count as f64;
76        let combined_m2 = self.welford_m2
77            + other.welford_m2
78            + delta * delta * (self.count as f64 * other.count as f64) / new_count as f64;
79        self.welford_mean = combined_mean;
80        self.welford_m2 = combined_m2;
81        self.count = new_count;
82    }
83}
84
85/// Returns `true` if any shard reported that `system_as_of` fell below its
86/// oldest tile version and it produced zero rows as a result.
87///
88/// Callers combine this flag via logical OR across shards to propagate the
89/// below-horizon signal to the upstream coordinator response.
90pub fn any_truncated_before_horizon_slice(shard_resps: &[ArrayShardSliceResp]) -> bool {
91    shard_resps.iter().any(|r| r.truncated_before_horizon)
92}
93
94/// Returns `true` if any shard reported that `system_as_of` fell below its
95/// oldest tile version, causing the shard to contribute zero partials.
96pub fn any_truncated_before_horizon_agg(shard_resps: &[ArrayShardAggResp]) -> bool {
97    shard_resps.iter().any(|r| r.truncated_before_horizon)
98}
99
100/// Merge row batches from multiple shards into one result set.
101///
102/// Rows are concatenated in shard-arrival order (order-independent for
103/// an unsorted slice). If `coordinator_limit > 0` the merged list is
104/// truncated to at most `coordinator_limit` rows after concatenation —
105/// this is the final cut-off after shards have already applied their own
106/// per-shard limit via `ArrayShardSliceReq::limit`.
107///
108/// Pass `coordinator_limit = 0` to return all rows without truncation.
109pub fn merge_slice_rows(
110    shard_resps: &[ArrayShardSliceResp],
111    coordinator_limit: u32,
112) -> Vec<Vec<u8>> {
113    let total: usize = shard_resps.iter().map(|r| r.rows_msgpack.len()).sum();
114    let cap = if coordinator_limit > 0 {
115        total.min(coordinator_limit as usize)
116    } else {
117        total
118    };
119    let mut merged = Vec::with_capacity(cap);
120    'outer: for resp in shard_resps {
121        for row in &resp.rows_msgpack {
122            if coordinator_limit > 0 && merged.len() >= coordinator_limit as usize {
123                break 'outer;
124            }
125            merged.push(row.clone());
126        }
127    }
128    merged
129}
130
131/// Merge per-shard partial aggregates into one result per group-by key.
132///
133/// Groups by `group_key`; uses `ArrayAggPartial::merge` for each group.
134pub fn reduce_agg_partials(shard_resps: &[ArrayShardAggResp]) -> Vec<ArrayAggPartial> {
135    use std::collections::BTreeMap;
136    let mut buckets: BTreeMap<i64, ArrayAggPartial> = BTreeMap::new();
137    for resp in shard_resps {
138        for partial in &resp.partials {
139            buckets
140                .entry(partial.group_key)
141                .and_modify(|existing| existing.merge(partial))
142                .or_insert_with(|| partial.clone());
143        }
144    }
145    buckets.into_values().collect()
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn reduce_sum_across_shards() {
154        let resp_a = ArrayShardAggResp {
155            shard_id: 0,
156            partials: vec![ArrayAggPartial::from_single(0, 10.0)],
157            truncated_before_horizon: false,
158        };
159        let resp_b = ArrayShardAggResp {
160            shard_id: 1,
161            partials: vec![ArrayAggPartial::from_single(0, 20.0)],
162            truncated_before_horizon: false,
163        };
164        let merged = reduce_agg_partials(&[resp_a, resp_b]);
165        assert_eq!(merged.len(), 1);
166        assert_eq!(merged[0].count, 2);
167        assert!((merged[0].sum - 30.0).abs() < f64::EPSILON);
168    }
169
170    #[test]
171    fn reduce_separate_group_keys() {
172        let resp = ArrayShardAggResp {
173            shard_id: 0,
174            partials: vec![
175                ArrayAggPartial::from_single(0, 5.0),
176                ArrayAggPartial::from_single(1, 15.0),
177            ],
178            truncated_before_horizon: false,
179        };
180        let merged = reduce_agg_partials(&[resp]);
181        assert_eq!(merged.len(), 2);
182    }
183
184    #[test]
185    fn merge_empty_partial_is_noop() {
186        let mut a = ArrayAggPartial::from_single(0, 42.0);
187        let empty = ArrayAggPartial {
188            count: 0,
189            ..ArrayAggPartial::from_single(0, 0.0)
190        };
191        a.merge(&empty);
192        assert_eq!(a.count, 1);
193        assert!((a.sum - 42.0).abs() < f64::EPSILON);
194    }
195
196    #[test]
197    fn merge_slice_rows_concatenates() {
198        let r0 = ArrayShardSliceResp {
199            shard_id: 0,
200            rows_msgpack: vec![vec![1u8], vec![2u8]],
201            truncated: false,
202            truncated_before_horizon: false,
203        };
204        let r1 = ArrayShardSliceResp {
205            shard_id: 1,
206            rows_msgpack: vec![vec![3u8]],
207            truncated: false,
208            truncated_before_horizon: false,
209        };
210        let rows = merge_slice_rows(&[r0, r1], 0);
211        assert_eq!(rows.len(), 3);
212    }
213
214    #[test]
215    fn merge_slice_rows_applies_coordinator_limit() {
216        let resp = ArrayShardSliceResp {
217            shard_id: 0,
218            rows_msgpack: vec![vec![1u8], vec![2u8], vec![3u8], vec![4u8], vec![5u8]],
219            truncated: false,
220            truncated_before_horizon: false,
221        };
222        let rows = merge_slice_rows(&[resp], 3);
223        assert_eq!(rows.len(), 3);
224        assert_eq!(rows[0], vec![1u8]);
225        assert_eq!(rows[2], vec![3u8]);
226    }
227
228    #[test]
229    fn reduce_min_across_shards() {
230        let resp_a = ArrayShardAggResp {
231            shard_id: 0,
232            partials: vec![ArrayAggPartial::from_single(0, 5.0)],
233            truncated_before_horizon: false,
234        };
235        let resp_b = ArrayShardAggResp {
236            shard_id: 1,
237            partials: vec![ArrayAggPartial::from_single(0, 3.0)],
238            truncated_before_horizon: false,
239        };
240        let merged = reduce_agg_partials(&[resp_a, resp_b]);
241        assert_eq!(merged.len(), 1);
242        assert!((merged[0].min - 3.0).abs() < f64::EPSILON);
243    }
244
245    #[test]
246    fn reduce_max_across_shards() {
247        let resp_a = ArrayShardAggResp {
248            shard_id: 0,
249            partials: vec![ArrayAggPartial::from_single(0, 5.0)],
250            truncated_before_horizon: false,
251        };
252        let resp_b = ArrayShardAggResp {
253            shard_id: 1,
254            partials: vec![ArrayAggPartial::from_single(0, 99.0)],
255            truncated_before_horizon: false,
256        };
257        let merged = reduce_agg_partials(&[resp_a, resp_b]);
258        assert_eq!(merged.len(), 1);
259        assert!((merged[0].max - 99.0).abs() < f64::EPSILON);
260    }
261
262    #[test]
263    fn reduce_avg_welford_merge_exact() {
264        // Two shards: shard A has one value of 10, shard B has one value of 20.
265        // Combined mean should be exactly 15.
266        let mut a = ArrayAggPartial::from_single(0, 10.0);
267        let b = ArrayAggPartial::from_single(0, 20.0);
268        a.merge(&b);
269        // welford_mean after merge = 15.0
270        assert!((a.welford_mean - 15.0).abs() < 1e-9);
271        assert_eq!(a.count, 2);
272        assert!((a.sum - 30.0).abs() < f64::EPSILON);
273    }
274
275    #[test]
276    fn reduce_grouped_overlapping_keys() {
277        // Shard A: groups 0→5, 1→10. Shard B: groups 1→20, 2→30.
278        let resp_a = ArrayShardAggResp {
279            shard_id: 0,
280            partials: vec![
281                ArrayAggPartial::from_single(0, 5.0),
282                ArrayAggPartial::from_single(1, 10.0),
283            ],
284            truncated_before_horizon: false,
285        };
286        let resp_b = ArrayShardAggResp {
287            shard_id: 1,
288            partials: vec![
289                ArrayAggPartial::from_single(1, 20.0),
290                ArrayAggPartial::from_single(2, 30.0),
291            ],
292            truncated_before_horizon: false,
293        };
294        let merged = reduce_agg_partials(&[resp_a, resp_b]);
295        assert_eq!(merged.len(), 3);
296        let g0 = merged.iter().find(|p| p.group_key == 0).unwrap();
297        let g1 = merged.iter().find(|p| p.group_key == 1).unwrap();
298        let g2 = merged.iter().find(|p| p.group_key == 2).unwrap();
299        assert!((g0.sum - 5.0).abs() < f64::EPSILON);
300        assert!((g1.sum - 30.0).abs() < f64::EPSILON);
301        assert!((g2.sum - 30.0).abs() < f64::EPSILON);
302    }
303
304    #[test]
305    fn truncated_before_horizon_or_combines_across_shards() {
306        let r0 = ArrayShardSliceResp {
307            shard_id: 0,
308            rows_msgpack: vec![],
309            truncated: false,
310            truncated_before_horizon: true,
311        };
312        let r1 = ArrayShardSliceResp {
313            shard_id: 1,
314            rows_msgpack: vec![vec![1u8]],
315            truncated: false,
316            truncated_before_horizon: false,
317        };
318        assert!(any_truncated_before_horizon_slice(&[r0, r1]));
319
320        let a0 = ArrayShardAggResp {
321            shard_id: 0,
322            partials: vec![],
323            truncated_before_horizon: false,
324        };
325        let a1 = ArrayShardAggResp {
326            shard_id: 1,
327            partials: vec![],
328            truncated_before_horizon: true,
329        };
330        assert!(any_truncated_before_horizon_agg(&[a0, a1]));
331
332        let a_none = ArrayShardAggResp {
333            shard_id: 2,
334            partials: vec![],
335            truncated_before_horizon: false,
336        };
337        assert!(!any_truncated_before_horizon_agg(&[a_none]));
338    }
339
340    #[test]
341    fn reduce_grouped_disjoint_keys() {
342        // Shard A has only group 0; shard B has only group 1 — no overlap.
343        let resp_a = ArrayShardAggResp {
344            shard_id: 0,
345            partials: vec![ArrayAggPartial::from_single(0, 7.0)],
346            truncated_before_horizon: false,
347        };
348        let resp_b = ArrayShardAggResp {
349            shard_id: 1,
350            partials: vec![ArrayAggPartial::from_single(1, 13.0)],
351            truncated_before_horizon: false,
352        };
353        let merged = reduce_agg_partials(&[resp_a, resp_b]);
354        assert_eq!(merged.len(), 2);
355        let g0 = merged.iter().find(|p| p.group_key == 0).unwrap();
356        let g1 = merged.iter().find(|p| p.group_key == 1).unwrap();
357        assert_eq!(g0.count, 1);
358        assert_eq!(g1.count, 1);
359    }
360
361    #[test]
362    fn merge_slice_rows_limit_across_shards() {
363        let r0 = ArrayShardSliceResp {
364            shard_id: 0,
365            rows_msgpack: vec![vec![1u8], vec![2u8]],
366            truncated: false,
367            truncated_before_horizon: false,
368        };
369        let r1 = ArrayShardSliceResp {
370            shard_id: 1,
371            rows_msgpack: vec![vec![3u8], vec![4u8]],
372            truncated: false,
373            truncated_before_horizon: false,
374        };
375        // Total 4 rows, limit 3 → first 3.
376        let rows = merge_slice_rows(&[r0, r1], 3);
377        assert_eq!(rows.len(), 3);
378    }
379
380    #[test]
381    fn merge_slice_rows_zero_limit_is_unlimited() {
382        let resp = ArrayShardSliceResp {
383            shard_id: 0,
384            rows_msgpack: (0u8..20).map(|i| vec![i]).collect(),
385            truncated: false,
386            truncated_before_horizon: false,
387        };
388        let rows = merge_slice_rows(&[resp], 0);
389        assert_eq!(rows.len(), 20);
390    }
391}