Skip to main content

nodedb_cluster/distributed_array/
scatter.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Fan-out helpers for array distributed operations.
4//!
5//! Two entry points:
6//! - `fan_out` — broadcast one request to all listed shards (slice, agg,
7//!   delete, surrogate scan).
8//! - `fan_out_partitioned` — send a different payload to each shard
9//!   (used by `coord_put` where cells are routed by Hilbert prefix).
10//!
11//! Both functions dispatch concurrently via `FuturesUnordered`, apply a
12//! per-shard `tokio::time::timeout`, and wrap each call in the cluster
13//! `CircuitBreaker`. Any shard failure is propagated as `Err` — partial
14//! results are not silently dropped.
15
16use std::sync::Arc;
17
18use futures::StreamExt;
19
20use crate::circuit_breaker::CircuitBreaker;
21use crate::error::{ClusterError, Result};
22use crate::wire::{VShardEnvelope, VShardMessageType};
23
24/// Dispatch one shard call with one automatic retry on `WrongOwner`.
25///
26/// When a shard returns `WrongOwner` it means the coordinator's routing table
27/// was stale at the time the request was built. The `ShardRpcDispatch`
28/// implementor holds a live reference to the `RoutingTable` (updated by the
29/// `CacheApplier` on every committed `RoutingChange`), so re-issuing the same
30/// envelope lets the implementor route it to the current owner without any
31/// explicit routing-refresh API call on the coordinator side. A second
32/// `WrongOwner` propagates as an error — we do not loop.
33async fn call_with_wrong_owner_retry(
34    dispatch: &Arc<dyn ShardRpcDispatch>,
35    env: VShardEnvelope,
36    timeout_ms: u64,
37) -> std::result::Result<VShardEnvelope, ClusterError> {
38    let result = tokio::time::timeout(
39        std::time::Duration::from_millis(timeout_ms),
40        dispatch.call(env.clone(), timeout_ms),
41    )
42    .await;
43
44    match result {
45        Ok(Ok(resp)) => return Ok(resp),
46        Ok(Err(ClusterError::WrongOwner { .. })) => {
47            // Retry once: the dispatch impl will re-read the live routing table.
48        }
49        Ok(Err(e)) => return Err(e),
50        Err(_elapsed) => {
51            return Err(ClusterError::Transport {
52                detail: format!(
53                    "array shard {}: RPC timed out after {timeout_ms}ms",
54                    env.vshard_id
55                ),
56            });
57        }
58    }
59
60    // Second attempt — propagate whatever error arises, including a second WrongOwner.
61    let result = tokio::time::timeout(
62        std::time::Duration::from_millis(timeout_ms),
63        dispatch.call(env.clone(), timeout_ms),
64    )
65    .await;
66
67    match result {
68        Ok(Ok(resp)) => Ok(resp),
69        Ok(Err(e)) => Err(e),
70        Err(_elapsed) => Err(ClusterError::Transport {
71            detail: format!(
72                "array shard {}: RPC timed out after {timeout_ms}ms (retry)",
73                env.vshard_id
74            ),
75        }),
76    }
77}
78
79use super::rpc::ShardRpcDispatch;
80
81/// Parameters governing a single fan-out round.
82pub struct FanOutParams {
83    /// Shard IDs to contact (broadcast target list).
84    pub shard_ids: Vec<u32>,
85    /// Per-shard RPC timeout in milliseconds.
86    pub timeout_ms: u64,
87    /// Source node ID (used to tag outgoing envelopes).
88    pub source_node: u64,
89}
90
91/// Parameters for a partitioned fan-out where each shard receives a
92/// different payload. `per_shard` entries are `(vshard_id, payload_bytes)`.
93pub struct FanOutPartitionedParams {
94    /// Per-shard RPC timeout in milliseconds.
95    pub timeout_ms: u64,
96    /// Source node ID (used to tag outgoing envelopes).
97    pub source_node: u64,
98}
99
100/// Send `req_bytes` to every shard listed in `params` and collect responses.
101///
102/// Each shard RPC runs concurrently via `FuturesUnordered`. Per-shard
103/// timeouts are enforced by `tokio::time::timeout`. The circuit breaker
104/// is checked before each call and updated on success/failure. Any shard
105/// error causes the whole fan-out to return `Err` — the coordinator
106/// decides whether to retry.
107///
108/// Returns `Vec<(shard_id, response_payload_bytes)>` in arrival order.
109pub async fn fan_out(
110    params: &FanOutParams,
111    opcode: u32,
112    req_bytes: &[u8],
113    dispatch: &Arc<dyn ShardRpcDispatch>,
114    circuit_breaker: &CircuitBreaker,
115) -> Result<Vec<(u32, Vec<u8>)>> {
116    if params.shard_ids.is_empty() {
117        return Ok(Vec::new());
118    }
119
120    // Build one future per shard and collect via FuturesUnordered for
121    // true concurrency (no sequential .await loop).
122    let mut futs = futures::stream::FuturesUnordered::new();
123
124    for &shard_id in &params.shard_ids {
125        // Circuit-breaker gate: treat shard_id as the peer identifier.
126        circuit_breaker.check(shard_id as u64)?;
127
128        let env = VShardEnvelope::new(
129            msg_type_from_opcode(opcode)?,
130            params.source_node,
131            0, // target_node resolved by the dispatch impl
132            shard_id,
133            req_bytes.to_vec(),
134        );
135        let timeout_ms = params.timeout_ms;
136        let dispatch = Arc::clone(dispatch);
137        let cb_shard = shard_id;
138
139        futs.push(async move {
140            match call_with_wrong_owner_retry(&dispatch, env, timeout_ms).await {
141                Ok(resp) => Ok((cb_shard, resp.payload)),
142                Err(e) => Err((cb_shard, e)),
143            }
144        });
145    }
146
147    let mut results = Vec::with_capacity(params.shard_ids.len());
148    while let Some(outcome) = futs.next().await {
149        match outcome {
150            Ok((shard_id, payload)) => {
151                circuit_breaker.record_success(shard_id as u64);
152                results.push((shard_id, payload));
153            }
154            Err((shard_id, e)) => {
155                circuit_breaker.record_failure(shard_id as u64);
156                return Err(e);
157            }
158        }
159    }
160
161    Ok(results)
162}
163
164/// Send a distinct payload to each shard and collect responses.
165///
166/// `per_shard` — `(vshard_id, payload_bytes)` pairs, one per target shard.
167/// Returns `(shard_id, response_payload_bytes)` in arrival order.
168pub async fn fan_out_partitioned(
169    params: &FanOutPartitionedParams,
170    opcode: u32,
171    per_shard: &[(u32, Vec<u8>)],
172    dispatch: &Arc<dyn ShardRpcDispatch>,
173    circuit_breaker: &CircuitBreaker,
174) -> Result<Vec<(u32, Vec<u8>)>> {
175    if per_shard.is_empty() {
176        return Ok(Vec::new());
177    }
178
179    let mut futs = futures::stream::FuturesUnordered::new();
180
181    for (shard_id, payload) in per_shard {
182        circuit_breaker.check(*shard_id as u64)?;
183
184        let env = VShardEnvelope::new(
185            msg_type_from_opcode(opcode)?,
186            params.source_node,
187            0,
188            *shard_id,
189            payload.clone(),
190        );
191        let timeout_ms = params.timeout_ms;
192        let dispatch = Arc::clone(dispatch);
193        let cb_shard = *shard_id;
194
195        futs.push(async move {
196            match call_with_wrong_owner_retry(&dispatch, env, timeout_ms).await {
197                Ok(resp) => Ok((cb_shard, resp.payload)),
198                Err(e) => Err((cb_shard, e)),
199            }
200        });
201    }
202
203    let mut results = Vec::with_capacity(per_shard.len());
204    while let Some(outcome) = futs.next().await {
205        match outcome {
206            Ok((shard_id, payload)) => {
207                circuit_breaker.record_success(shard_id as u64);
208                results.push((shard_id, payload));
209            }
210            Err((shard_id, e)) => {
211                circuit_breaker.record_failure(shard_id as u64);
212                return Err(e);
213            }
214        }
215    }
216
217    Ok(results)
218}
219
220/// Map an opcode constant to a `VShardMessageType`.
221///
222/// All array opcodes are in the range 80–89. Any other value is a
223/// programming error in the coordinator, not a runtime condition.
224fn msg_type_from_opcode(opcode: u32) -> Result<VShardMessageType> {
225    match opcode {
226        80 => Ok(VShardMessageType::ArrayShardSliceReq),
227        81 => Ok(VShardMessageType::ArrayShardSliceResp),
228        82 => Ok(VShardMessageType::ArrayShardAggReq),
229        83 => Ok(VShardMessageType::ArrayShardAggResp),
230        84 => Ok(VShardMessageType::ArrayShardPutReq),
231        85 => Ok(VShardMessageType::ArrayShardPutResp),
232        86 => Ok(VShardMessageType::ArrayShardDeleteReq),
233        87 => Ok(VShardMessageType::ArrayShardDeleteResp),
234        88 => Ok(VShardMessageType::ArrayShardSurrogateBitmapReq),
235        89 => Ok(VShardMessageType::ArrayShardSurrogateBitmapResp),
236        other => Err(ClusterError::Codec {
237            detail: format!("msg_type_from_opcode: unknown array opcode {other}"),
238        }),
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use std::sync::Arc;
245
246    use async_trait::async_trait;
247
248    use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
249    use crate::error::{ClusterError, Result};
250    use crate::wire::VShardEnvelope;
251
252    use super::super::rpc::ShardRpcDispatch;
253    use super::{FanOutParams, FanOutPartitionedParams, fan_out, fan_out_partitioned};
254
255    /// Mock that echoes the request payload back as the response payload,
256    /// stamped with the shard's id in the response envelope.
257    struct EchoDispatch;
258
259    #[async_trait]
260    impl ShardRpcDispatch for EchoDispatch {
261        async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
262            // Echo back a response envelope of the corresponding resp opcode.
263            let resp_opcode = req.msg_type as u32 + 1;
264            let resp_type = super::msg_type_from_opcode(resp_opcode)?;
265            Ok(VShardEnvelope::new(
266                resp_type,
267                req.target_node,
268                req.source_node,
269                req.vshard_id,
270                req.payload,
271            ))
272        }
273    }
274
275    /// Mock that always returns a transport error.
276    struct FailDispatch;
277
278    #[async_trait]
279    impl ShardRpcDispatch for FailDispatch {
280        async fn call(&self, _req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
281            Err(ClusterError::Transport {
282                detail: "injected failure".into(),
283            })
284        }
285    }
286
287    fn cb() -> CircuitBreaker {
288        CircuitBreaker::new(CircuitBreakerConfig::default())
289    }
290
291    #[tokio::test]
292    async fn fan_out_broadcasts_to_all_shards() {
293        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
294        let params = FanOutParams {
295            shard_ids: vec![0, 1, 2],
296            timeout_ms: 1000,
297            source_node: 42,
298        };
299        let req_bytes = b"test-payload";
300        let results = fan_out(
301            &params,
302            super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
303            req_bytes,
304            &dispatch,
305            &cb(),
306        )
307        .await
308        .expect("fan_out should succeed");
309
310        assert_eq!(results.len(), 3);
311        // All shards should echo back the same payload.
312        for (_, payload) in &results {
313            assert_eq!(payload.as_slice(), req_bytes);
314        }
315        // All three shard IDs should be present.
316        let mut ids: Vec<u32> = results.iter().map(|(id, _)| *id).collect();
317        ids.sort_unstable();
318        assert_eq!(ids, vec![0, 1, 2]);
319    }
320
321    #[tokio::test]
322    async fn fan_out_empty_shards_returns_empty() {
323        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
324        let params = FanOutParams {
325            shard_ids: vec![],
326            timeout_ms: 1000,
327            source_node: 1,
328        };
329        let results = fan_out(
330            &params,
331            super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
332            b"",
333            &dispatch,
334            &cb(),
335        )
336        .await
337        .expect("empty fan_out should succeed");
338        assert!(results.is_empty());
339    }
340
341    #[tokio::test]
342    async fn fan_out_propagates_shard_error() {
343        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(FailDispatch);
344        let params = FanOutParams {
345            shard_ids: vec![0, 1],
346            timeout_ms: 1000,
347            source_node: 1,
348        };
349        let err = fan_out(
350            &params,
351            super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
352            b"",
353            &dispatch,
354            &cb(),
355        )
356        .await
357        .expect_err("fan_out should propagate shard failure");
358        assert!(
359            matches!(err, ClusterError::Transport { .. }),
360            "expected Transport error, got {err:?}"
361        );
362    }
363
364    #[tokio::test]
365    async fn fan_out_partitioned_dispatches_distinct_payloads() {
366        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
367        let params = FanOutPartitionedParams {
368            timeout_ms: 1000,
369            source_node: 1,
370        };
371        let per_shard = vec![
372            (0u32, b"shard0-data".to_vec()),
373            (1u32, b"shard1-data".to_vec()),
374        ];
375        let results = fan_out_partitioned(
376            &params,
377            super::super::opcodes::ARRAY_SHARD_PUT_REQ,
378            &per_shard,
379            &dispatch,
380            &cb(),
381        )
382        .await
383        .expect("fan_out_partitioned should succeed");
384
385        assert_eq!(results.len(), 2);
386        let mut sorted = results.clone();
387        sorted.sort_unstable_by_key(|(id, _)| *id);
388        assert_eq!(sorted[0].1, b"shard0-data");
389        assert_eq!(sorted[1].1, b"shard1-data");
390    }
391
392    #[tokio::test]
393    async fn circuit_breaker_open_blocks_fan_out() {
394        use crate::circuit_breaker::CircuitBreakerConfig;
395        use std::time::Duration;
396
397        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
398        let cb = CircuitBreaker::new(CircuitBreakerConfig {
399            failure_threshold: 1,
400            cooldown: Duration::from_secs(60),
401        });
402        // Trip the breaker for shard 0.
403        cb.record_failure(0);
404
405        let params = FanOutParams {
406            shard_ids: vec![0],
407            timeout_ms: 1000,
408            source_node: 1,
409        };
410        let err = fan_out(
411            &params,
412            super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
413            b"",
414            &dispatch,
415            &cb,
416        )
417        .await
418        .expect_err("open circuit should block fan_out");
419        assert!(
420            matches!(err, ClusterError::CircuitOpen { .. }),
421            "expected CircuitOpen, got {err:?}"
422        );
423    }
424
425    // ── WrongOwner retry tests ────────────────────────────────────────────
426
427    use std::sync::atomic::{AtomicU32, Ordering};
428
429    /// Dispatch that returns `WrongOwner` exactly `fail_count` times, then
430    /// echoes the request payload back as a success.
431    struct WrongOwnerThenEchoDispatch {
432        call_count: Arc<AtomicU32>,
433        fail_count: u32,
434    }
435
436    #[async_trait]
437    impl ShardRpcDispatch for WrongOwnerThenEchoDispatch {
438        async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
439            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
440            if n < self.fail_count {
441                return Err(ClusterError::WrongOwner {
442                    vshard_id: req.vshard_id,
443                    expected_owner_node: None,
444                });
445            }
446            let resp_opcode = req.msg_type as u32 + 1;
447            let resp_type = super::msg_type_from_opcode(resp_opcode)?;
448            Ok(VShardEnvelope::new(
449                resp_type,
450                req.target_node,
451                req.source_node,
452                req.vshard_id,
453                req.payload,
454            ))
455        }
456    }
457
458    #[tokio::test]
459    async fn wrong_owner_triggers_retry_once() {
460        // First call returns WrongOwner; retry (second call) succeeds.
461        let call_count = Arc::new(AtomicU32::new(0));
462        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(WrongOwnerThenEchoDispatch {
463            call_count: call_count.clone(),
464            fail_count: 1,
465        });
466        let params = FanOutParams {
467            shard_ids: vec![0],
468            timeout_ms: 1000,
469            source_node: 1,
470        };
471        let result = fan_out(
472            &params,
473            super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
474            b"payload",
475            &dispatch,
476            &cb(),
477        )
478        .await
479        .expect("fan_out should succeed after one WrongOwner retry");
480
481        assert_eq!(result.len(), 1);
482        assert_eq!(
483            call_count.load(Ordering::SeqCst),
484            2,
485            "must have called dispatch twice"
486        );
487    }
488
489    #[tokio::test]
490    async fn wrong_owner_twice_propagates() {
491        // Both attempts return WrongOwner → fan_out surfaces the error.
492        let call_count = Arc::new(AtomicU32::new(0));
493        let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(WrongOwnerThenEchoDispatch {
494            call_count: call_count.clone(),
495            fail_count: 2,
496        });
497        let params = FanOutParams {
498            shard_ids: vec![0],
499            timeout_ms: 1000,
500            source_node: 1,
501        };
502        let err = fan_out(
503            &params,
504            super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
505            b"payload",
506            &dispatch,
507            &cb(),
508        )
509        .await
510        .expect_err("fan_out should propagate WrongOwner when both attempts fail");
511
512        assert!(
513            matches!(err, ClusterError::WrongOwner { .. }),
514            "expected WrongOwner, got {err:?}"
515        );
516        assert_eq!(
517            call_count.load(Ordering::SeqCst),
518            2,
519            "must have called dispatch twice"
520        );
521    }
522}