distributed_scheduler/driver/
redis.rs

1use std::{
2    fmt::Debug,
3    sync::{atomic::AtomicBool, Arc},
4};
5
6use redis::{aio::ConnectionLike, AsyncCommands};
7
8use super::{utils, Driver};
9
10const DEFAULT_TIMEOUT: u64 = 3;
11
12#[derive(Clone)]
13pub struct RedisDriver<C>
14where
15    C: ConnectionLike,
16{
17    con: C,
18
19    service_name: String,
20    node_id: String,
21    started: Arc<AtomicBool>,
22    timeout: u64,
23}
24
25impl<C> Debug for RedisDriver<C>
26where
27    C: ConnectionLike,
28{
29    fn fmt(
30        &self,
31        f: &mut std::fmt::Formatter<'_>,
32    ) -> std::fmt::Result {
33        f.debug_struct("RedisDriver")
34            .field("service_name", &self.service_name)
35            .field("node_id", &self.node_id)
36            .field("started", &self.started)
37            .field("timeout", &self.timeout)
38            .finish()
39    }
40}
41
42#[derive(Debug, thiserror::Error)]
43pub enum Error {
44    #[error("Redis error: {0}")]
45    Redis(#[from] redis::RedisError),
46    #[error("Empty service name")]
47    EmptyServiceName,
48    #[error("Empty node id")]
49    EmptyNodeId,
50}
51
52impl<C> RedisDriver<C>
53where
54    C: ConnectionLike,
55{
56    pub async fn new(
57        con: C,
58        service_name: &str,
59        node_id: &str,
60    ) -> Result<Self, Error> {
61        if service_name.is_empty() {
62            return Err(Error::EmptyServiceName);
63        }
64
65        if node_id.is_empty() {
66            return Err(Error::EmptyNodeId);
67        }
68
69        Ok(Self {
70            con,
71            service_name: service_name.into(),
72            node_id: utils::get_key_prefix(service_name) + node_id,
73            started: Arc::new(AtomicBool::new(false)),
74            timeout: DEFAULT_TIMEOUT,
75        })
76    }
77
78    pub fn with_timeout(
79        mut self,
80        timeout: u64,
81    ) -> Self {
82        self.timeout = timeout;
83        self
84    }
85
86    pub fn timeout(&self) -> u64 {
87        self.timeout
88    }
89}
90
91#[async_trait::async_trait]
92impl<C> Driver for RedisDriver<C>
93where
94    C: ConnectionLike + Send + Sync + Clone + 'static,
95{
96    type Error = Error;
97
98    fn node_id(&self) -> String {
99        self.node_id.clone()
100    }
101
102    /// Scan the redis server to get the nodes
103    async fn get_nodes(&self) -> Result<Vec<String>, Self::Error> {
104        let pattern = utils::get_key_prefix(&self.service_name) + "*";
105
106        let mut con = self.con.clone();
107        let mut res = con.scan_match(pattern).await?;
108
109        let mut nodes: Vec<String> = Vec::new();
110        while let Some(key) = res.next_item().await {
111            nodes.push(key);
112        }
113        Ok(nodes)
114    }
115
116    /// Start a routine to send heartbeat to the redis server
117    async fn start(&mut self) -> Result<(), Self::Error> {
118        // check if the driver has already started
119        if self.started.load(std::sync::atomic::Ordering::SeqCst) {
120            tracing::warn!("Driver has already started");
121            return Ok(());
122        }
123
124        // set the driver as started
125        self.started.store(true, std::sync::atomic::Ordering::SeqCst);
126
127        // start the heartbeat
128        tokio::spawn({
129            let con = self.con.clone();
130            let node_id = self.node_id.clone();
131            let timeout = self.timeout;
132            let started = self.started.clone();
133
134            async move {
135                heartbeat(&node_id, timeout, con, started)
136                    .await
137                    .expect("Failed to start scheduler driver heartbeat")
138            }
139        });
140
141        Ok(())
142    }
143}
144
145impl<C> Drop for RedisDriver<C>
146where
147    C: ConnectionLike,
148{
149    fn drop(&mut self) {
150        self.started.store(false, std::sync::atomic::Ordering::SeqCst);
151    }
152}
153
154/// Register the node in the redis
155///
156/// # Arguments
157///
158/// * `node_id` - The id of the node
159/// * `timeout` - The timeout of the node
160/// * `con` - The redis connection
161async fn register_node<C: ConnectionLike + Send + Sync>(
162    node_id: &str,
163    timeout: u64,
164    con: &mut C,
165) -> Result<(), redis::RedisError> {
166    con.set_ex(node_id, node_id, timeout).await?;
167    Ok(())
168}
169
170/// Heartbeat function to keep the node alive
171///
172/// # Arguments
173///
174/// * `service_name` - The name of the service
175/// * `node_id` - The id of the node
176/// * `timeout` - The timeout of the node
177/// * `con` - The redis connection
178/// * `started` - The atomic bool to check if the driver has started
179///
180/// # Returns
181///
182/// * `Result<(), Box<dyn std::error::Error>` - The result of the function
183async fn heartbeat<C: ConnectionLike + Send + Sync>(
184    node_id: &str,
185    timeout: u64,
186    con: C,
187    started: Arc<AtomicBool>,
188) -> Result<(), Box<dyn std::error::Error>> {
189    let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
190    let mut con = con;
191    let mut error_time = 0;
192
193    tracing::debug!("Started heartbeat");
194
195    loop {
196        // check if the driver has stopped
197        if !started.load(std::sync::atomic::Ordering::SeqCst) {
198            break;
199        }
200
201        // tick the interval
202        interval.tick().await;
203
204        // register the node
205        register_node(node_id, timeout, &mut con)
206            .await
207            .map_err(|e| {
208                error_time += 1;
209                tracing::error!("Failed to register node: {:?}", e);
210            })
211            .ok();
212
213        // check if the error time is greater than 5
214        if error_time >= 5 {
215            panic!("Failed to register node 5 times, stopping heartbeat");
216        }
217    }
218
219    tracing::info!("Heartbeat stopped");
220    Ok(())
221}
222
223#[cfg(test)]
224mod tests {
225    use redis_test::{MockCmd, MockRedisConnection};
226
227    use super::*;
228
229    #[tokio::test]
230    async fn test_register_node_success() {
231        let node_id = "test-node";
232        let timeout = 10_u64;
233
234        let mut mock_con = MockRedisConnection::new(vec![MockCmd::new(
235            redis::cmd("SETEX").arg(node_id).arg(timeout as usize).arg(node_id),
236            Ok(redis::Value::Okay),
237        )]);
238
239        // Perform the node registration
240        let result = register_node(node_id, timeout, &mut mock_con).await;
241
242        assert!(result.is_ok(), "Node registration should be successful");
243    }
244
245    #[tokio::test]
246    async fn test_get_nodes_success() {
247        let service_name = "test-service";
248        let node_id = "test-node";
249        let pattern = utils::get_key_prefix(service_name) + "*";
250
251        let keys = ["test-service-node1", "test-service-node2", "test-service-node3"];
252        let keys_as_redis_values = keys
253            .iter()
254            .map(|k| redis::Value::BulkString(k.to_string().into_bytes()))
255            .collect::<Vec<_>>();
256
257        let mock_con = MockRedisConnection::new(vec![MockCmd::new(
258            redis::cmd("SCAN").arg("0").arg("MATCH").arg(&pattern),
259            Ok(redis::Value::Array(keys_as_redis_values)),
260        )]);
261
262        // Perform the node registration
263        let driver = RedisDriver::new(mock_con, service_name, node_id).await.unwrap();
264        let result = driver.get_nodes().await;
265
266        assert!(result.is_ok(), "Get nodes should be successful");
267        assert_eq!(result.unwrap(), keys, "The nodes should match");
268    }
269}