datacake_node/
node.rs

1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fmt::Debug;
4use std::net::SocketAddr;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8
9use chitchat::transport::Transport;
10use chitchat::{
11    spawn_chitchat,
12    ChitchatConfig,
13    ChitchatHandle,
14    ClusterStateSnapshot,
15    FailureDetectorConfig,
16    NodeId,
17};
18use tokio::sync::watch;
19use tokio_stream::wrappers::WatchStream;
20use tokio_stream::StreamExt;
21use tracing::{debug, error, info};
22
23use crate::error::NodeError;
24use crate::statistics::ClusterStatistics;
25use crate::DEFAULT_DATA_CENTER;
26
27static DATA_CENTER_KEY: &str = "data_center";
28const GOSSIP_INTERVAL: Duration = if cfg!(test) {
29    Duration::from_millis(500)
30} else {
31    Duration::from_secs(1)
32};
33pub type NodeMembership = BTreeMap<crate::NodeId, ClusterMember>;
34
35#[derive(Clone, Debug, Eq, PartialEq)]
36pub struct ClusterMember {
37    /// A unique ID for the given node in the cluster.
38    pub node_id: crate::NodeId,
39    /// The public address of the nod.
40    pub public_addr: SocketAddr,
41    /// The data center / availability zone the node is in.
42    ///
43    /// This is used to select nodes for sending consistency tasks to.
44    pub data_center: String,
45}
46
47impl ClusterMember {
48    pub fn new(
49        node_id: crate::NodeId,
50        public_addr: SocketAddr,
51        data_center: String,
52    ) -> Self {
53        Self {
54            node_id,
55            public_addr,
56            data_center,
57        }
58    }
59
60    pub fn chitchat_id(&self) -> String {
61        self.node_id.to_string()
62    }
63}
64
65impl From<ClusterMember> for NodeId {
66    fn from(member: ClusterMember) -> Self {
67        Self::new(member.chitchat_id(), member.public_addr)
68    }
69}
70
71pub struct ChitchatNode {
72    pub me: Cow<'static, ClusterMember>,
73    statistics: ClusterStatistics,
74    chitchat_handle: ChitchatHandle,
75    members: watch::Receiver<NodeMembership>,
76    stop: Arc<AtomicBool>,
77}
78
79impl ChitchatNode {
80    pub async fn connect(
81        me: ClusterMember,
82        listen_addr: SocketAddr,
83        cluster_id: String,
84        seed_nodes: Vec<String>,
85        failure_detector_config: FailureDetectorConfig,
86        transport: &dyn Transport,
87        statistics: ClusterStatistics,
88    ) -> Result<Self, NodeError> {
89        info!(
90            cluster_id = %cluster_id,
91            node_id = %me.node_id,
92            public_addr = %me.public_addr,
93            listen_gossip_addr = %listen_addr,
94            peer_seed_addrs = %seed_nodes.join(", "),
95            "Joining cluster."
96        );
97
98        statistics.num_live_members.store(1, Ordering::Relaxed);
99        statistics.num_data_centers.store(1, Ordering::Relaxed);
100
101        let cfg = ChitchatConfig {
102            node_id: NodeId::from(me.clone()),
103            cluster_id: cluster_id.clone(),
104            gossip_interval: GOSSIP_INTERVAL,
105            listen_addr,
106            seed_nodes,
107            failure_detector_config,
108            is_ready_predicate: None,
109        };
110
111        let chitchat_handle = spawn_chitchat(
112            cfg,
113            vec![(DATA_CENTER_KEY.to_string(), me.data_center.clone())],
114            transport,
115        )
116        .await
117        .map_err(|e| NodeError::ChitChat(e.to_string()))?;
118
119        let chitchat = chitchat_handle.chitchat();
120        let (members_tx, members_rx) = watch::channel(BTreeMap::new());
121
122        let cluster = ChitchatNode {
123            me: Cow::Owned(me.clone()),
124            chitchat_handle,
125            statistics: statistics.clone(),
126            members: members_rx,
127            stop: Arc::new(Default::default()),
128        };
129
130        let initial_members: BTreeMap<crate::NodeId, ClusterMember> =
131            BTreeMap::from_iter([(me.node_id, me.clone())]);
132        if members_tx.send(initial_members).is_err() {
133            error!("Failed to add itself as the initial member of the cluster.");
134        }
135
136        let stop_flag = cluster.stop.clone();
137        tokio::spawn(async move {
138            let mut node_change_rx = chitchat.lock().await.ready_nodes_watcher();
139
140            while let Some(members_set) = node_change_rx.next().await {
141                let state_snapshot = {
142                    let lock = chitchat.lock().await;
143                    let dead_member_count = lock.dead_nodes().count();
144
145                    statistics
146                        .num_dead_members
147                        .store(dead_member_count as u64, Ordering::Relaxed);
148                    lock.state_snapshot()
149                };
150
151                let mut members = members_set
152                    .into_iter()
153                    .map(|node_id| build_cluster_member(&node_id, &state_snapshot))
154                    .filter_map(|member_res| {
155                        match member_res {
156                            Ok(member) => {
157                                Some((member.node_id, member))
158                            },
159                            Err(error) => {
160                                error!(
161                                    error = ?error,
162                                    "Failed to build cluster member from cluster state, ignoring member.",
163                                );
164                                None
165                            },
166                        }
167                    })
168                    .collect::<BTreeMap<_, _>>();
169                members.insert(me.node_id, me.clone());
170
171                statistics
172                    .num_live_members
173                    .store(members.len() as u64, Ordering::Relaxed);
174
175                if stop_flag.load(Ordering::Relaxed) {
176                    debug!("Received a stop signal. Stopping.");
177                    break;
178                }
179
180                if members_tx.send(members).is_err() {
181                    // Somehow the cluster has been dropped.
182                    error!("Failed to update members list. Stopping.");
183                    break;
184                }
185            }
186
187            Result::<(), NodeError>::Ok(())
188        });
189
190        Ok(cluster)
191    }
192
193    /// Return [WatchStream] for monitoring change of node members.
194    pub fn member_change_watcher(&self) -> WatchStream<NodeMembership> {
195        WatchStream::new(self.members.clone())
196    }
197
198    /// Returns a handle to the members watcher channel.
199    pub fn members_watcher(&self) -> watch::Receiver<NodeMembership> {
200        self.members.clone()
201    }
202
203    #[cfg(test)]
204    /// Returns a list of node members.
205    pub fn members(&self) -> NodeMembership {
206        self.members.borrow().clone()
207    }
208
209    #[inline]
210    /// Get a handle to the live statistics.
211    pub fn statistics(&self) -> ClusterStatistics {
212        self.statistics.clone()
213    }
214
215    /// Leave the cluster.
216    pub async fn shutdown(self) {
217        info!(self_addr = ?self.me.public_addr, "Shutting down the cluster.");
218        let result = self.chitchat_handle.shutdown().await;
219        if let Err(error) = result {
220            error!(self_addr = ?self.me.public_addr, error = ?error, "Error while shutting down.");
221        }
222
223        self.stop.store(true, Ordering::Relaxed);
224    }
225
226    /// Convenience method for testing that waits for the predicate to hold true for the cluster's
227    /// members.
228    pub async fn wait_for_members<F>(
229        self: &ChitchatNode,
230        mut predicate: F,
231        timeout_after: Duration,
232    ) -> Result<(), anyhow::Error>
233    where
234        F: FnMut(&NodeMembership) -> bool,
235    {
236        use tokio::time::timeout;
237
238        timeout(
239            timeout_after,
240            self.member_change_watcher()
241                .skip_while(|members| !predicate(members))
242                .next(),
243        )
244        .await?;
245        Ok(())
246    }
247}
248
249fn build_cluster_member<'a>(
250    chitchat_id: &'a NodeId,
251    state: &'a ClusterStateSnapshot,
252) -> Result<ClusterMember, String> {
253    let node_state = state.node_states.get(&chitchat_id.id).ok_or_else(|| {
254        format!(
255            "Could not find node ID `{}` in ChitChat state.",
256            chitchat_id.id
257        )
258    })?;
259
260    let data_center = node_state
261        .get(DATA_CENTER_KEY)
262        .unwrap_or(DEFAULT_DATA_CENTER);
263
264    let node_id = chitchat_id
265        .id
266        .parse::<crate::NodeId>()
267        .map_err(|e| format!("Invalid node ID: {e}"))?;
268
269    Ok(ClusterMember::new(
270        node_id,
271        chitchat_id.gossip_public_address,
272        data_center.to_owned(),
273    ))
274}
275
276#[cfg(test)]
277mod tests {
278    use std::sync::atomic::AtomicU8;
279
280    use anyhow::Result;
281    use chitchat::transport::{ChannelTransport, Transport};
282
283    use super::*;
284
285    #[tokio::test]
286    async fn test_cluster_single_node() -> Result<()> {
287        let _ = tracing_subscriber::fmt::try_init();
288
289        let transport = ChannelTransport::default();
290        let cluster = create_node_for_test(Vec::new(), &transport).await?;
291
292        let members: Vec<SocketAddr> = cluster
293            .members()
294            .values()
295            .map(|member| member.public_addr)
296            .collect();
297        let expected_members = vec![cluster.me.public_addr];
298        assert_eq!(members, expected_members);
299        cluster.shutdown().await;
300        Ok(())
301    }
302
303    #[tokio::test]
304    async fn test_cluster_propagated_state() -> Result<()> {
305        let _ = tracing_subscriber::fmt::try_init();
306
307        let transport = ChannelTransport::default();
308        let node1 = create_node_for_test(Vec::new(), &transport).await?;
309        let node_1_gossip_addr = node1.me.public_addr.to_string();
310        let node2 =
311            create_node_for_test(vec![node_1_gossip_addr.clone()], &transport).await?;
312        let node3 = create_node_for_test(vec![node_1_gossip_addr], &transport).await?;
313
314        let wait_secs = Duration::from_secs(30);
315        for cluster in [&node1, &node2, &node3] {
316            cluster
317                .wait_for_members(|members| members.len() == 3, wait_secs)
318                .await
319                .unwrap();
320        }
321
322        for (id, member) in node1.members() {
323            dbg!(id, &member.public_addr);
324        }
325
326        Ok(())
327    }
328
329    fn create_failure_detector_config_for_test() -> FailureDetectorConfig {
330        FailureDetectorConfig {
331            phi_threshold: 6.0,
332            initial_interval: GOSSIP_INTERVAL,
333            ..Default::default()
334        }
335    }
336
337    pub async fn create_node_for_test_with_id(
338        node_id: crate::NodeId,
339        cluster_id: String,
340        seeds: Vec<String>,
341        transport: &dyn Transport,
342    ) -> Result<ChitchatNode> {
343        let public_addr: SocketAddr = ([127, 0, 0, 1], node_id as u16).into();
344        let failure_detector_config = create_failure_detector_config_for_test();
345        let node = ChitchatNode::connect(
346            ClusterMember::new(node_id, public_addr, "unknown".to_string()),
347            public_addr,
348            cluster_id,
349            seeds,
350            failure_detector_config,
351            transport,
352            ClusterStatistics::default(),
353        )
354        .await?;
355        Ok(node)
356    }
357
358    pub async fn create_node_for_test(
359        seeds: Vec<String>,
360        transport: &dyn Transport,
361    ) -> Result<ChitchatNode> {
362        static NODE_AUTO_INCREMENT: AtomicU8 = AtomicU8::new(1);
363        let node_id = NODE_AUTO_INCREMENT.fetch_add(1, Ordering::Relaxed);
364        let node = create_node_for_test_with_id(
365            node_id,
366            "test-cluster".to_string(),
367            seeds,
368            transport,
369        )
370        .await?;
371        Ok(node)
372    }
373}