rakka_testkit/
multinode.rs1use std::collections::HashMap;
31use std::sync::{Arc, Mutex};
32use std::time::Duration;
33
34use rakka_config::Config;
35use rakka_core::actor::{ActorSystem, ActorSystemError};
36use thiserror::Error;
37use tokio::sync::Barrier;
38
39#[derive(Debug, Error)]
40#[non_exhaustive]
41pub enum MultiNodeError {
42 #[error("failed to boot node `{name}`: {source}")]
43 Boot {
44 name: String,
45 #[source]
46 source: ActorSystemError,
47 },
48 #[error("barrier `{name}` timed out (got {got}/{expected})")]
49 BarrierTimeout { name: String, got: usize, expected: usize },
50}
51
52pub struct MultiNodeSpec {
54 name: String,
55 node_count: usize,
56 barriers: Arc<Mutex<HashMap<String, Arc<Barrier>>>>,
57 arrivals: Arc<Mutex<HashMap<String, usize>>>,
58}
59
60impl MultiNodeSpec {
61 pub fn new(name: impl Into<String>, node_count: usize) -> Self {
62 assert!(node_count >= 1, "node_count must be ≥ 1");
63 Self {
64 name: name.into(),
65 node_count,
66 barriers: Arc::new(Mutex::new(HashMap::new())),
67 arrivals: Arc::new(Mutex::new(HashMap::new())),
68 }
69 }
70
71 pub fn name(&self) -> &str {
72 &self.name
73 }
74
75 pub fn node_count(&self) -> usize {
76 self.node_count
77 }
78
79 pub fn node_address(&self, i: usize) -> String {
82 format!("{}@node-{}", self.name, i)
83 }
84
85 pub async fn boot(&self) -> Result<Vec<ActorSystem>, MultiNodeError> {
89 let mut out = Vec::with_capacity(self.node_count);
90 for i in 0..self.node_count {
91 let name = format!("{}-{}", self.name, i);
92 let sys = ActorSystem::create(&name, Config::reference())
93 .await
94 .map_err(|e| MultiNodeError::Boot { name, source: e })?;
95 out.push(sys);
96 }
97 Ok(out)
98 }
99
100 pub async fn barrier(&self, label: &str, timeout: Duration) -> Result<(), MultiNodeError> {
108 let bar = {
109 let mut g = self.barriers.lock().unwrap();
110 g.entry(label.to_string()).or_insert_with(|| Arc::new(Barrier::new(self.node_count))).clone()
111 };
112 {
113 let mut a = self.arrivals.lock().unwrap();
114 *a.entry(label.to_string()).or_insert(0) += 1;
115 }
116 match tokio::time::timeout(timeout, bar.wait()).await {
117 Ok(_) => Ok(()),
118 Err(_) => {
119 let arrivals = *self.arrivals.lock().unwrap().get(label).unwrap_or(&0);
120 Err(MultiNodeError::BarrierTimeout {
121 name: label.to_string(),
122 got: arrivals,
123 expected: self.node_count,
124 })
125 }
126 }
127 }
128
129 pub async fn shutdown(&self, nodes: Vec<ActorSystem>) {
131 for sys in nodes {
132 sys.terminate().await;
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[tokio::test]
142 async fn boot_three_nodes_and_barrier() {
143 let spec = Arc::new(MultiNodeSpec::new("BarrierTest", 3));
144 let nodes = spec.boot().await.unwrap();
145 assert_eq!(nodes.len(), 3);
146
147 let mut handles = Vec::new();
148 for _ in 0..3 {
149 let s = spec.clone();
150 handles.push(tokio::spawn(async move {
151 s.barrier("step1", Duration::from_secs(2)).await.unwrap();
152 }));
153 }
154 for h in handles {
155 h.await.unwrap();
156 }
157
158 spec.shutdown(nodes).await;
159 }
160
161 #[tokio::test]
162 async fn barrier_times_out_when_only_some_arrive() {
163 let spec = Arc::new(MultiNodeSpec::new("BarrierTimeoutTest", 3));
164 let _ = spec.boot().await.unwrap();
165 let s2 = spec.clone();
167 let h = tokio::spawn(async move { s2.barrier("only-two", Duration::from_millis(50)).await });
168 spec.barrier("only-two", Duration::from_millis(50)).await.err();
169 let r = h.await.unwrap();
170 assert!(matches!(r, Err(MultiNodeError::BarrierTimeout { .. })));
171 }
172
173 #[test]
174 fn node_addresses_are_distinct() {
175 let s = MultiNodeSpec::new("X", 4);
176 let addrs: Vec<String> = (0..4).map(|i| s.node_address(i)).collect();
177 let unique: std::collections::HashSet<_> = addrs.iter().cloned().collect();
178 assert_eq!(unique.len(), 4);
179 }
180}