distributed_scheduler/driver/
redis_zset.rs

1use std::{
2    fmt::Debug,
3    sync::{
4        atomic::{AtomicBool, Ordering},
5        Arc,
6    },
7};
8
9use redis::{aio::ConnectionLike, AsyncCommands};
10
11use super::{utils, Driver};
12
13const DEFAULT_TIMEOUT: u64 = 3;
14
15#[derive(Clone)]
16pub struct RedisZSetDriver<C>
17where
18    C: ConnectionLike,
19{
20    con: C,
21
22    service_name: String,
23    node_id: String,
24    started: Arc<AtomicBool>,
25    timeout: u64,
26    notify: Arc<tokio::sync::Notify>,
27}
28
29impl<C> Debug for RedisZSetDriver<C>
30where
31    C: ConnectionLike,
32{
33    fn fmt(
34        &self,
35        f: &mut std::fmt::Formatter<'_>,
36    ) -> std::fmt::Result {
37        f.debug_struct("RedisDriver")
38            .field("service_name", &self.service_name)
39            .field("node_id", &self.node_id)
40            .field("started", &self.started)
41            .field("timeout", &self.timeout)
42            .finish()
43    }
44}
45
46#[derive(Debug, thiserror::Error)]
47pub enum Error {
48    #[error("Redis error: {0}")]
49    Redis(#[from] redis::RedisError),
50    #[error("Empty service name")]
51    EmptyServiceName,
52    #[error("Empty node id")]
53    EmptyNodeId,
54}
55
56impl<C> RedisZSetDriver<C>
57where
58    C: ConnectionLike,
59{
60    pub async fn new(
61        con: C,
62        service_name: &str,
63        node_id: &str,
64    ) -> Result<Self, Error> {
65        if service_name.is_empty() {
66            return Err(Error::EmptyServiceName);
67        }
68
69        if node_id.is_empty() {
70            return Err(Error::EmptyNodeId);
71        }
72
73        Ok(Self {
74            con,
75            service_name: service_name.into(),
76            node_id: utils::get_key_prefix(service_name) + node_id,
77            started: Arc::new(AtomicBool::new(false)),
78            timeout: DEFAULT_TIMEOUT,
79            notify: Arc::new(tokio::sync::Notify::new()),
80        })
81    }
82
83    pub fn with_timeout(
84        mut self,
85        timeout: u64,
86    ) -> Self {
87        self.timeout = timeout;
88        self
89    }
90
91    fn shutdown(&self) {
92        self.started.store(false, Ordering::SeqCst);
93        self.notify.notify_waiters();
94    }
95}
96
97#[async_trait::async_trait]
98impl<C> Driver for RedisZSetDriver<C>
99where
100    C: ConnectionLike + Send + Sync + Clone + 'static,
101{
102    type Error = Error;
103
104    fn node_id(&self) -> String {
105        self.node_id.clone()
106    }
107
108    /// Get the list of nodes from the redis server, use `ZRANGEBYSCORE` to get the latest nodes
109    async fn get_nodes(&self) -> Result<Vec<String>, Self::Error> {
110        let key = utils::get_zset_key(&self.service_name);
111
112        let mut con = self.con.clone();
113        let min = (chrono::Utc::now() - chrono::Duration::seconds(self.timeout as i64)).timestamp(); // current timestamp - timeout
114        let nodes: Vec<String> = con.zrangebyscore(key, min, "+inf").await?;
115        Ok(nodes)
116    }
117
118    /// Start a routine to send the heartbeat to the redis server
119    async fn start(&mut self) -> Result<(), Self::Error> {
120        // check if the driver has already started
121        if self.started.swap(true, Ordering::SeqCst) {
122            return Ok(());
123        }
124
125        // start the heartbeat
126        let mut con = self.con.clone();
127        let service_name = self.service_name.clone();
128        let node_id = self.node_id.clone();
129        let started = self.started.clone();
130        let shutdown_notify = self.notify.clone();
131        tokio::spawn(async move {
132            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
133
134            while started.load(Ordering::SeqCst) {
135                tokio::select! {
136                    _ = interval.tick() => {
137                        if let Err(e) = register_node(&service_name, &node_id, &mut con, chrono::Utc::now().timestamp()).await {
138                            tracing::error!("Failed to register node: {:?}", e);
139                        }
140                    }
141                    _ = shutdown_notify.notified() => {
142                        tracing::info!("Heartbeat task received shutdown signal");
143                        break;
144                    }
145                }
146            }
147        });
148
149        Ok(())
150    }
151}
152
153impl<C> Drop for RedisZSetDriver<C>
154where
155    C: ConnectionLike,
156{
157    fn drop(&mut self) {
158        self.shutdown();
159    }
160}
161
162/// Register the node in the redis.
163/// Use redis command `ZADD` to add the node to the zset.
164///
165/// # Arguments
166///
167/// * `service_name` - The name of the service
168/// * `node_id` - The id of the node
169/// * `con` - The redis connection
170/// * `time` - The time to register the node
171async fn register_node<C: ConnectionLike + Send + Sync>(
172    service_name: &str,
173    node_id: &str,
174    con: &mut C,
175    time: i64,
176) -> Result<(), redis::RedisError> {
177    con.zadd(utils::get_zset_key(service_name), node_id, time).await?;
178
179    Ok(())
180}
181
182#[cfg(test)]
183mod tests {
184    use redis_test::{MockCmd, MockRedisConnection};
185
186    use super::*;
187
188    #[tokio::test]
189    async fn test_register_node_success() {
190        let service_name = "test-service";
191        let node_id = "test-node";
192        let ts = chrono::Utc::now().timestamp();
193
194        let mut mock_con = MockRedisConnection::new(vec![MockCmd::new(
195            redis::cmd("ZADD")
196                .arg(utils::get_zset_key(service_name))
197                .arg(ts)
198                .arg(node_id),
199            Ok(redis::Value::Okay),
200        )]);
201
202        // Perform the node registration
203        let result = register_node(service_name, node_id, &mut mock_con, ts).await;
204
205        assert!(
206            result.is_ok(),
207            "Register node should be successful: {}",
208            result.unwrap_err()
209        );
210    }
211
212    #[tokio::test]
213    async fn test_get_nodes_success() {
214        let service_name = "test-service";
215        let node_id = "test-node";
216
217        let keys = ["node1", "node2", "node3"];
218        let keys_as_redis_value: Vec<redis::Value> = keys
219            .iter()
220            .map(|k| redis::Value::BulkString(k.as_bytes().to_vec()))
221            .collect();
222
223        let mock_con = MockRedisConnection::new(vec![MockCmd::new(
224            redis::cmd("ZRANGEBYSCORE")
225                .arg(utils::get_zset_key(service_name))
226                .arg((chrono::Utc::now() - chrono::Duration::seconds(DEFAULT_TIMEOUT as i64)).timestamp())
227                .arg("+inf"),
228            Ok(redis::Value::Array(keys_as_redis_value)),
229        )]);
230
231        // Perform the node registration
232        let driver = RedisZSetDriver::new(mock_con, service_name, node_id).await.unwrap();
233        let result = driver.get_nodes().await;
234
235        assert!(
236            result.is_ok(),
237            "Get nodes should be successful: {}",
238            result.unwrap_err()
239        );
240        assert_eq!(result.unwrap(), keys, "The nodes should match");
241    }
242}