1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fmt::Debug;
4use std::net::SocketAddr;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8
9use chitchat::transport::Transport;
10use chitchat::{
11 spawn_chitchat,
12 ChitchatConfig,
13 ChitchatHandle,
14 ClusterStateSnapshot,
15 FailureDetectorConfig,
16 NodeId,
17};
18use tokio::sync::watch;
19use tokio_stream::wrappers::WatchStream;
20use tokio_stream::StreamExt;
21use tracing::{debug, error, info};
22
23use crate::error::NodeError;
24use crate::statistics::ClusterStatistics;
25use crate::DEFAULT_DATA_CENTER;
26
27static DATA_CENTER_KEY: &str = "data_center";
28const GOSSIP_INTERVAL: Duration = if cfg!(test) {
29 Duration::from_millis(500)
30} else {
31 Duration::from_secs(1)
32};
33pub type NodeMembership = BTreeMap<crate::NodeId, ClusterMember>;
34
35#[derive(Clone, Debug, Eq, PartialEq)]
36pub struct ClusterMember {
37 pub node_id: crate::NodeId,
39 pub public_addr: SocketAddr,
41 pub data_center: String,
45}
46
47impl ClusterMember {
48 pub fn new(
49 node_id: crate::NodeId,
50 public_addr: SocketAddr,
51 data_center: String,
52 ) -> Self {
53 Self {
54 node_id,
55 public_addr,
56 data_center,
57 }
58 }
59
60 pub fn chitchat_id(&self) -> String {
61 self.node_id.to_string()
62 }
63}
64
65impl From<ClusterMember> for NodeId {
66 fn from(member: ClusterMember) -> Self {
67 Self::new(member.chitchat_id(), member.public_addr)
68 }
69}
70
71pub struct ChitchatNode {
72 pub me: Cow<'static, ClusterMember>,
73 statistics: ClusterStatistics,
74 chitchat_handle: ChitchatHandle,
75 members: watch::Receiver<NodeMembership>,
76 stop: Arc<AtomicBool>,
77}
78
79impl ChitchatNode {
80 pub async fn connect(
81 me: ClusterMember,
82 listen_addr: SocketAddr,
83 cluster_id: String,
84 seed_nodes: Vec<String>,
85 failure_detector_config: FailureDetectorConfig,
86 transport: &dyn Transport,
87 statistics: ClusterStatistics,
88 ) -> Result<Self, NodeError> {
89 info!(
90 cluster_id = %cluster_id,
91 node_id = %me.node_id,
92 public_addr = %me.public_addr,
93 listen_gossip_addr = %listen_addr,
94 peer_seed_addrs = %seed_nodes.join(", "),
95 "Joining cluster."
96 );
97
98 statistics.num_live_members.store(1, Ordering::Relaxed);
99 statistics.num_data_centers.store(1, Ordering::Relaxed);
100
101 let cfg = ChitchatConfig {
102 node_id: NodeId::from(me.clone()),
103 cluster_id: cluster_id.clone(),
104 gossip_interval: GOSSIP_INTERVAL,
105 listen_addr,
106 seed_nodes,
107 failure_detector_config,
108 is_ready_predicate: None,
109 };
110
111 let chitchat_handle = spawn_chitchat(
112 cfg,
113 vec![(DATA_CENTER_KEY.to_string(), me.data_center.clone())],
114 transport,
115 )
116 .await
117 .map_err(|e| NodeError::ChitChat(e.to_string()))?;
118
119 let chitchat = chitchat_handle.chitchat();
120 let (members_tx, members_rx) = watch::channel(BTreeMap::new());
121
122 let cluster = ChitchatNode {
123 me: Cow::Owned(me.clone()),
124 chitchat_handle,
125 statistics: statistics.clone(),
126 members: members_rx,
127 stop: Arc::new(Default::default()),
128 };
129
130 let initial_members: BTreeMap<crate::NodeId, ClusterMember> =
131 BTreeMap::from_iter([(me.node_id, me.clone())]);
132 if members_tx.send(initial_members).is_err() {
133 error!("Failed to add itself as the initial member of the cluster.");
134 }
135
136 let stop_flag = cluster.stop.clone();
137 tokio::spawn(async move {
138 let mut node_change_rx = chitchat.lock().await.ready_nodes_watcher();
139
140 while let Some(members_set) = node_change_rx.next().await {
141 let state_snapshot = {
142 let lock = chitchat.lock().await;
143 let dead_member_count = lock.dead_nodes().count();
144
145 statistics
146 .num_dead_members
147 .store(dead_member_count as u64, Ordering::Relaxed);
148 lock.state_snapshot()
149 };
150
151 let mut members = members_set
152 .into_iter()
153 .map(|node_id| build_cluster_member(&node_id, &state_snapshot))
154 .filter_map(|member_res| {
155 match member_res {
156 Ok(member) => {
157 Some((member.node_id, member))
158 },
159 Err(error) => {
160 error!(
161 error = ?error,
162 "Failed to build cluster member from cluster state, ignoring member.",
163 );
164 None
165 },
166 }
167 })
168 .collect::<BTreeMap<_, _>>();
169 members.insert(me.node_id, me.clone());
170
171 statistics
172 .num_live_members
173 .store(members.len() as u64, Ordering::Relaxed);
174
175 if stop_flag.load(Ordering::Relaxed) {
176 debug!("Received a stop signal. Stopping.");
177 break;
178 }
179
180 if members_tx.send(members).is_err() {
181 error!("Failed to update members list. Stopping.");
183 break;
184 }
185 }
186
187 Result::<(), NodeError>::Ok(())
188 });
189
190 Ok(cluster)
191 }
192
193 pub fn member_change_watcher(&self) -> WatchStream<NodeMembership> {
195 WatchStream::new(self.members.clone())
196 }
197
198 pub fn members_watcher(&self) -> watch::Receiver<NodeMembership> {
200 self.members.clone()
201 }
202
203 #[cfg(test)]
204 pub fn members(&self) -> NodeMembership {
206 self.members.borrow().clone()
207 }
208
209 #[inline]
210 pub fn statistics(&self) -> ClusterStatistics {
212 self.statistics.clone()
213 }
214
215 pub async fn shutdown(self) {
217 info!(self_addr = ?self.me.public_addr, "Shutting down the cluster.");
218 let result = self.chitchat_handle.shutdown().await;
219 if let Err(error) = result {
220 error!(self_addr = ?self.me.public_addr, error = ?error, "Error while shutting down.");
221 }
222
223 self.stop.store(true, Ordering::Relaxed);
224 }
225
226 pub async fn wait_for_members<F>(
229 self: &ChitchatNode,
230 mut predicate: F,
231 timeout_after: Duration,
232 ) -> Result<(), anyhow::Error>
233 where
234 F: FnMut(&NodeMembership) -> bool,
235 {
236 use tokio::time::timeout;
237
238 timeout(
239 timeout_after,
240 self.member_change_watcher()
241 .skip_while(|members| !predicate(members))
242 .next(),
243 )
244 .await?;
245 Ok(())
246 }
247}
248
249fn build_cluster_member<'a>(
250 chitchat_id: &'a NodeId,
251 state: &'a ClusterStateSnapshot,
252) -> Result<ClusterMember, String> {
253 let node_state = state.node_states.get(&chitchat_id.id).ok_or_else(|| {
254 format!(
255 "Could not find node ID `{}` in ChitChat state.",
256 chitchat_id.id
257 )
258 })?;
259
260 let data_center = node_state
261 .get(DATA_CENTER_KEY)
262 .unwrap_or(DEFAULT_DATA_CENTER);
263
264 let node_id = chitchat_id
265 .id
266 .parse::<crate::NodeId>()
267 .map_err(|e| format!("Invalid node ID: {e}"))?;
268
269 Ok(ClusterMember::new(
270 node_id,
271 chitchat_id.gossip_public_address,
272 data_center.to_owned(),
273 ))
274}
275
276#[cfg(test)]
277mod tests {
278 use std::sync::atomic::AtomicU8;
279
280 use anyhow::Result;
281 use chitchat::transport::{ChannelTransport, Transport};
282
283 use super::*;
284
285 #[tokio::test]
286 async fn test_cluster_single_node() -> Result<()> {
287 let _ = tracing_subscriber::fmt::try_init();
288
289 let transport = ChannelTransport::default();
290 let cluster = create_node_for_test(Vec::new(), &transport).await?;
291
292 let members: Vec<SocketAddr> = cluster
293 .members()
294 .values()
295 .map(|member| member.public_addr)
296 .collect();
297 let expected_members = vec![cluster.me.public_addr];
298 assert_eq!(members, expected_members);
299 cluster.shutdown().await;
300 Ok(())
301 }
302
303 #[tokio::test]
304 async fn test_cluster_propagated_state() -> Result<()> {
305 let _ = tracing_subscriber::fmt::try_init();
306
307 let transport = ChannelTransport::default();
308 let node1 = create_node_for_test(Vec::new(), &transport).await?;
309 let node_1_gossip_addr = node1.me.public_addr.to_string();
310 let node2 =
311 create_node_for_test(vec![node_1_gossip_addr.clone()], &transport).await?;
312 let node3 = create_node_for_test(vec![node_1_gossip_addr], &transport).await?;
313
314 let wait_secs = Duration::from_secs(30);
315 for cluster in [&node1, &node2, &node3] {
316 cluster
317 .wait_for_members(|members| members.len() == 3, wait_secs)
318 .await
319 .unwrap();
320 }
321
322 for (id, member) in node1.members() {
323 dbg!(id, &member.public_addr);
324 }
325
326 Ok(())
327 }
328
329 fn create_failure_detector_config_for_test() -> FailureDetectorConfig {
330 FailureDetectorConfig {
331 phi_threshold: 6.0,
332 initial_interval: GOSSIP_INTERVAL,
333 ..Default::default()
334 }
335 }
336
337 pub async fn create_node_for_test_with_id(
338 node_id: crate::NodeId,
339 cluster_id: String,
340 seeds: Vec<String>,
341 transport: &dyn Transport,
342 ) -> Result<ChitchatNode> {
343 let public_addr: SocketAddr = ([127, 0, 0, 1], node_id as u16).into();
344 let failure_detector_config = create_failure_detector_config_for_test();
345 let node = ChitchatNode::connect(
346 ClusterMember::new(node_id, public_addr, "unknown".to_string()),
347 public_addr,
348 cluster_id,
349 seeds,
350 failure_detector_config,
351 transport,
352 ClusterStatistics::default(),
353 )
354 .await?;
355 Ok(node)
356 }
357
358 pub async fn create_node_for_test(
359 seeds: Vec<String>,
360 transport: &dyn Transport,
361 ) -> Result<ChitchatNode> {
362 static NODE_AUTO_INCREMENT: AtomicU8 = AtomicU8::new(1);
363 let node_id = NODE_AUTO_INCREMENT.fetch_add(1, Ordering::Relaxed);
364 let node = create_node_for_test_with_id(
365 node_id,
366 "test-cluster".to_string(),
367 seeds,
368 transport,
369 )
370 .await?;
371 Ok(node)
372 }
373}