Skip to main content

atomr_testkit/
multinode_oop.rs

1//! Out-of-process `MultiNodeSpec`.
2//! With a controller process and
3//! N child processes coordinated over the loopback transport.
4//!
5//! The line protocol is intentionally trivial (one ASCII command per
6//! line) so a child written in any language could join the rendezvous.
7//! Each barrier label is a separate sync point: every node calls
8//! `barrier(label)`; the controller blocks until N nodes have arrived
9//! on that label, then echoes `OK <label>` to each.
10//!
11//! Wire protocol (all `\n`-terminated UTF-8):
12//!   child → controller   `BARRIER <label>`
13//!   controller → child   `OK <label>` once N have arrived
14//!   controller → child   `TIMEOUT <label>` if the controller's
15//!                        per-barrier timer fires first
16//!
17//! The harness purposely does not impose any actor-system or runtime
18//! contract on the child side — a child is just "code that connects
19//! to a TCP port and exchanges barrier labels". Tests pair this with
20//! whatever node bootstrapping their assertion needs.
21
22use std::collections::HashMap;
23use std::io;
24use std::net::SocketAddr;
25use std::sync::Arc;
26use std::time::Duration;
27
28use parking_lot::Mutex;
29use thiserror::Error;
30use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
31use tokio::net::{TcpListener, TcpStream};
32use tokio::sync::oneshot;
33use tokio::task::JoinHandle;
34
35#[derive(Debug, Error)]
36#[non_exhaustive]
37pub enum MultiNodeOopError {
38    #[error("io: {0}")]
39    Io(#[from] io::Error),
40    #[error("controller already shut down")]
41    ControllerDown,
42    #[error("barrier `{label}` timed out at controller (got {got}/{expected})")]
43    BarrierTimeout { label: String, got: usize, expected: usize },
44    #[error("malformed line from peer: {0}")]
45    Malformed(String),
46    #[error("unexpected reply: {0}")]
47    UnexpectedReply(String),
48}
49
50/// Per-label rendezvous state on the controller.
51struct LabelState {
52    expected: usize,
53    /// Senders waiting to be notified; we send a single byte ('O' or
54    /// 'T') so the child handler can emit the appropriate response.
55    waiters: Vec<oneshot::Sender<bool>>,
56    arrived: usize,
57    completed: bool,
58}
59
60/// Out-of-process barrier controller. Bind it on the test driver,
61/// then point the children at `local_addr()` (e.g. via env var).
62pub struct MultiNodeOopController {
63    addr: SocketAddr,
64    inner: Arc<ControllerInner>,
65    handle: JoinHandle<()>,
66}
67
68struct ControllerInner {
69    expected: usize,
70    labels: Mutex<HashMap<String, LabelState>>,
71}
72
73impl MultiNodeOopController {
74    /// Start the controller with the expected node count. Listens on
75    /// `127.0.0.1:0` (kernel-assigned port). The accepted address can
76    /// be read via `local_addr()`.
77    pub async fn start(expected_nodes: usize) -> Result<Self, MultiNodeOopError> {
78        assert!(expected_nodes >= 1, "expected_nodes must be ≥ 1");
79        let listener = TcpListener::bind("127.0.0.1:0").await?;
80        let addr = listener.local_addr()?;
81        let inner =
82            Arc::new(ControllerInner { expected: expected_nodes, labels: Mutex::new(HashMap::new()) });
83        let inner_a = inner.clone();
84        let handle = tokio::spawn(async move {
85            loop {
86                match listener.accept().await {
87                    Ok((s, _)) => {
88                        let i = inner_a.clone();
89                        tokio::spawn(async move {
90                            handle_child(s, i).await;
91                        });
92                    }
93                    Err(_) => return,
94                }
95            }
96        });
97        Ok(Self { addr, inner, handle })
98    }
99
100    pub fn local_addr(&self) -> SocketAddr {
101        self.addr
102    }
103
104    /// Time-bound a label: if the requested label has not been
105    /// reached by `timeout`, every connected child waiter receives
106    /// `TIMEOUT label`. Returns the count of arrivals when the timer
107    /// fired.
108    pub async fn timeout_barrier(&self, label: &str, timeout: Duration) -> Result<usize, MultiNodeOopError> {
109        tokio::time::sleep(timeout).await;
110        let mut g = self.inner.labels.lock();
111        let state = g.entry(label.to_string()).or_insert_with(|| LabelState {
112            expected: self.inner.expected,
113            waiters: Vec::new(),
114            arrived: 0,
115            completed: false,
116        });
117        if state.completed {
118            return Ok(state.arrived);
119        }
120        // Trigger every waiter with `false` (timeout).
121        let arrived = state.arrived;
122        for w in state.waiters.drain(..) {
123            let _ = w.send(false);
124        }
125        state.completed = true;
126        if arrived < state.expected {
127            return Err(MultiNodeOopError::BarrierTimeout {
128                label: label.into(),
129                got: arrived,
130                expected: state.expected,
131            });
132        }
133        Ok(arrived)
134    }
135
136    /// Stop accepting new connections and drop the listener task.
137    /// Pending child connections continue but new BARRIERs against a
138    /// shut-down controller will fail.
139    pub fn shutdown(self) {
140        self.handle.abort();
141    }
142}
143
144async fn handle_child(stream: TcpStream, inner: Arc<ControllerInner>) {
145    let (r, mut w) = stream.into_split();
146    let mut lines = BufReader::new(r).lines();
147    while let Ok(Some(line)) = lines.next_line().await {
148        let trimmed = line.trim().to_string();
149        if let Some(label) = trimmed.strip_prefix("BARRIER ") {
150            let rx = enroll(&inner, label);
151            // Wait for the rendezvous resolution.
152            let outcome = match rx.await {
153                Ok(true) => format!("OK {label}\n"),
154                Ok(false) => format!("TIMEOUT {label}\n"),
155                Err(_) => format!("TIMEOUT {label}\n"),
156            };
157            if w.write_all(outcome.as_bytes()).await.is_err() {
158                return;
159            }
160        }
161    }
162}
163
164fn enroll(inner: &Arc<ControllerInner>, label: &str) -> oneshot::Receiver<bool> {
165    let (tx, rx) = oneshot::channel();
166    let mut g = inner.labels.lock();
167    let state = g.entry(label.to_string()).or_insert_with(|| LabelState {
168        expected: inner.expected,
169        waiters: Vec::new(),
170        arrived: 0,
171        completed: false,
172    });
173    if state.completed {
174        // Already-completed label: responder will see we've moved on.
175        let _ = tx.send(true);
176        return rx;
177    }
178    state.arrived += 1;
179    state.waiters.push(tx);
180    if state.arrived >= state.expected {
181        // Trigger every waiter with `true` (success).
182        for w in state.waiters.drain(..) {
183            let _ = w.send(true);
184        }
185        state.completed = true;
186    }
187    rx
188}
189
190/// Child-side handle. Construct one per node by passing the
191/// controller's `local_addr()`.
192pub struct MultiNodeOopNode {
193    stream: tokio::sync::Mutex<TcpStream>,
194}
195
196impl MultiNodeOopNode {
197    pub async fn connect(controller: SocketAddr) -> Result<Self, MultiNodeOopError> {
198        let s = TcpStream::connect(controller).await?;
199        s.set_nodelay(true)?;
200        Ok(Self { stream: tokio::sync::Mutex::new(s) })
201    }
202
203    /// Block until every node has arrived on `label`, or until the
204    /// controller's timer fires first. Returns Ok on success and
205    /// `BarrierTimeout` on failure.
206    pub async fn barrier(&self, label: &str) -> Result<(), MultiNodeOopError> {
207        let mut g = self.stream.lock().await;
208        g.write_all(format!("BARRIER {label}\n").as_bytes()).await?;
209        let mut buf = String::new();
210        let mut reader = BufReader::new(&mut *g);
211        reader.read_line(&mut buf).await?;
212        let trimmed = buf.trim();
213        if let Some(rest) = trimmed.strip_prefix("OK ") {
214            if rest == label {
215                return Ok(());
216            }
217        }
218        if let Some(rest) = trimmed.strip_prefix("TIMEOUT ") {
219            return Err(MultiNodeOopError::BarrierTimeout { label: rest.to_string(), got: 0, expected: 0 });
220        }
221        Err(MultiNodeOopError::UnexpectedReply(trimmed.to_string()))
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
230    async fn three_nodes_meet_at_barrier() {
231        let ctrl = MultiNodeOopController::start(3).await.unwrap();
232        let addr = ctrl.local_addr();
233
234        let mut handles = Vec::new();
235        for _ in 0..3 {
236            handles.push(tokio::spawn(async move {
237                let n = MultiNodeOopNode::connect(addr).await.unwrap();
238                n.barrier("converged").await.unwrap();
239            }));
240        }
241        for h in handles {
242            h.await.unwrap();
243        }
244    }
245
246    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
247    async fn barrier_times_out_when_only_some_arrive() {
248        let ctrl = MultiNodeOopController::start(3).await.unwrap();
249        let addr = ctrl.local_addr();
250        let label = "incomplete";
251
252        // Only two of three nodes arrive.
253        let h1 = tokio::spawn(async move {
254            let n = MultiNodeOopNode::connect(addr).await.unwrap();
255            let _ = n.barrier(label).await;
256        });
257        let h2 = tokio::spawn(async move {
258            let n = MultiNodeOopNode::connect(addr).await.unwrap();
259            let _ = n.barrier(label).await;
260        });
261
262        // Drive the controller's timer.
263        let to = ctrl.timeout_barrier(label, Duration::from_millis(50)).await;
264        // We expect a timeout error because only 2/3 arrived.
265        assert!(matches!(to, Err(MultiNodeOopError::BarrierTimeout { .. })));
266
267        let _ = h1.await;
268        let _ = h2.await;
269    }
270
271    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
272    async fn multiple_independent_labels() {
273        let ctrl = MultiNodeOopController::start(2).await.unwrap();
274        let addr = ctrl.local_addr();
275
276        let h1 = tokio::spawn(async move {
277            let n = MultiNodeOopNode::connect(addr).await.unwrap();
278            n.barrier("phase-a").await.unwrap();
279            n.barrier("phase-b").await.unwrap();
280        });
281        let h2 = tokio::spawn(async move {
282            let n = MultiNodeOopNode::connect(addr).await.unwrap();
283            n.barrier("phase-a").await.unwrap();
284            n.barrier("phase-b").await.unwrap();
285        });
286        h1.await.unwrap();
287        h2.await.unwrap();
288    }
289
290    #[tokio::test]
291    async fn controller_addr_is_loopback() {
292        let ctrl = MultiNodeOopController::start(1).await.unwrap();
293        let addr = ctrl.local_addr();
294        assert!(addr.ip().is_loopback());
295        assert_ne!(addr.port(), 0);
296    }
297}