1use std::net::SocketAddr;
20use std::sync::Arc;
21
22use nodedb_types::NodeId;
23use tokio::sync::watch;
24use tokio::task::JoinHandle;
25
26use super::config::SwimConfig;
27use super::detector::{FailureDetector, ProbeScheduler, Transport};
28use super::dissemination::DisseminationQueue;
29use super::error::SwimError;
30use super::incarnation::Incarnation;
31use super::member::MemberState;
32use super::member::record::MemberUpdate;
33use super::membership::MembershipList;
34use super::subscriber::MembershipSubscriber;
35
36pub struct SwimHandle {
41 detector: Arc<FailureDetector>,
42 membership: Arc<MembershipList>,
43 shutdown_tx: watch::Sender<bool>,
44 join: JoinHandle<()>,
45}
46
47impl SwimHandle {
48 pub fn detector(&self) -> &Arc<FailureDetector> {
51 &self.detector
52 }
53
54 pub fn membership(&self) -> &Arc<MembershipList> {
57 &self.membership
58 }
59
60 pub fn dissemination(&self) -> &Arc<DisseminationQueue> {
64 self.detector.dissemination()
65 }
66
67 pub async fn shutdown(self) {
70 let _ = self.shutdown_tx.send(true);
71 let _ = self.join.await;
72 }
73}
74
75pub async fn spawn(
89 cfg: SwimConfig,
90 local_id: NodeId,
91 local_addr: SocketAddr,
92 seeds: Vec<SocketAddr>,
93 transport: Arc<dyn Transport>,
94) -> Result<SwimHandle, SwimError> {
95 spawn_with_subscribers(cfg, local_id, local_addr, seeds, transport, Vec::new()).await
96}
97
98pub async fn spawn_with_subscribers(
102 cfg: SwimConfig,
103 local_id: NodeId,
104 local_addr: SocketAddr,
105 seeds: Vec<SocketAddr>,
106 transport: Arc<dyn Transport>,
107 subscribers: Vec<Arc<dyn MembershipSubscriber>>,
108) -> Result<SwimHandle, SwimError> {
109 cfg.validate()?;
110
111 let membership = Arc::new(MembershipList::new_local(
112 local_id.clone(),
113 local_addr,
114 cfg.initial_incarnation,
115 ));
116
117 for seed_addr in &seeds {
120 if *seed_addr == local_addr {
121 continue;
122 }
123 membership.apply(&MemberUpdate {
124 node_id: NodeId::from_validated(format!("seed:{seed_addr}")),
126 addr: seed_addr.to_string(),
127 state: MemberState::Alive,
128 incarnation: Incarnation::ZERO,
129 });
130 }
131
132 let initial_inc = cfg.initial_incarnation;
133 let detector = Arc::new(FailureDetector::with_subscribers(
134 cfg,
135 Arc::clone(&membership),
136 transport,
137 ProbeScheduler::new(),
138 subscribers,
139 ));
140
141 detector.dissemination().enqueue(MemberUpdate {
147 node_id: local_id.clone(),
148 addr: local_addr.to_string(),
149 state: MemberState::Alive,
150 incarnation: initial_inc,
151 });
152
153 let (shutdown_tx, shutdown_rx) = watch::channel(false);
154 let join = tokio::spawn({
155 let detector = Arc::clone(&detector);
156 async move { detector.run(shutdown_rx).await }
157 });
158
159 Ok(SwimHandle {
160 detector,
161 membership,
162 shutdown_tx,
163 join,
164 })
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::swim::detector::TransportFabric;
171 use std::net::{IpAddr, Ipv4Addr};
172 use std::time::Duration;
173
174 fn addr(p: u16) -> SocketAddr {
175 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), p)
176 }
177
178 fn cfg() -> SwimConfig {
179 SwimConfig {
180 probe_interval: Duration::from_millis(100),
181 probe_timeout: Duration::from_millis(40),
182 indirect_probes: 2,
183 suspicion_mult: 4,
184 min_suspicion: Duration::from_millis(500),
185 initial_incarnation: Incarnation::ZERO,
186 max_piggyback: 6,
187 fanout_lambda: 3,
188 }
189 }
190
191 #[tokio::test]
192 async fn spawn_solo_cluster_has_only_local() {
193 let fab = TransportFabric::new();
194 let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7100)).await);
195 let handle = spawn(
196 cfg(),
197 NodeId::try_new("a").expect("test fixture"),
198 addr(7100),
199 vec![],
200 transport,
201 )
202 .await
203 .expect("spawn");
204 assert_eq!(handle.membership().len(), 1);
205 assert!(handle.membership().is_solo());
206 handle.shutdown().await;
207 }
208
209 #[tokio::test]
210 async fn spawn_seeds_populates_membership() {
211 let fab = TransportFabric::new();
212 let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7110)).await);
213 let handle = spawn(
214 cfg(),
215 NodeId::try_new("a").expect("test fixture"),
216 addr(7110),
217 vec![addr(7111), addr(7112)],
218 transport,
219 )
220 .await
221 .expect("spawn");
222 assert_eq!(handle.membership().len(), 3);
223 handle.shutdown().await;
224 }
225
226 #[tokio::test]
227 async fn spawn_skips_local_addr_in_seeds() {
228 let fab = TransportFabric::new();
229 let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7120)).await);
230 let handle = spawn(
231 cfg(),
232 NodeId::try_new("a").expect("test fixture"),
233 addr(7120),
234 vec![addr(7120), addr(7121)],
235 transport,
236 )
237 .await
238 .expect("spawn");
239 assert_eq!(handle.membership().len(), 2);
241 handle.shutdown().await;
242 }
243
244 #[tokio::test]
245 async fn invalid_config_rejected_before_task_spawned() {
246 let fab = TransportFabric::new();
247 let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7130)).await);
248 let mut bad = cfg();
249 bad.probe_timeout = bad.probe_interval; let res = spawn(
251 bad,
252 NodeId::try_new("a").expect("test fixture"),
253 addr(7130),
254 vec![],
255 transport,
256 )
257 .await;
258 match res {
259 Err(SwimError::InvalidConfig { .. }) => {}
260 Err(other) => panic!("expected InvalidConfig, got {other:?}"),
261 Ok(_) => panic!("expected InvalidConfig error"),
262 }
263 }
264
265 #[tokio::test]
266 async fn shutdown_joins_promptly() {
267 let fab = TransportFabric::new();
268 let transport: Arc<dyn Transport> = Arc::new(fab.bind(addr(7140)).await);
269 let handle = spawn(
270 cfg(),
271 NodeId::try_new("a").expect("test fixture"),
272 addr(7140),
273 vec![],
274 transport,
275 )
276 .await
277 .expect("spawn");
278 let start = std::time::Instant::now();
279 tokio::time::timeout(Duration::from_millis(500), handle.shutdown())
280 .await
281 .expect("shutdown did not join within budget");
282 assert!(start.elapsed() < Duration::from_millis(500));
283 }
284}