atomr_testkit/
multinode_oop.rs1use std::collections::HashMap;
23use std::io;
24use std::net::SocketAddr;
25use std::sync::Arc;
26use std::time::Duration;
27
28use parking_lot::Mutex;
29use thiserror::Error;
30use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
31use tokio::net::{TcpListener, TcpStream};
32use tokio::sync::oneshot;
33use tokio::task::JoinHandle;
34
35#[derive(Debug, Error)]
36#[non_exhaustive]
37pub enum MultiNodeOopError {
38 #[error("io: {0}")]
39 Io(#[from] io::Error),
40 #[error("controller already shut down")]
41 ControllerDown,
42 #[error("barrier `{label}` timed out at controller (got {got}/{expected})")]
43 BarrierTimeout { label: String, got: usize, expected: usize },
44 #[error("malformed line from peer: {0}")]
45 Malformed(String),
46 #[error("unexpected reply: {0}")]
47 UnexpectedReply(String),
48}
49
50struct LabelState {
52 expected: usize,
53 waiters: Vec<oneshot::Sender<bool>>,
56 arrived: usize,
57 completed: bool,
58}
59
60pub struct MultiNodeOopController {
63 addr: SocketAddr,
64 inner: Arc<ControllerInner>,
65 handle: JoinHandle<()>,
66}
67
68struct ControllerInner {
69 expected: usize,
70 labels: Mutex<HashMap<String, LabelState>>,
71}
72
73impl MultiNodeOopController {
74 pub async fn start(expected_nodes: usize) -> Result<Self, MultiNodeOopError> {
78 assert!(expected_nodes >= 1, "expected_nodes must be ≥ 1");
79 let listener = TcpListener::bind("127.0.0.1:0").await?;
80 let addr = listener.local_addr()?;
81 let inner =
82 Arc::new(ControllerInner { expected: expected_nodes, labels: Mutex::new(HashMap::new()) });
83 let inner_a = inner.clone();
84 let handle = tokio::spawn(async move {
85 loop {
86 match listener.accept().await {
87 Ok((s, _)) => {
88 let i = inner_a.clone();
89 tokio::spawn(async move {
90 handle_child(s, i).await;
91 });
92 }
93 Err(_) => return,
94 }
95 }
96 });
97 Ok(Self { addr, inner, handle })
98 }
99
100 pub fn local_addr(&self) -> SocketAddr {
101 self.addr
102 }
103
104 pub async fn timeout_barrier(&self, label: &str, timeout: Duration) -> Result<usize, MultiNodeOopError> {
109 tokio::time::sleep(timeout).await;
110 let mut g = self.inner.labels.lock();
111 let state = g.entry(label.to_string()).or_insert_with(|| LabelState {
112 expected: self.inner.expected,
113 waiters: Vec::new(),
114 arrived: 0,
115 completed: false,
116 });
117 if state.completed {
118 return Ok(state.arrived);
119 }
120 let arrived = state.arrived;
122 for w in state.waiters.drain(..) {
123 let _ = w.send(false);
124 }
125 state.completed = true;
126 if arrived < state.expected {
127 return Err(MultiNodeOopError::BarrierTimeout {
128 label: label.into(),
129 got: arrived,
130 expected: state.expected,
131 });
132 }
133 Ok(arrived)
134 }
135
136 pub fn shutdown(self) {
140 self.handle.abort();
141 }
142}
143
144async fn handle_child(stream: TcpStream, inner: Arc<ControllerInner>) {
145 let (r, mut w) = stream.into_split();
146 let mut lines = BufReader::new(r).lines();
147 while let Ok(Some(line)) = lines.next_line().await {
148 let trimmed = line.trim().to_string();
149 if let Some(label) = trimmed.strip_prefix("BARRIER ") {
150 let rx = enroll(&inner, label);
151 let outcome = match rx.await {
153 Ok(true) => format!("OK {label}\n"),
154 Ok(false) => format!("TIMEOUT {label}\n"),
155 Err(_) => format!("TIMEOUT {label}\n"),
156 };
157 if w.write_all(outcome.as_bytes()).await.is_err() {
158 return;
159 }
160 }
161 }
162}
163
164fn enroll(inner: &Arc<ControllerInner>, label: &str) -> oneshot::Receiver<bool> {
165 let (tx, rx) = oneshot::channel();
166 let mut g = inner.labels.lock();
167 let state = g.entry(label.to_string()).or_insert_with(|| LabelState {
168 expected: inner.expected,
169 waiters: Vec::new(),
170 arrived: 0,
171 completed: false,
172 });
173 if state.completed {
174 let _ = tx.send(true);
176 return rx;
177 }
178 state.arrived += 1;
179 state.waiters.push(tx);
180 if state.arrived >= state.expected {
181 for w in state.waiters.drain(..) {
183 let _ = w.send(true);
184 }
185 state.completed = true;
186 }
187 rx
188}
189
190pub struct MultiNodeOopNode {
193 stream: tokio::sync::Mutex<TcpStream>,
194}
195
196impl MultiNodeOopNode {
197 pub async fn connect(controller: SocketAddr) -> Result<Self, MultiNodeOopError> {
198 let s = TcpStream::connect(controller).await?;
199 s.set_nodelay(true)?;
200 Ok(Self { stream: tokio::sync::Mutex::new(s) })
201 }
202
203 pub async fn barrier(&self, label: &str) -> Result<(), MultiNodeOopError> {
207 let mut g = self.stream.lock().await;
208 g.write_all(format!("BARRIER {label}\n").as_bytes()).await?;
209 let mut buf = String::new();
210 let mut reader = BufReader::new(&mut *g);
211 reader.read_line(&mut buf).await?;
212 let trimmed = buf.trim();
213 if let Some(rest) = trimmed.strip_prefix("OK ") {
214 if rest == label {
215 return Ok(());
216 }
217 }
218 if let Some(rest) = trimmed.strip_prefix("TIMEOUT ") {
219 return Err(MultiNodeOopError::BarrierTimeout { label: rest.to_string(), got: 0, expected: 0 });
220 }
221 Err(MultiNodeOopError::UnexpectedReply(trimmed.to_string()))
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
230 async fn three_nodes_meet_at_barrier() {
231 let ctrl = MultiNodeOopController::start(3).await.unwrap();
232 let addr = ctrl.local_addr();
233
234 let mut handles = Vec::new();
235 for _ in 0..3 {
236 handles.push(tokio::spawn(async move {
237 let n = MultiNodeOopNode::connect(addr).await.unwrap();
238 n.barrier("converged").await.unwrap();
239 }));
240 }
241 for h in handles {
242 h.await.unwrap();
243 }
244 }
245
246 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
247 async fn barrier_times_out_when_only_some_arrive() {
248 let ctrl = MultiNodeOopController::start(3).await.unwrap();
249 let addr = ctrl.local_addr();
250 let label = "incomplete";
251
252 let h1 = tokio::spawn(async move {
254 let n = MultiNodeOopNode::connect(addr).await.unwrap();
255 let _ = n.barrier(label).await;
256 });
257 let h2 = tokio::spawn(async move {
258 let n = MultiNodeOopNode::connect(addr).await.unwrap();
259 let _ = n.barrier(label).await;
260 });
261
262 let to = ctrl.timeout_barrier(label, Duration::from_millis(50)).await;
264 assert!(matches!(to, Err(MultiNodeOopError::BarrierTimeout { .. })));
266
267 let _ = h1.await;
268 let _ = h2.await;
269 }
270
271 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
272 async fn multiple_independent_labels() {
273 let ctrl = MultiNodeOopController::start(2).await.unwrap();
274 let addr = ctrl.local_addr();
275
276 let h1 = tokio::spawn(async move {
277 let n = MultiNodeOopNode::connect(addr).await.unwrap();
278 n.barrier("phase-a").await.unwrap();
279 n.barrier("phase-b").await.unwrap();
280 });
281 let h2 = tokio::spawn(async move {
282 let n = MultiNodeOopNode::connect(addr).await.unwrap();
283 n.barrier("phase-a").await.unwrap();
284 n.barrier("phase-b").await.unwrap();
285 });
286 h1.await.unwrap();
287 h2.await.unwrap();
288 }
289
290 #[tokio::test]
291 async fn controller_addr_is_loopback() {
292 let ctrl = MultiNodeOopController::start(1).await.unwrap();
293 let addr = ctrl.local_addr();
294 assert!(addr.ip().is_loopback());
295 assert_ne!(addr.port(), 0);
296 }
297}