Skip to main content

nodedb_cluster/distributed_array/coordinator/
mod.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3pub mod read;
4pub mod write;
5
6pub use read::{ArrayCoordParams, ArrayCoordinator, CoordSliceResult};
7pub use write::{ArrayWriteCoordParams, coord_delete, coord_put, coord_put_partitioned};
8
9#[cfg(test)]
10mod tests {
11    use std::sync::Arc;
12
13    use async_trait::async_trait;
14
15    use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
16    use crate::error::Result;
17    use crate::wire::{VShardEnvelope, VShardMessageType};
18
19    use super::super::merge::ArrayAggPartial;
20    use super::super::rpc::ShardRpcDispatch;
21    use super::super::wire::{
22        ArrayShardAggReq, ArrayShardAggResp, ArrayShardSliceReq, ArrayShardSliceResp,
23    };
24    use super::read::{ArrayCoordParams, ArrayCoordinator};
25
26    /// Mock dispatch that returns a pre-serialised `ArrayShardSliceResp`.
27    struct SliceEchoDispatch {
28        /// Rows to return from each shard.
29        rows: Vec<Vec<u8>>,
30    }
31
32    #[async_trait]
33    impl ShardRpcDispatch for SliceEchoDispatch {
34        async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
35            let resp = ArrayShardSliceResp {
36                shard_id: req.vshard_id,
37                rows_msgpack: self.rows.clone(),
38                truncated: false,
39                truncated_before_horizon: false,
40            };
41            let payload = zerompk::to_msgpack_vec(&resp).unwrap();
42            Ok(VShardEnvelope::new(
43                VShardMessageType::ArrayShardSliceResp,
44                req.target_node,
45                req.source_node,
46                req.vshard_id,
47                payload,
48            ))
49        }
50    }
51
52    /// Mock dispatch that returns a pre-canned `ArrayShardAggResp`.
53    struct AggEchoDispatch {
54        partials: Vec<ArrayAggPartial>,
55    }
56
57    #[async_trait]
58    impl ShardRpcDispatch for AggEchoDispatch {
59        async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
60            let resp = ArrayShardAggResp {
61                shard_id: req.vshard_id,
62                partials: self.partials.clone(),
63                truncated_before_horizon: false,
64            };
65            let payload = zerompk::to_msgpack_vec(&resp).unwrap();
66            Ok(VShardEnvelope::new(
67                VShardMessageType::ArrayShardSliceResp,
68                req.target_node,
69                req.source_node,
70                req.vshard_id,
71                payload,
72            ))
73        }
74    }
75
76    fn make_coordinator(
77        shard_ids: Vec<u32>,
78        dispatch: Arc<dyn ShardRpcDispatch>,
79    ) -> ArrayCoordinator {
80        ArrayCoordinator::new(
81            ArrayCoordParams {
82                source_node: 1,
83                shard_ids,
84                timeout_ms: 1000,
85                // Tests use prefix_bits=0 so shard-side routing validation
86                // is skipped — mock executors don't need to match Hilbert
87                // ownership.
88                prefix_bits: 0,
89                slice_hilbert_ranges: vec![],
90            },
91            dispatch,
92            Arc::new(CircuitBreaker::new(CircuitBreakerConfig::default())),
93        )
94    }
95
96    #[tokio::test]
97    async fn coord_slice_merges_rows_from_all_shards() {
98        let row_a = zerompk::to_msgpack_vec(&"row-a").unwrap();
99        let row_b = zerompk::to_msgpack_vec(&"row-b").unwrap();
100        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(SliceEchoDispatch {
101            rows: vec![row_a.clone(), row_b.clone()],
102        });
103        let coord = make_coordinator(vec![0, 1, 2], dispatch);
104        let req = ArrayShardSliceReq {
105            array_id_msgpack: vec![],
106            slice_msgpack: vec![],
107            attr_projection: vec![],
108            limit: 100,
109            cell_filter_msgpack: vec![],
110            prefix_bits: 0,
111            slice_hilbert_ranges: vec![],
112            shard_hilbert_range: None,
113            system_as_of: None,
114            valid_at_ms: None,
115        };
116
117        // 3 shards × 2 rows each = 6 merged rows.
118        let result = coord
119            .coord_slice(req, 0)
120            .await
121            .expect("coord_slice should succeed");
122        assert_eq!(result.rows.len(), 6);
123        assert!(!result.truncated_before_horizon);
124    }
125
126    #[tokio::test]
127    async fn coord_slice_applies_coordinator_limit() {
128        let row = zerompk::to_msgpack_vec(&"row").unwrap();
129        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(SliceEchoDispatch {
130            rows: vec![row.clone(), row.clone(), row.clone()],
131        });
132        // 2 shards × 3 rows = 6 total, but limit = 4.
133        let coord = make_coordinator(vec![0, 1], dispatch);
134        let req = ArrayShardSliceReq {
135            array_id_msgpack: vec![],
136            slice_msgpack: vec![],
137            attr_projection: vec![],
138            limit: 3,
139            cell_filter_msgpack: vec![],
140            prefix_bits: 0,
141            slice_hilbert_ranges: vec![],
142            shard_hilbert_range: None,
143            system_as_of: None,
144            valid_at_ms: None,
145        };
146
147        let result = coord
148            .coord_slice(req, 4)
149            .await
150            .expect("coord_slice with limit should succeed");
151        assert_eq!(result.rows.len(), 4);
152    }
153
154    fn make_agg_req() -> ArrayShardAggReq {
155        // Sum reducer c_enum = 0.
156        ArrayShardAggReq {
157            array_id_msgpack: vec![],
158            attr_idx: 0,
159            reducer_msgpack: vec![0x00],
160            group_by_dim: -1,
161            cell_filter_msgpack: vec![],
162            shard_hilbert_range: None,
163            system_as_of: None,
164            valid_at_ms: None,
165        }
166    }
167
168    #[tokio::test]
169    async fn coord_agg_merges_scalar_partials_from_shards() {
170        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(AggEchoDispatch {
171            partials: vec![ArrayAggPartial::from_single(0, 10.0)],
172        });
173        // 3 shards each returning a partial with sum=10 → merged sum=30.
174        let coord = make_coordinator(vec![0, 1, 2], dispatch);
175        let merged = coord
176            .coord_agg(make_agg_req())
177            .await
178            .expect("coord_agg should succeed");
179
180        assert_eq!(merged.len(), 1);
181        assert_eq!(merged[0].count, 3);
182        assert!((merged[0].sum - 30.0).abs() < f64::EPSILON);
183    }
184
185    #[tokio::test]
186    async fn coord_agg_with_empty_shards_returns_empty() {
187        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(AggEchoDispatch { partials: vec![] });
188        let coord = make_coordinator(vec![0, 1], dispatch);
189        let merged = coord
190            .coord_agg(make_agg_req())
191            .await
192            .expect("coord_agg with empty shards should succeed");
193        assert!(merged.is_empty());
194    }
195
196    #[tokio::test]
197    async fn coord_agg_merges_grouped_partials_across_shards() {
198        // Shard 0 returns group_key=0 partial, shard 1 also group_key=0 + group_key=1.
199        struct GroupedDispatch {
200            shard0_partials: Vec<ArrayAggPartial>,
201            shard1_partials: Vec<ArrayAggPartial>,
202        }
203
204        #[async_trait]
205        impl ShardRpcDispatch for GroupedDispatch {
206            async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
207                let partials = if req.vshard_id == 0 {
208                    self.shard0_partials.clone()
209                } else {
210                    self.shard1_partials.clone()
211                };
212                let resp = ArrayShardAggResp {
213                    shard_id: req.vshard_id,
214                    partials,
215                    truncated_before_horizon: false,
216                };
217                let payload = zerompk::to_msgpack_vec(&resp).unwrap();
218                Ok(VShardEnvelope::new(
219                    VShardMessageType::ArrayShardSliceResp,
220                    req.target_node,
221                    req.source_node,
222                    req.vshard_id,
223                    payload,
224                ))
225            }
226        }
227
228        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(GroupedDispatch {
229            shard0_partials: vec![ArrayAggPartial::from_single(0, 5.0)],
230            shard1_partials: vec![
231                ArrayAggPartial::from_single(0, 15.0),
232                ArrayAggPartial::from_single(1, 20.0),
233            ],
234        });
235        let coord = make_coordinator(vec![0, 1], dispatch);
236        let merged = coord
237            .coord_agg(make_agg_req())
238            .await
239            .expect("grouped coord_agg should succeed");
240
241        // group_key=0: sum=5+15=20, count=2; group_key=1: sum=20, count=1.
242        assert_eq!(merged.len(), 2);
243        let g0 = merged.iter().find(|p| p.group_key == 0).expect("group 0");
244        let g1 = merged.iter().find(|p| p.group_key == 1).expect("group 1");
245        assert!((g0.sum - 20.0).abs() < f64::EPSILON);
246        assert_eq!(g0.count, 2);
247        assert!((g1.sum - 20.0).abs() < f64::EPSILON);
248        assert_eq!(g1.count, 1);
249    }
250
251    #[tokio::test]
252    async fn coord_slice_zero_limit_returns_all() {
253        let row = zerompk::to_msgpack_vec(&"r").unwrap();
254        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(SliceEchoDispatch {
255            rows: vec![row.clone(); 10],
256        });
257        let coord = make_coordinator(vec![0, 1], dispatch);
258        let req = ArrayShardSliceReq {
259            array_id_msgpack: vec![],
260            slice_msgpack: vec![],
261            attr_projection: vec![],
262            limit: 0,
263            cell_filter_msgpack: vec![],
264            prefix_bits: 0,
265            slice_hilbert_ranges: vec![],
266            shard_hilbert_range: None,
267            system_as_of: None,
268            valid_at_ms: None,
269        };
270
271        // coordinator_limit = 0 → no cutoff → 20 rows.
272        let result = coord
273            .coord_slice(req, 0)
274            .await
275            .expect("coord_slice unlimited should succeed");
276        assert_eq!(result.rows.len(), 20);
277    }
278
279    // ── coord_put / coord_delete tests ────────────────────────────────────
280
281    use super::super::wire::{ArrayShardDeleteResp, ArrayShardPutReq, ArrayShardPutResp};
282    use super::write::{ArrayWriteCoordParams, coord_delete, coord_put};
283    use crate::error::ClusterError;
284
285    /// Records which vShard IDs were called and echoes back an `ArrayShardPutResp`.
286    struct PutEchoDispatch;
287
288    #[async_trait]
289    impl ShardRpcDispatch for PutEchoDispatch {
290        async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
291            let shard_req: ArrayShardPutReq = zerompk::from_msgpack(&req.payload).unwrap();
292            let resp = ArrayShardPutResp {
293                shard_id: req.vshard_id,
294                applied_lsn: shard_req.wal_lsn,
295            };
296            let payload = zerompk::to_msgpack_vec(&resp).unwrap();
297            Ok(VShardEnvelope::new(
298                VShardMessageType::ArrayShardSliceResp,
299                req.target_node,
300                req.source_node,
301                req.vshard_id,
302                payload,
303            ))
304        }
305    }
306
307    /// Dispatch that always returns a Codec error — used for failure-propagation tests.
308    struct FailDispatch;
309
310    #[async_trait]
311    impl ShardRpcDispatch for FailDispatch {
312        async fn call(&self, _req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
313            Err(ClusterError::Codec {
314                detail: "injected failure".into(),
315            })
316        }
317    }
318
319    /// Echo dispatch for delete that returns an `ArrayShardDeleteResp`.
320    struct DeleteEchoDispatch;
321
322    #[async_trait]
323    impl ShardRpcDispatch for DeleteEchoDispatch {
324        async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
325            use super::super::wire::ArrayShardDeleteReq;
326            let shard_req: ArrayShardDeleteReq = zerompk::from_msgpack(&req.payload).unwrap();
327            let resp = ArrayShardDeleteResp {
328                shard_id: req.vshard_id,
329                applied_lsn: shard_req.wal_lsn,
330            };
331            let payload = zerompk::to_msgpack_vec(&resp).unwrap();
332            Ok(VShardEnvelope::new(
333                VShardMessageType::ArrayShardSliceResp,
334                req.target_node,
335                req.source_node,
336                req.vshard_id,
337                payload,
338            ))
339        }
340    }
341
342    fn write_params() -> ArrayWriteCoordParams {
343        ArrayWriteCoordParams {
344            source_node: 1,
345            timeout_ms: 1000,
346        }
347    }
348
349    fn cb() -> Arc<CircuitBreaker> {
350        Arc::new(CircuitBreaker::new(CircuitBreakerConfig::default()))
351    }
352
353    #[tokio::test]
354    async fn coord_put_partitions_cells_by_tile() {
355        // prefix_bits=10, stride=1 → vshard == top-10-bit bucket.
356        // p0 → bucket 0 → vshard 0
357        // p1 → bucket 1 → vshard 1
358        // p2 → bucket 2 → vshard 2
359        let p0 = 0x0000_0000_0000_0000u64;
360        let p1 = 0x0040_0000_0000_0000u64;
361        let p2 = 0x0080_0000_0000_0000u64;
362
363        let cells = vec![
364            (p0, vec![0x01u8]),
365            (p1, vec![0x02u8]),
366            (p0, vec![0x03u8]),
367            (p2, vec![0x04u8]),
368            (p1, vec![0x05u8]),
369        ];
370
371        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(PutEchoDispatch);
372        let mut resps = coord_put(&write_params(), vec![], 10, 42, &cells, &dispatch, &cb())
373            .await
374            .expect("coord_put should succeed");
375
376        resps.sort_by_key(|r| r.shard_id);
377        assert_eq!(resps.len(), 3, "should fan-out to 3 shards");
378        assert_eq!(resps[0].shard_id, 0);
379        assert_eq!(resps[1].shard_id, 1);
380        assert_eq!(resps[2].shard_id, 2);
381        // Each shard echoes back wal_lsn=42.
382        for r in &resps {
383            assert_eq!(r.applied_lsn, 42);
384        }
385    }
386
387    #[tokio::test]
388    async fn coord_put_aggregates_partial_failures() {
389        // A failing dispatch must surface as an error, not silent partial success.
390        let cells = vec![(0u64, vec![0xAAu8])];
391        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(FailDispatch);
392        let err = coord_put(&write_params(), vec![], 10, 1, &cells, &dispatch, &cb())
393            .await
394            .expect_err("coord_put with failing shard should return error");
395        assert!(
396            matches!(err, ClusterError::Codec { .. }),
397            "expected Codec error, got {err:?}"
398        );
399    }
400
401    #[tokio::test]
402    async fn coord_delete_partitions_by_tile() {
403        let p0 = 0x0000_0000_0000_0000u64;
404        let p1 = 0x0040_0000_0000_0000u64;
405
406        let coords = vec![(p0, vec![0xAAu8]), (p1, vec![0xBBu8]), (p0, vec![0xCCu8])];
407
408        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(DeleteEchoDispatch);
409        let mut resps = coord_delete(&write_params(), vec![], 10, 55, &coords, &dispatch, &cb())
410            .await
411            .expect("coord_delete should succeed");
412
413        resps.sort_by_key(|r| r.shard_id);
414        assert_eq!(resps.len(), 2, "should fan-out to 2 shards");
415        assert_eq!(resps[0].shard_id, 0);
416        assert_eq!(resps[1].shard_id, 1);
417        for r in &resps {
418            assert_eq!(r.applied_lsn, 55);
419        }
420    }
421}