1use std::sync::Arc;
17
18use futures::StreamExt;
19
20use crate::circuit_breaker::CircuitBreaker;
21use crate::error::{ClusterError, Result};
22use crate::wire::{VShardEnvelope, VShardMessageType};
23
24async fn call_with_wrong_owner_retry(
34 dispatch: &Arc<dyn ShardRpcDispatch>,
35 env: VShardEnvelope,
36 timeout_ms: u64,
37) -> std::result::Result<VShardEnvelope, ClusterError> {
38 let result = tokio::time::timeout(
39 std::time::Duration::from_millis(timeout_ms),
40 dispatch.call(env.clone(), timeout_ms),
41 )
42 .await;
43
44 match result {
45 Ok(Ok(resp)) => return Ok(resp),
46 Ok(Err(ClusterError::WrongOwner { .. })) => {
47 }
49 Ok(Err(e)) => return Err(e),
50 Err(_elapsed) => {
51 return Err(ClusterError::Transport {
52 detail: format!(
53 "array shard {}: RPC timed out after {timeout_ms}ms",
54 env.vshard_id
55 ),
56 });
57 }
58 }
59
60 let result = tokio::time::timeout(
62 std::time::Duration::from_millis(timeout_ms),
63 dispatch.call(env.clone(), timeout_ms),
64 )
65 .await;
66
67 match result {
68 Ok(Ok(resp)) => Ok(resp),
69 Ok(Err(e)) => Err(e),
70 Err(_elapsed) => Err(ClusterError::Transport {
71 detail: format!(
72 "array shard {}: RPC timed out after {timeout_ms}ms (retry)",
73 env.vshard_id
74 ),
75 }),
76 }
77}
78
79use super::rpc::ShardRpcDispatch;
80
81pub struct FanOutParams {
83 pub shard_ids: Vec<u32>,
85 pub timeout_ms: u64,
87 pub source_node: u64,
89}
90
91pub struct FanOutPartitionedParams {
94 pub timeout_ms: u64,
96 pub source_node: u64,
98}
99
100pub async fn fan_out(
110 params: &FanOutParams,
111 opcode: u32,
112 req_bytes: &[u8],
113 dispatch: &Arc<dyn ShardRpcDispatch>,
114 circuit_breaker: &CircuitBreaker,
115) -> Result<Vec<(u32, Vec<u8>)>> {
116 if params.shard_ids.is_empty() {
117 return Ok(Vec::new());
118 }
119
120 let mut futs = futures::stream::FuturesUnordered::new();
123
124 for &shard_id in ¶ms.shard_ids {
125 circuit_breaker.check(shard_id as u64)?;
127
128 let env = VShardEnvelope::new(
129 msg_type_from_opcode(opcode)?,
130 params.source_node,
131 0, shard_id,
133 req_bytes.to_vec(),
134 );
135 let timeout_ms = params.timeout_ms;
136 let dispatch = Arc::clone(dispatch);
137 let cb_shard = shard_id;
138
139 futs.push(async move {
140 match call_with_wrong_owner_retry(&dispatch, env, timeout_ms).await {
141 Ok(resp) => Ok((cb_shard, resp.payload)),
142 Err(e) => Err((cb_shard, e)),
143 }
144 });
145 }
146
147 let mut results = Vec::with_capacity(params.shard_ids.len());
148 while let Some(outcome) = futs.next().await {
149 match outcome {
150 Ok((shard_id, payload)) => {
151 circuit_breaker.record_success(shard_id as u64);
152 results.push((shard_id, payload));
153 }
154 Err((shard_id, e)) => {
155 circuit_breaker.record_failure(shard_id as u64);
156 return Err(e);
157 }
158 }
159 }
160
161 Ok(results)
162}
163
164pub async fn fan_out_partitioned(
169 params: &FanOutPartitionedParams,
170 opcode: u32,
171 per_shard: &[(u32, Vec<u8>)],
172 dispatch: &Arc<dyn ShardRpcDispatch>,
173 circuit_breaker: &CircuitBreaker,
174) -> Result<Vec<(u32, Vec<u8>)>> {
175 if per_shard.is_empty() {
176 return Ok(Vec::new());
177 }
178
179 let mut futs = futures::stream::FuturesUnordered::new();
180
181 for (shard_id, payload) in per_shard {
182 circuit_breaker.check(*shard_id as u64)?;
183
184 let env = VShardEnvelope::new(
185 msg_type_from_opcode(opcode)?,
186 params.source_node,
187 0,
188 *shard_id,
189 payload.clone(),
190 );
191 let timeout_ms = params.timeout_ms;
192 let dispatch = Arc::clone(dispatch);
193 let cb_shard = *shard_id;
194
195 futs.push(async move {
196 match call_with_wrong_owner_retry(&dispatch, env, timeout_ms).await {
197 Ok(resp) => Ok((cb_shard, resp.payload)),
198 Err(e) => Err((cb_shard, e)),
199 }
200 });
201 }
202
203 let mut results = Vec::with_capacity(per_shard.len());
204 while let Some(outcome) = futs.next().await {
205 match outcome {
206 Ok((shard_id, payload)) => {
207 circuit_breaker.record_success(shard_id as u64);
208 results.push((shard_id, payload));
209 }
210 Err((shard_id, e)) => {
211 circuit_breaker.record_failure(shard_id as u64);
212 return Err(e);
213 }
214 }
215 }
216
217 Ok(results)
218}
219
220fn msg_type_from_opcode(opcode: u32) -> Result<VShardMessageType> {
225 match opcode {
226 80 => Ok(VShardMessageType::ArrayShardSliceReq),
227 81 => Ok(VShardMessageType::ArrayShardSliceResp),
228 82 => Ok(VShardMessageType::ArrayShardAggReq),
229 83 => Ok(VShardMessageType::ArrayShardAggResp),
230 84 => Ok(VShardMessageType::ArrayShardPutReq),
231 85 => Ok(VShardMessageType::ArrayShardPutResp),
232 86 => Ok(VShardMessageType::ArrayShardDeleteReq),
233 87 => Ok(VShardMessageType::ArrayShardDeleteResp),
234 88 => Ok(VShardMessageType::ArrayShardSurrogateBitmapReq),
235 89 => Ok(VShardMessageType::ArrayShardSurrogateBitmapResp),
236 other => Err(ClusterError::Codec {
237 detail: format!("msg_type_from_opcode: unknown array opcode {other}"),
238 }),
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use std::sync::Arc;
245
246 use async_trait::async_trait;
247
248 use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
249 use crate::error::{ClusterError, Result};
250 use crate::wire::VShardEnvelope;
251
252 use super::super::rpc::ShardRpcDispatch;
253 use super::{FanOutParams, FanOutPartitionedParams, fan_out, fan_out_partitioned};
254
255 struct EchoDispatch;
258
259 #[async_trait]
260 impl ShardRpcDispatch for EchoDispatch {
261 async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
262 let resp_opcode = req.msg_type as u32 + 1;
264 let resp_type = super::msg_type_from_opcode(resp_opcode)?;
265 Ok(VShardEnvelope::new(
266 resp_type,
267 req.target_node,
268 req.source_node,
269 req.vshard_id,
270 req.payload,
271 ))
272 }
273 }
274
275 struct FailDispatch;
277
278 #[async_trait]
279 impl ShardRpcDispatch for FailDispatch {
280 async fn call(&self, _req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
281 Err(ClusterError::Transport {
282 detail: "injected failure".into(),
283 })
284 }
285 }
286
287 fn cb() -> CircuitBreaker {
288 CircuitBreaker::new(CircuitBreakerConfig::default())
289 }
290
291 #[tokio::test]
292 async fn fan_out_broadcasts_to_all_shards() {
293 let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
294 let params = FanOutParams {
295 shard_ids: vec![0, 1, 2],
296 timeout_ms: 1000,
297 source_node: 42,
298 };
299 let req_bytes = b"test-payload";
300 let results = fan_out(
301 ¶ms,
302 super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
303 req_bytes,
304 &dispatch,
305 &cb(),
306 )
307 .await
308 .expect("fan_out should succeed");
309
310 assert_eq!(results.len(), 3);
311 for (_, payload) in &results {
313 assert_eq!(payload.as_slice(), req_bytes);
314 }
315 let mut ids: Vec<u32> = results.iter().map(|(id, _)| *id).collect();
317 ids.sort_unstable();
318 assert_eq!(ids, vec![0, 1, 2]);
319 }
320
321 #[tokio::test]
322 async fn fan_out_empty_shards_returns_empty() {
323 let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
324 let params = FanOutParams {
325 shard_ids: vec![],
326 timeout_ms: 1000,
327 source_node: 1,
328 };
329 let results = fan_out(
330 ¶ms,
331 super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
332 b"",
333 &dispatch,
334 &cb(),
335 )
336 .await
337 .expect("empty fan_out should succeed");
338 assert!(results.is_empty());
339 }
340
341 #[tokio::test]
342 async fn fan_out_propagates_shard_error() {
343 let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(FailDispatch);
344 let params = FanOutParams {
345 shard_ids: vec![0, 1],
346 timeout_ms: 1000,
347 source_node: 1,
348 };
349 let err = fan_out(
350 ¶ms,
351 super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
352 b"",
353 &dispatch,
354 &cb(),
355 )
356 .await
357 .expect_err("fan_out should propagate shard failure");
358 assert!(
359 matches!(err, ClusterError::Transport { .. }),
360 "expected Transport error, got {err:?}"
361 );
362 }
363
364 #[tokio::test]
365 async fn fan_out_partitioned_dispatches_distinct_payloads() {
366 let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
367 let params = FanOutPartitionedParams {
368 timeout_ms: 1000,
369 source_node: 1,
370 };
371 let per_shard = vec![
372 (0u32, b"shard0-data".to_vec()),
373 (1u32, b"shard1-data".to_vec()),
374 ];
375 let results = fan_out_partitioned(
376 ¶ms,
377 super::super::opcodes::ARRAY_SHARD_PUT_REQ,
378 &per_shard,
379 &dispatch,
380 &cb(),
381 )
382 .await
383 .expect("fan_out_partitioned should succeed");
384
385 assert_eq!(results.len(), 2);
386 let mut sorted = results.clone();
387 sorted.sort_unstable_by_key(|(id, _)| *id);
388 assert_eq!(sorted[0].1, b"shard0-data");
389 assert_eq!(sorted[1].1, b"shard1-data");
390 }
391
392 #[tokio::test]
393 async fn circuit_breaker_open_blocks_fan_out() {
394 use crate::circuit_breaker::CircuitBreakerConfig;
395 use std::time::Duration;
396
397 let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(EchoDispatch);
398 let cb = CircuitBreaker::new(CircuitBreakerConfig {
399 failure_threshold: 1,
400 cooldown: Duration::from_secs(60),
401 });
402 cb.record_failure(0);
404
405 let params = FanOutParams {
406 shard_ids: vec![0],
407 timeout_ms: 1000,
408 source_node: 1,
409 };
410 let err = fan_out(
411 ¶ms,
412 super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
413 b"",
414 &dispatch,
415 &cb,
416 )
417 .await
418 .expect_err("open circuit should block fan_out");
419 assert!(
420 matches!(err, ClusterError::CircuitOpen { .. }),
421 "expected CircuitOpen, got {err:?}"
422 );
423 }
424
425 use std::sync::atomic::{AtomicU32, Ordering};
428
429 struct WrongOwnerThenEchoDispatch {
432 call_count: Arc<AtomicU32>,
433 fail_count: u32,
434 }
435
436 #[async_trait]
437 impl ShardRpcDispatch for WrongOwnerThenEchoDispatch {
438 async fn call(&self, req: VShardEnvelope, _timeout_ms: u64) -> Result<VShardEnvelope> {
439 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
440 if n < self.fail_count {
441 return Err(ClusterError::WrongOwner {
442 vshard_id: req.vshard_id,
443 expected_owner_node: None,
444 });
445 }
446 let resp_opcode = req.msg_type as u32 + 1;
447 let resp_type = super::msg_type_from_opcode(resp_opcode)?;
448 Ok(VShardEnvelope::new(
449 resp_type,
450 req.target_node,
451 req.source_node,
452 req.vshard_id,
453 req.payload,
454 ))
455 }
456 }
457
458 #[tokio::test]
459 async fn wrong_owner_triggers_retry_once() {
460 let call_count = Arc::new(AtomicU32::new(0));
462 let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(WrongOwnerThenEchoDispatch {
463 call_count: call_count.clone(),
464 fail_count: 1,
465 });
466 let params = FanOutParams {
467 shard_ids: vec![0],
468 timeout_ms: 1000,
469 source_node: 1,
470 };
471 let result = fan_out(
472 ¶ms,
473 super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
474 b"payload",
475 &dispatch,
476 &cb(),
477 )
478 .await
479 .expect("fan_out should succeed after one WrongOwner retry");
480
481 assert_eq!(result.len(), 1);
482 assert_eq!(
483 call_count.load(Ordering::SeqCst),
484 2,
485 "must have called dispatch twice"
486 );
487 }
488
489 #[tokio::test]
490 async fn wrong_owner_twice_propagates() {
491 let call_count = Arc::new(AtomicU32::new(0));
493 let dispatch: Arc<dyn ShardRpcDispatch> = Arc::new(WrongOwnerThenEchoDispatch {
494 call_count: call_count.clone(),
495 fail_count: 2,
496 });
497 let params = FanOutParams {
498 shard_ids: vec![0],
499 timeout_ms: 1000,
500 source_node: 1,
501 };
502 let err = fan_out(
503 ¶ms,
504 super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
505 b"payload",
506 &dispatch,
507 &cb(),
508 )
509 .await
510 .expect_err("fan_out should propagate WrongOwner when both attempts fail");
511
512 assert!(
513 matches!(err, ClusterError::WrongOwner { .. }),
514 "expected WrongOwner, got {err:?}"
515 );
516 assert_eq!(
517 call_count.load(Ordering::SeqCst),
518 2,
519 "must have called dispatch twice"
520 );
521 }
522}