nodedb_cluster/distributed_array/
merge.rs1use serde::{Deserialize, Serialize};
19
20use super::wire::{ArrayShardAggResp, ArrayShardSliceResp};
21
22#[derive(
27 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
28)]
29pub struct ArrayAggPartial {
30 pub group_key: i64,
32 pub count: u64,
33 pub sum: f64,
34 pub min: f64,
35 pub max: f64,
36 pub welford_mean: f64,
38 pub welford_m2: f64,
39}
40
41impl ArrayAggPartial {
42 pub fn from_single(group_key: i64, val: f64) -> Self {
44 Self {
45 group_key,
46 count: 1,
47 sum: val,
48 min: val,
49 max: val,
50 welford_mean: val,
51 welford_m2: 0.0,
52 }
53 }
54
55 pub fn merge(&mut self, other: &ArrayAggPartial) {
57 if other.count == 0 {
58 return;
59 }
60 if self.count == 0 {
61 *self = other.clone();
62 return;
63 }
64 self.sum += other.sum;
65 if other.min < self.min {
66 self.min = other.min;
67 }
68 if other.max > self.max {
69 self.max = other.max;
70 }
71 let new_count = self.count + other.count;
72 let delta = other.welford_mean - self.welford_mean;
73 let combined_mean = (self.welford_mean * self.count as f64
74 + other.welford_mean * other.count as f64)
75 / new_count as f64;
76 let combined_m2 = self.welford_m2
77 + other.welford_m2
78 + delta * delta * (self.count as f64 * other.count as f64) / new_count as f64;
79 self.welford_mean = combined_mean;
80 self.welford_m2 = combined_m2;
81 self.count = new_count;
82 }
83}
84
85pub fn any_truncated_before_horizon_slice(shard_resps: &[ArrayShardSliceResp]) -> bool {
91 shard_resps.iter().any(|r| r.truncated_before_horizon)
92}
93
94pub fn any_truncated_before_horizon_agg(shard_resps: &[ArrayShardAggResp]) -> bool {
97 shard_resps.iter().any(|r| r.truncated_before_horizon)
98}
99
100pub fn merge_slice_rows(
110 shard_resps: &[ArrayShardSliceResp],
111 coordinator_limit: u32,
112) -> Vec<Vec<u8>> {
113 let total: usize = shard_resps.iter().map(|r| r.rows_msgpack.len()).sum();
114 let cap = if coordinator_limit > 0 {
115 total.min(coordinator_limit as usize)
116 } else {
117 total
118 };
119 let mut merged = Vec::with_capacity(cap);
120 'outer: for resp in shard_resps {
121 for row in &resp.rows_msgpack {
122 if coordinator_limit > 0 && merged.len() >= coordinator_limit as usize {
123 break 'outer;
124 }
125 merged.push(row.clone());
126 }
127 }
128 merged
129}
130
131pub fn reduce_agg_partials(shard_resps: &[ArrayShardAggResp]) -> Vec<ArrayAggPartial> {
135 use std::collections::BTreeMap;
136 let mut buckets: BTreeMap<i64, ArrayAggPartial> = BTreeMap::new();
137 for resp in shard_resps {
138 for partial in &resp.partials {
139 buckets
140 .entry(partial.group_key)
141 .and_modify(|existing| existing.merge(partial))
142 .or_insert_with(|| partial.clone());
143 }
144 }
145 buckets.into_values().collect()
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn reduce_sum_across_shards() {
154 let resp_a = ArrayShardAggResp {
155 shard_id: 0,
156 partials: vec![ArrayAggPartial::from_single(0, 10.0)],
157 truncated_before_horizon: false,
158 };
159 let resp_b = ArrayShardAggResp {
160 shard_id: 1,
161 partials: vec![ArrayAggPartial::from_single(0, 20.0)],
162 truncated_before_horizon: false,
163 };
164 let merged = reduce_agg_partials(&[resp_a, resp_b]);
165 assert_eq!(merged.len(), 1);
166 assert_eq!(merged[0].count, 2);
167 assert!((merged[0].sum - 30.0).abs() < f64::EPSILON);
168 }
169
170 #[test]
171 fn reduce_separate_group_keys() {
172 let resp = ArrayShardAggResp {
173 shard_id: 0,
174 partials: vec![
175 ArrayAggPartial::from_single(0, 5.0),
176 ArrayAggPartial::from_single(1, 15.0),
177 ],
178 truncated_before_horizon: false,
179 };
180 let merged = reduce_agg_partials(&[resp]);
181 assert_eq!(merged.len(), 2);
182 }
183
184 #[test]
185 fn merge_empty_partial_is_noop() {
186 let mut a = ArrayAggPartial::from_single(0, 42.0);
187 let empty = ArrayAggPartial {
188 count: 0,
189 ..ArrayAggPartial::from_single(0, 0.0)
190 };
191 a.merge(&empty);
192 assert_eq!(a.count, 1);
193 assert!((a.sum - 42.0).abs() < f64::EPSILON);
194 }
195
196 #[test]
197 fn merge_slice_rows_concatenates() {
198 let r0 = ArrayShardSliceResp {
199 shard_id: 0,
200 rows_msgpack: vec![vec![1u8], vec![2u8]],
201 truncated: false,
202 truncated_before_horizon: false,
203 };
204 let r1 = ArrayShardSliceResp {
205 shard_id: 1,
206 rows_msgpack: vec![vec![3u8]],
207 truncated: false,
208 truncated_before_horizon: false,
209 };
210 let rows = merge_slice_rows(&[r0, r1], 0);
211 assert_eq!(rows.len(), 3);
212 }
213
214 #[test]
215 fn merge_slice_rows_applies_coordinator_limit() {
216 let resp = ArrayShardSliceResp {
217 shard_id: 0,
218 rows_msgpack: vec![vec![1u8], vec![2u8], vec![3u8], vec![4u8], vec![5u8]],
219 truncated: false,
220 truncated_before_horizon: false,
221 };
222 let rows = merge_slice_rows(&[resp], 3);
223 assert_eq!(rows.len(), 3);
224 assert_eq!(rows[0], vec![1u8]);
225 assert_eq!(rows[2], vec![3u8]);
226 }
227
228 #[test]
229 fn reduce_min_across_shards() {
230 let resp_a = ArrayShardAggResp {
231 shard_id: 0,
232 partials: vec![ArrayAggPartial::from_single(0, 5.0)],
233 truncated_before_horizon: false,
234 };
235 let resp_b = ArrayShardAggResp {
236 shard_id: 1,
237 partials: vec![ArrayAggPartial::from_single(0, 3.0)],
238 truncated_before_horizon: false,
239 };
240 let merged = reduce_agg_partials(&[resp_a, resp_b]);
241 assert_eq!(merged.len(), 1);
242 assert!((merged[0].min - 3.0).abs() < f64::EPSILON);
243 }
244
245 #[test]
246 fn reduce_max_across_shards() {
247 let resp_a = ArrayShardAggResp {
248 shard_id: 0,
249 partials: vec![ArrayAggPartial::from_single(0, 5.0)],
250 truncated_before_horizon: false,
251 };
252 let resp_b = ArrayShardAggResp {
253 shard_id: 1,
254 partials: vec![ArrayAggPartial::from_single(0, 99.0)],
255 truncated_before_horizon: false,
256 };
257 let merged = reduce_agg_partials(&[resp_a, resp_b]);
258 assert_eq!(merged.len(), 1);
259 assert!((merged[0].max - 99.0).abs() < f64::EPSILON);
260 }
261
262 #[test]
263 fn reduce_avg_welford_merge_exact() {
264 let mut a = ArrayAggPartial::from_single(0, 10.0);
267 let b = ArrayAggPartial::from_single(0, 20.0);
268 a.merge(&b);
269 assert!((a.welford_mean - 15.0).abs() < 1e-9);
271 assert_eq!(a.count, 2);
272 assert!((a.sum - 30.0).abs() < f64::EPSILON);
273 }
274
275 #[test]
276 fn reduce_grouped_overlapping_keys() {
277 let resp_a = ArrayShardAggResp {
279 shard_id: 0,
280 partials: vec![
281 ArrayAggPartial::from_single(0, 5.0),
282 ArrayAggPartial::from_single(1, 10.0),
283 ],
284 truncated_before_horizon: false,
285 };
286 let resp_b = ArrayShardAggResp {
287 shard_id: 1,
288 partials: vec![
289 ArrayAggPartial::from_single(1, 20.0),
290 ArrayAggPartial::from_single(2, 30.0),
291 ],
292 truncated_before_horizon: false,
293 };
294 let merged = reduce_agg_partials(&[resp_a, resp_b]);
295 assert_eq!(merged.len(), 3);
296 let g0 = merged.iter().find(|p| p.group_key == 0).unwrap();
297 let g1 = merged.iter().find(|p| p.group_key == 1).unwrap();
298 let g2 = merged.iter().find(|p| p.group_key == 2).unwrap();
299 assert!((g0.sum - 5.0).abs() < f64::EPSILON);
300 assert!((g1.sum - 30.0).abs() < f64::EPSILON);
301 assert!((g2.sum - 30.0).abs() < f64::EPSILON);
302 }
303
304 #[test]
305 fn truncated_before_horizon_or_combines_across_shards() {
306 let r0 = ArrayShardSliceResp {
307 shard_id: 0,
308 rows_msgpack: vec![],
309 truncated: false,
310 truncated_before_horizon: true,
311 };
312 let r1 = ArrayShardSliceResp {
313 shard_id: 1,
314 rows_msgpack: vec![vec![1u8]],
315 truncated: false,
316 truncated_before_horizon: false,
317 };
318 assert!(any_truncated_before_horizon_slice(&[r0, r1]));
319
320 let a0 = ArrayShardAggResp {
321 shard_id: 0,
322 partials: vec![],
323 truncated_before_horizon: false,
324 };
325 let a1 = ArrayShardAggResp {
326 shard_id: 1,
327 partials: vec![],
328 truncated_before_horizon: true,
329 };
330 assert!(any_truncated_before_horizon_agg(&[a0, a1]));
331
332 let a_none = ArrayShardAggResp {
333 shard_id: 2,
334 partials: vec![],
335 truncated_before_horizon: false,
336 };
337 assert!(!any_truncated_before_horizon_agg(&[a_none]));
338 }
339
340 #[test]
341 fn reduce_grouped_disjoint_keys() {
342 let resp_a = ArrayShardAggResp {
344 shard_id: 0,
345 partials: vec![ArrayAggPartial::from_single(0, 7.0)],
346 truncated_before_horizon: false,
347 };
348 let resp_b = ArrayShardAggResp {
349 shard_id: 1,
350 partials: vec![ArrayAggPartial::from_single(1, 13.0)],
351 truncated_before_horizon: false,
352 };
353 let merged = reduce_agg_partials(&[resp_a, resp_b]);
354 assert_eq!(merged.len(), 2);
355 let g0 = merged.iter().find(|p| p.group_key == 0).unwrap();
356 let g1 = merged.iter().find(|p| p.group_key == 1).unwrap();
357 assert_eq!(g0.count, 1);
358 assert_eq!(g1.count, 1);
359 }
360
361 #[test]
362 fn merge_slice_rows_limit_across_shards() {
363 let r0 = ArrayShardSliceResp {
364 shard_id: 0,
365 rows_msgpack: vec![vec![1u8], vec![2u8]],
366 truncated: false,
367 truncated_before_horizon: false,
368 };
369 let r1 = ArrayShardSliceResp {
370 shard_id: 1,
371 rows_msgpack: vec![vec![3u8], vec![4u8]],
372 truncated: false,
373 truncated_before_horizon: false,
374 };
375 let rows = merge_slice_rows(&[r0, r1], 3);
377 assert_eq!(rows.len(), 3);
378 }
379
380 #[test]
381 fn merge_slice_rows_zero_limit_is_unlimited() {
382 let resp = ArrayShardSliceResp {
383 shard_id: 0,
384 rows_msgpack: (0u8..20).map(|i| vec![i]).collect(),
385 truncated: false,
386 truncated_before_horizon: false,
387 };
388 let rows = merge_slice_rows(&[resp], 0);
389 assert_eq!(rows.len(), 20);
390 }
391}