distributed_scheduler/
node_pool.rs

1use std::sync::{
2    atomic::{AtomicBool, AtomicU8},
3    Arc,
4};
5
6use thiserror::Error;
7use tokio::sync::RwLock;
8
9use crate::driver::Driver;
10
11/// Node pool, used to manage nodes in a cluster.
12pub struct NodePool<D>
13where
14    D: Driver + Send + Sync,
15{
16    node_id: String,
17
18    pre_nodes: RwLock<Vec<String>>,
19    hash: RwLock<hashring::HashRing<String>>,
20    driver: Arc<D>,
21
22    state_lock: AtomicBool,
23    stop: AtomicBool,
24}
25
26impl<D> std::fmt::Debug for NodePool<D>
27where
28    D: Driver + Send + Sync + std::fmt::Debug,
29{
30    fn fmt(
31        &self,
32        f: &mut std::fmt::Formatter<'_>,
33    ) -> std::fmt::Result {
34        f.debug_struct("NodePool")
35            .field("node_id", &self.node_id)
36            .field("pre_nodes", &self.pre_nodes)
37            .field("hash", &self.hash)
38            .field("driver", &self.driver)
39            .field("state_lock", &self.state_lock)
40            .field("stop", &self.stop)
41            .finish()
42    }
43}
44
45#[derive(Debug, Error)]
46pub enum Error<D>
47where
48    D: Driver + Send + Sync,
49{
50    #[error("No node available")]
51    NoNodeAvailable,
52    #[error("Driver error: {0}")]
53    DriverError(D::Error),
54}
55
56impl<D> NodePool<D>
57where
58    D: Driver + Send + Sync,
59{
60    /// Create a new node pool with the given driver.
61    pub async fn new(mut driver: D) -> Result<Self, Error<D>> {
62        driver.start().await.map_err(Error::DriverError)?;
63
64        // Update the hash ring
65        let mut pre_nodes = Vec::new();
66        let mut hash = hashring::HashRing::new();
67        let state_lock = AtomicBool::new(false);
68
69        update_hash_ring::<D>(
70            &mut pre_nodes,
71            &state_lock,
72            &mut hash,
73            &driver.get_nodes().await.map_err(Error::DriverError)?,
74        )
75        .await?;
76
77        Ok(Self {
78            node_id: driver.node_id(),
79            pre_nodes: RwLock::new(pre_nodes),
80            hash: RwLock::new(hash),
81            driver: Arc::new(driver),
82            state_lock,
83            stop: AtomicBool::new(false),
84        })
85    }
86
87    /// Start the node pool, blocking the current thread.
88    pub async fn start(&self) -> Result<(), Error<D>> {
89        let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
90        let error_time = AtomicU8::new(0);
91
92        loop {
93            interval.tick().await;
94            if self.stop.load(std::sync::atomic::Ordering::SeqCst) {
95                return Ok(());
96            }
97
98            // independent ownership
99            {
100                let nodes = match self.driver.get_nodes().await {
101                    Ok(nodes) => nodes,
102                    Err(_) => continue,
103                };
104
105                let mut pre_nodes = self.pre_nodes.write().await;
106                let mut hash = self.hash.write().await;
107
108                match update_hash_ring::<D>(&mut pre_nodes, &self.state_lock, &mut hash, &nodes).await {
109                    Ok(_) => {
110                        error_time.store(0, std::sync::atomic::Ordering::SeqCst);
111                    }
112                    Err(_) => {
113                        error_time.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
114                        tracing::error!("Failed to update hash ring");
115                    }
116                }
117            }
118
119            if error_time.load(std::sync::atomic::Ordering::SeqCst) >= 5 {
120                panic!("Failed to update hash ring 5 times")
121            }
122        }
123    }
124
125    /// Check if the job should be executed on the current node.
126    pub(crate) async fn check_job_available(
127        &self,
128        job_name: &str,
129    ) -> Result<bool, Error<D>> {
130        let hash = self.hash.read().await;
131        match hash.get(&job_name) {
132            Some(node) if node == &self.node_id => Ok(true),
133            Some(_) => Ok(false),
134            None => Err(Error::NoNodeAvailable),
135        }
136    }
137
138    pub fn stop(&self) {
139        self.stop.store(true, std::sync::atomic::Ordering::SeqCst);
140    }
141}
142
143impl<D> Drop for NodePool<D>
144where
145    D: Driver + Send + Sync,
146{
147    fn drop(&mut self) {
148        self.stop();
149    }
150}
151
152/// Update the hash ring with the given nodes.
153///
154/// # Arguments
155///
156/// * `pre_nodes` - The previous nodes
157/// * `state_lock` - The state lock
158/// * `hash` - The hash ring
159/// * `nodes` - The new nodes
160async fn update_hash_ring<D>(
161    pre_nodes: &mut Vec<String>,
162    state_lock: &AtomicBool,
163    hash: &mut hashring::HashRing<String>,
164    nodes: &Vec<String>,
165) -> Result<(), Error<D>>
166where
167    D: Driver + Send + Sync,
168{
169    if equal_ring(nodes, pre_nodes) {
170        tracing::trace!("Nodes are equal, skipping update, nodes: {:?}", nodes);
171        return Ok(());
172    }
173
174    tracing::info!(
175        "Nodes detected, updating hash ring, pre_nodes: {:?}, now_nodes: {:?}",
176        pre_nodes,
177        nodes
178    );
179
180    // Lock the state
181    if state_lock
182        .compare_exchange(
183            false,
184            true,
185            std::sync::atomic::Ordering::SeqCst,
186            std::sync::atomic::Ordering::SeqCst,
187        )
188        .is_err()
189    {
190        return Ok(());
191    }
192
193    // Update the pre_nodes
194    pre_nodes.clone_from(nodes);
195
196    *hash = hashring::HashRing::new();
197    for node in nodes {
198        hash.add(node.clone());
199    }
200
201    // Unlock the state
202    state_lock.store(false, std::sync::atomic::Ordering::SeqCst);
203
204    Ok(())
205}
206
207/// Compare two rings.
208fn equal_ring(
209    a: &[String],
210    b: &[String],
211) -> bool {
212    if a.len() != b.len() {
213        return false;
214    }
215
216    let mut a_sorted = a.to_vec();
217    let mut pre_nodes_sorted = b.to_vec();
218
219    a_sorted.sort();
220    pre_nodes_sorted.sort();
221
222    a_sorted == pre_nodes_sorted
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::driver::local::LocalDriver;
229
230    #[tokio::test]
231    async fn test_update_hash_ring() {
232        let mut pre_nodes = Vec::new();
233        let state_lock = AtomicBool::new(false);
234        let mut hash = hashring::HashRing::new();
235
236        let nodes = vec!["node1".to_string(), "node2".to_string()];
237
238        update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
239            .await
240            .unwrap();
241        assert_eq!(pre_nodes, nodes);
242        assert_eq!(hash.get(&"test"), Some(&"node2".to_string()));
243
244        let nodes = vec!["node1".to_string(), "node2".to_string(), "node3".to_string()];
245
246        update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
247            .await
248            .unwrap();
249        assert_eq!(pre_nodes, nodes);
250        assert_eq!(hash.get(&"test"), Some(&"node2".to_string()));
251
252        let nodes = vec!["node1".to_string(), "node3".to_string()];
253
254        update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
255            .await
256            .unwrap();
257        assert_eq!(pre_nodes, nodes);
258        assert_eq!(hash.get(&"test"), Some(&"node3".to_string()));
259
260        let nodes = vec!["node1".to_string(), "node3".to_string()];
261
262        update_hash_ring::<LocalDriver>(&mut pre_nodes, &state_lock, &mut hash, &nodes)
263            .await
264            .unwrap();
265        assert_eq!(pre_nodes, nodes);
266        assert_eq!(hash.get(&"test"), Some(&"node3".to_string()));
267    }
268
269    #[tokio::test]
270    async fn test_equal_ring() {
271        let a = vec!["node1".to_string(), "node2".to_string()];
272        let b = vec!["node1".to_string(), "node2".to_string()];
273
274        assert!(equal_ring(&a, &b));
275
276        let a = vec!["node1".to_string(), "node2".to_string()];
277        let b = vec!["node2".to_string(), "node1".to_string()];
278
279        assert!(equal_ring(&a, &b));
280
281        let a = vec!["node1".to_string(), "node2".to_string()];
282        let b = vec!["node1".to_string(), "node3".to_string()];
283        assert!(!equal_ring(&a, &b));
284    }
285}