atomr_testkit/
multinode.rs1use std::collections::HashMap;
29use std::sync::{Arc, Mutex};
30use std::time::Duration;
31
32use atomr_config::Config;
33use atomr_core::actor::{ActorSystem, ActorSystemError};
34use thiserror::Error;
35use tokio::sync::Barrier;
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum MultiNodeError {
40 #[error("failed to boot node `{name}`: {source}")]
41 Boot {
42 name: String,
43 #[source]
44 source: ActorSystemError,
45 },
46 #[error("barrier `{name}` timed out (got {got}/{expected})")]
47 BarrierTimeout { name: String, got: usize, expected: usize },
48}
49
50pub struct MultiNodeSpec {
52 name: String,
53 node_count: usize,
54 barriers: Arc<Mutex<HashMap<String, Arc<Barrier>>>>,
55 arrivals: Arc<Mutex<HashMap<String, usize>>>,
56}
57
58impl MultiNodeSpec {
59 pub fn new(name: impl Into<String>, node_count: usize) -> Self {
60 assert!(node_count >= 1, "node_count must be ≥ 1");
61 Self {
62 name: name.into(),
63 node_count,
64 barriers: Arc::new(Mutex::new(HashMap::new())),
65 arrivals: Arc::new(Mutex::new(HashMap::new())),
66 }
67 }
68
69 pub fn name(&self) -> &str {
70 &self.name
71 }
72
73 pub fn node_count(&self) -> usize {
74 self.node_count
75 }
76
77 pub fn node_address(&self, i: usize) -> String {
80 format!("{}@node-{}", self.name, i)
81 }
82
83 pub async fn boot(&self) -> Result<Vec<ActorSystem>, MultiNodeError> {
87 let mut out = Vec::with_capacity(self.node_count);
88 for i in 0..self.node_count {
89 let name = format!("{}-{}", self.name, i);
90 let sys = ActorSystem::create(&name, Config::reference())
91 .await
92 .map_err(|e| MultiNodeError::Boot { name, source: e })?;
93 out.push(sys);
94 }
95 Ok(out)
96 }
97
98 pub async fn barrier(&self, label: &str, timeout: Duration) -> Result<(), MultiNodeError> {
106 let bar = {
107 let mut g = self.barriers.lock().unwrap();
108 g.entry(label.to_string()).or_insert_with(|| Arc::new(Barrier::new(self.node_count))).clone()
109 };
110 {
111 let mut a = self.arrivals.lock().unwrap();
112 *a.entry(label.to_string()).or_insert(0) += 1;
113 }
114 match tokio::time::timeout(timeout, bar.wait()).await {
115 Ok(_) => Ok(()),
116 Err(_) => {
117 let arrivals = *self.arrivals.lock().unwrap().get(label).unwrap_or(&0);
118 Err(MultiNodeError::BarrierTimeout {
119 name: label.to_string(),
120 got: arrivals,
121 expected: self.node_count,
122 })
123 }
124 }
125 }
126
127 pub async fn shutdown(&self, nodes: Vec<ActorSystem>) {
129 for sys in nodes {
130 sys.terminate().await;
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[tokio::test]
140 async fn boot_three_nodes_and_barrier() {
141 let spec = Arc::new(MultiNodeSpec::new("BarrierTest", 3));
142 let nodes = spec.boot().await.unwrap();
143 assert_eq!(nodes.len(), 3);
144
145 let mut handles = Vec::new();
146 for _ in 0..3 {
147 let s = spec.clone();
148 handles.push(tokio::spawn(async move {
149 s.barrier("step1", Duration::from_secs(2)).await.unwrap();
150 }));
151 }
152 for h in handles {
153 h.await.unwrap();
154 }
155
156 spec.shutdown(nodes).await;
157 }
158
159 #[tokio::test]
160 async fn barrier_times_out_when_only_some_arrive() {
161 let spec = Arc::new(MultiNodeSpec::new("BarrierTimeoutTest", 3));
162 let _ = spec.boot().await.unwrap();
163 let s2 = spec.clone();
165 let h = tokio::spawn(async move { s2.barrier("only-two", Duration::from_millis(50)).await });
166 spec.barrier("only-two", Duration::from_millis(50)).await.err();
167 let r = h.await.unwrap();
168 assert!(matches!(r, Err(MultiNodeError::BarrierTimeout { .. })));
169 }
170
171 #[test]
172 fn node_addresses_are_distinct() {
173 let s = MultiNodeSpec::new("X", 4);
174 let addrs: Vec<String> = (0..4).map(|i| s.node_address(i)).collect();
175 let unique: std::collections::HashSet<_> = addrs.iter().cloned().collect();
176 assert_eq!(unique.len(), 4);
177 }
178}