Skip to main content

atomr_testkit/
multinode.rs

1//! `MultiNodeSpec` — shared-barrier harness for multi-node tests.
2//!
3//! Phase 4 of `docs/full-port-plan.md`. Akka.NET's
4//! `Akka.Remote.TestKit.MultiNodeSpec` spawns N OS processes and
5//! synchronizes them via a controller; we instead spawn N
6//! `ActorSystem`s in the same Tokio runtime (each on a distinct
7//! local address/port) and synchronize them via in-process barriers.
8//! That covers the cluster / sharding / persistence integration
9//! suites without needing a separate test runner.
10//!
11//! For genuine OS-process isolation (TCP loopback, real sockets),
12//! `MultiNodeSpec` exposes `node_address(i)` so callers can ship
13//! that into a `RemoteSystem` builder. The Phase 5 remote-depth pass
14//! adds a real cross-process variant on top.
15//!
16//! Typical pattern:
17//!
18//! ```no_run
19//! # use std::time::Duration;
20//! # use atomr_testkit::MultiNodeSpec;
21//! # async fn run() {
22//! let spec = MultiNodeSpec::new("ClusterTest", 3);
23//! let nodes = spec.boot().await.unwrap();
24//! // ...do work on each node...
25//! spec.barrier("converged", Duration::from_secs(2)).await.unwrap();
26//! spec.shutdown(nodes).await;
27//! # }
28//! ```
29
30use std::collections::HashMap;
31use std::sync::{Arc, Mutex};
32use std::time::Duration;
33
34use atomr_config::Config;
35use atomr_core::actor::{ActorSystem, ActorSystemError};
36use thiserror::Error;
37use tokio::sync::Barrier;
38
39#[derive(Debug, Error)]
40#[non_exhaustive]
41pub enum MultiNodeError {
42    #[error("failed to boot node `{name}`: {source}")]
43    Boot {
44        name: String,
45        #[source]
46        source: ActorSystemError,
47    },
48    #[error("barrier `{name}` timed out (got {got}/{expected})")]
49    BarrierTimeout { name: String, got: usize, expected: usize },
50}
51
52/// Multi-node test specification.
53pub struct MultiNodeSpec {
54    name: String,
55    node_count: usize,
56    barriers: Arc<Mutex<HashMap<String, Arc<Barrier>>>>,
57    arrivals: Arc<Mutex<HashMap<String, usize>>>,
58}
59
60impl MultiNodeSpec {
61    pub fn new(name: impl Into<String>, node_count: usize) -> Self {
62        assert!(node_count >= 1, "node_count must be ≥ 1");
63        Self {
64            name: name.into(),
65            node_count,
66            barriers: Arc::new(Mutex::new(HashMap::new())),
67            arrivals: Arc::new(Mutex::new(HashMap::new())),
68        }
69    }
70
71    pub fn name(&self) -> &str {
72        &self.name
73    }
74
75    pub fn node_count(&self) -> usize {
76        self.node_count
77    }
78
79    /// Synthesize a per-node identity. Real cross-process tests can
80    /// derive a TCP address from this string.
81    pub fn node_address(&self, i: usize) -> String {
82        format!("{}@node-{}", self.name, i)
83    }
84
85    /// Boot `node_count` distinct in-process `ActorSystem`s. Each
86    /// gets a name `"<spec>-N"`. The reference config is used
87    /// because per-node config knobs come into play in Phase 6.
88    pub async fn boot(&self) -> Result<Vec<ActorSystem>, MultiNodeError> {
89        let mut out = Vec::with_capacity(self.node_count);
90        for i in 0..self.node_count {
91            let name = format!("{}-{}", self.name, i);
92            let sys = ActorSystem::create(&name, Config::reference())
93                .await
94                .map_err(|e| MultiNodeError::Boot { name, source: e })?;
95            out.push(sys);
96        }
97        Ok(out)
98    }
99
100    /// Each node calls `barrier(label, timeout)` with the same label;
101    /// the future resolves once all `node_count` callers have arrived
102    /// or `timeout` elapses (whichever is first).
103    ///
104    /// Backed by [`tokio::sync::Barrier`] per label; this avoids the
105    /// `Notify::notify_waiters` race where late waiters miss an
106    /// already-fired notification.
107    pub async fn barrier(&self, label: &str, timeout: Duration) -> Result<(), MultiNodeError> {
108        let bar = {
109            let mut g = self.barriers.lock().unwrap();
110            g.entry(label.to_string()).or_insert_with(|| Arc::new(Barrier::new(self.node_count))).clone()
111        };
112        {
113            let mut a = self.arrivals.lock().unwrap();
114            *a.entry(label.to_string()).or_insert(0) += 1;
115        }
116        match tokio::time::timeout(timeout, bar.wait()).await {
117            Ok(_) => Ok(()),
118            Err(_) => {
119                let arrivals = *self.arrivals.lock().unwrap().get(label).unwrap_or(&0);
120                Err(MultiNodeError::BarrierTimeout {
121                    name: label.to_string(),
122                    got: arrivals,
123                    expected: self.node_count,
124                })
125            }
126        }
127    }
128
129    /// Convenience: terminate every node booted by [`Self::boot`].
130    pub async fn shutdown(&self, nodes: Vec<ActorSystem>) {
131        for sys in nodes {
132            sys.terminate().await;
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[tokio::test]
142    async fn boot_three_nodes_and_barrier() {
143        let spec = Arc::new(MultiNodeSpec::new("BarrierTest", 3));
144        let nodes = spec.boot().await.unwrap();
145        assert_eq!(nodes.len(), 3);
146
147        let mut handles = Vec::new();
148        for _ in 0..3 {
149            let s = spec.clone();
150            handles.push(tokio::spawn(async move {
151                s.barrier("step1", Duration::from_secs(2)).await.unwrap();
152            }));
153        }
154        for h in handles {
155            h.await.unwrap();
156        }
157
158        spec.shutdown(nodes).await;
159    }
160
161    #[tokio::test]
162    async fn barrier_times_out_when_only_some_arrive() {
163        let spec = Arc::new(MultiNodeSpec::new("BarrierTimeoutTest", 3));
164        let _ = spec.boot().await.unwrap();
165        // Only 2 of 3 arrive — barrier must time out.
166        let s2 = spec.clone();
167        let h = tokio::spawn(async move { s2.barrier("only-two", Duration::from_millis(50)).await });
168        spec.barrier("only-two", Duration::from_millis(50)).await.err();
169        let r = h.await.unwrap();
170        assert!(matches!(r, Err(MultiNodeError::BarrierTimeout { .. })));
171    }
172
173    #[test]
174    fn node_addresses_are_distinct() {
175        let s = MultiNodeSpec::new("X", 4);
176        let addrs: Vec<String> = (0..4).map(|i| s.node_address(i)).collect();
177        let unique: std::collections::HashSet<_> = addrs.iter().cloned().collect();
178        assert_eq!(unique.len(), 4);
179    }
180}