Skip to main content

atomr_testkit/
multinode.rs

1//! `MultiNodeSpec` — shared-barrier harness for multi-node tests.
2//!
3//! Rather than spawn N OS processes coordinated via an external controller,
4//! atomr spawns N `ActorSystem`s in the same Tokio runtime (each on a
5//! distinct local address/port) and synchronizes them via in-process
6//! barriers. That covers the cluster / sharding / persistence integration
7//! suites without needing a separate test runner.
8//!
9//! For genuine OS-process isolation (TCP loopback, real sockets),
10//! `MultiNodeSpec` exposes `node_address(i)` so callers can ship
11//! that into a `RemoteSystem` builder. The Phase 5 remote-depth pass
12//! adds a real cross-process variant on top.
13//!
14//! Typical pattern:
15//!
16//! ```no_run
17//! # use std::time::Duration;
18//! # use atomr_testkit::MultiNodeSpec;
19//! # async fn run() {
20//! let spec = MultiNodeSpec::new("ClusterTest", 3);
21//! let nodes = spec.boot().await.unwrap();
22//! // ...do work on each node...
23//! spec.barrier("converged", Duration::from_secs(2)).await.unwrap();
24//! spec.shutdown(nodes).await;
25//! # }
26//! ```
27
28use std::collections::HashMap;
29use std::sync::{Arc, Mutex};
30use std::time::Duration;
31
32use atomr_config::Config;
33use atomr_core::actor::{ActorSystem, ActorSystemError};
34use thiserror::Error;
35use tokio::sync::Barrier;
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum MultiNodeError {
40    #[error("failed to boot node `{name}`: {source}")]
41    Boot {
42        name: String,
43        #[source]
44        source: ActorSystemError,
45    },
46    #[error("barrier `{name}` timed out (got {got}/{expected})")]
47    BarrierTimeout { name: String, got: usize, expected: usize },
48}
49
50/// Multi-node test specification.
51pub struct MultiNodeSpec {
52    name: String,
53    node_count: usize,
54    barriers: Arc<Mutex<HashMap<String, Arc<Barrier>>>>,
55    arrivals: Arc<Mutex<HashMap<String, usize>>>,
56}
57
58impl MultiNodeSpec {
59    pub fn new(name: impl Into<String>, node_count: usize) -> Self {
60        assert!(node_count >= 1, "node_count must be ≥ 1");
61        Self {
62            name: name.into(),
63            node_count,
64            barriers: Arc::new(Mutex::new(HashMap::new())),
65            arrivals: Arc::new(Mutex::new(HashMap::new())),
66        }
67    }
68
69    pub fn name(&self) -> &str {
70        &self.name
71    }
72
73    pub fn node_count(&self) -> usize {
74        self.node_count
75    }
76
77    /// Synthesize a per-node identity. Real cross-process tests can
78    /// derive a TCP address from this string.
79    pub fn node_address(&self, i: usize) -> String {
80        format!("{}@node-{}", self.name, i)
81    }
82
83    /// Boot `node_count` distinct in-process `ActorSystem`s. Each
84    /// gets a name `"<spec>-N"`. The reference config is used
85    /// because per-node config knobs come into play in Phase 6.
86    pub async fn boot(&self) -> Result<Vec<ActorSystem>, MultiNodeError> {
87        let mut out = Vec::with_capacity(self.node_count);
88        for i in 0..self.node_count {
89            let name = format!("{}-{}", self.name, i);
90            let sys = ActorSystem::create(&name, Config::reference())
91                .await
92                .map_err(|e| MultiNodeError::Boot { name, source: e })?;
93            out.push(sys);
94        }
95        Ok(out)
96    }
97
98    /// Each node calls `barrier(label, timeout)` with the same label;
99    /// the future resolves once all `node_count` callers have arrived
100    /// or `timeout` elapses (whichever is first).
101    ///
102    /// Backed by [`tokio::sync::Barrier`] per label; this avoids the
103    /// `Notify::notify_waiters` race where late waiters miss an
104    /// already-fired notification.
105    pub async fn barrier(&self, label: &str, timeout: Duration) -> Result<(), MultiNodeError> {
106        let bar = {
107            let mut g = self.barriers.lock().unwrap();
108            g.entry(label.to_string()).or_insert_with(|| Arc::new(Barrier::new(self.node_count))).clone()
109        };
110        {
111            let mut a = self.arrivals.lock().unwrap();
112            *a.entry(label.to_string()).or_insert(0) += 1;
113        }
114        match tokio::time::timeout(timeout, bar.wait()).await {
115            Ok(_) => Ok(()),
116            Err(_) => {
117                let arrivals = *self.arrivals.lock().unwrap().get(label).unwrap_or(&0);
118                Err(MultiNodeError::BarrierTimeout {
119                    name: label.to_string(),
120                    got: arrivals,
121                    expected: self.node_count,
122                })
123            }
124        }
125    }
126
127    /// Convenience: terminate every node booted by [`Self::boot`].
128    pub async fn shutdown(&self, nodes: Vec<ActorSystem>) {
129        for sys in nodes {
130            sys.terminate().await;
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[tokio::test]
140    async fn boot_three_nodes_and_barrier() {
141        let spec = Arc::new(MultiNodeSpec::new("BarrierTest", 3));
142        let nodes = spec.boot().await.unwrap();
143        assert_eq!(nodes.len(), 3);
144
145        let mut handles = Vec::new();
146        for _ in 0..3 {
147            let s = spec.clone();
148            handles.push(tokio::spawn(async move {
149                s.barrier("step1", Duration::from_secs(2)).await.unwrap();
150            }));
151        }
152        for h in handles {
153            h.await.unwrap();
154        }
155
156        spec.shutdown(nodes).await;
157    }
158
159    #[tokio::test]
160    async fn barrier_times_out_when_only_some_arrive() {
161        let spec = Arc::new(MultiNodeSpec::new("BarrierTimeoutTest", 3));
162        let _ = spec.boot().await.unwrap();
163        // Only 2 of 3 arrive — barrier must time out.
164        let s2 = spec.clone();
165        let h = tokio::spawn(async move { s2.barrier("only-two", Duration::from_millis(50)).await });
166        spec.barrier("only-two", Duration::from_millis(50)).await.err();
167        let r = h.await.unwrap();
168        assert!(matches!(r, Err(MultiNodeError::BarrierTimeout { .. })));
169    }
170
171    #[test]
172    fn node_addresses_are_distinct() {
173        let s = MultiNodeSpec::new("X", 4);
174        let addrs: Vec<String> = (0..4).map(|i| s.node_address(i)).collect();
175        let unique: std::collections::HashSet<_> = addrs.iter().cloned().collect();
176        assert_eq!(unique.len(), 4);
177    }
178}