distributed_scheduler/driver/
redis_zset.rs1use std::{
2 fmt::Debug,
3 sync::{
4 atomic::{AtomicBool, Ordering},
5 Arc,
6 },
7};
8
9use redis::{aio::ConnectionLike, AsyncCommands};
10
11use super::{utils, Driver};
12
13const DEFAULT_TIMEOUT: u64 = 3;
14
15#[derive(Clone)]
16pub struct RedisZSetDriver<C>
17where
18 C: ConnectionLike,
19{
20 con: C,
21
22 service_name: String,
23 node_id: String,
24 started: Arc<AtomicBool>,
25 timeout: u64,
26 notify: Arc<tokio::sync::Notify>,
27}
28
29impl<C> Debug for RedisZSetDriver<C>
30where
31 C: ConnectionLike,
32{
33 fn fmt(
34 &self,
35 f: &mut std::fmt::Formatter<'_>,
36 ) -> std::fmt::Result {
37 f.debug_struct("RedisDriver")
38 .field("service_name", &self.service_name)
39 .field("node_id", &self.node_id)
40 .field("started", &self.started)
41 .field("timeout", &self.timeout)
42 .finish()
43 }
44}
45
46#[derive(Debug, thiserror::Error)]
47pub enum Error {
48 #[error("Redis error: {0}")]
49 Redis(#[from] redis::RedisError),
50 #[error("Empty service name")]
51 EmptyServiceName,
52 #[error("Empty node id")]
53 EmptyNodeId,
54}
55
56impl<C> RedisZSetDriver<C>
57where
58 C: ConnectionLike,
59{
60 pub async fn new(
61 con: C,
62 service_name: &str,
63 node_id: &str,
64 ) -> Result<Self, Error> {
65 if service_name.is_empty() {
66 return Err(Error::EmptyServiceName);
67 }
68
69 if node_id.is_empty() {
70 return Err(Error::EmptyNodeId);
71 }
72
73 Ok(Self {
74 con,
75 service_name: service_name.into(),
76 node_id: utils::get_key_prefix(service_name) + node_id,
77 started: Arc::new(AtomicBool::new(false)),
78 timeout: DEFAULT_TIMEOUT,
79 notify: Arc::new(tokio::sync::Notify::new()),
80 })
81 }
82
83 pub fn with_timeout(
84 mut self,
85 timeout: u64,
86 ) -> Self {
87 self.timeout = timeout;
88 self
89 }
90
91 fn shutdown(&self) {
92 self.started.store(false, Ordering::SeqCst);
93 self.notify.notify_waiters();
94 }
95}
96
97#[async_trait::async_trait]
98impl<C> Driver for RedisZSetDriver<C>
99where
100 C: ConnectionLike + Send + Sync + Clone + 'static,
101{
102 type Error = Error;
103
104 fn node_id(&self) -> String {
105 self.node_id.clone()
106 }
107
108 async fn get_nodes(&self) -> Result<Vec<String>, Self::Error> {
110 let key = utils::get_zset_key(&self.service_name);
111
112 let mut con = self.con.clone();
113 let min = (chrono::Utc::now() - chrono::Duration::seconds(self.timeout as i64)).timestamp(); let nodes: Vec<String> = con.zrangebyscore(key, min, "+inf").await?;
115 Ok(nodes)
116 }
117
118 async fn start(&mut self) -> Result<(), Self::Error> {
120 if self.started.swap(true, Ordering::SeqCst) {
122 return Ok(());
123 }
124
125 let mut con = self.con.clone();
127 let service_name = self.service_name.clone();
128 let node_id = self.node_id.clone();
129 let started = self.started.clone();
130 let shutdown_notify = self.notify.clone();
131 tokio::spawn(async move {
132 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
133
134 while started.load(Ordering::SeqCst) {
135 tokio::select! {
136 _ = interval.tick() => {
137 if let Err(e) = register_node(&service_name, &node_id, &mut con, chrono::Utc::now().timestamp()).await {
138 tracing::error!("Failed to register node: {:?}", e);
139 }
140 }
141 _ = shutdown_notify.notified() => {
142 tracing::info!("Heartbeat task received shutdown signal");
143 break;
144 }
145 }
146 }
147 });
148
149 Ok(())
150 }
151}
152
153impl<C> Drop for RedisZSetDriver<C>
154where
155 C: ConnectionLike,
156{
157 fn drop(&mut self) {
158 self.shutdown();
159 }
160}
161
162async fn register_node<C: ConnectionLike + Send + Sync>(
172 service_name: &str,
173 node_id: &str,
174 con: &mut C,
175 time: i64,
176) -> Result<(), redis::RedisError> {
177 con.zadd(utils::get_zset_key(service_name), node_id, time).await?;
178
179 Ok(())
180}
181
182#[cfg(test)]
183mod tests {
184 use redis_test::{MockCmd, MockRedisConnection};
185
186 use super::*;
187
188 #[tokio::test]
189 async fn test_register_node_success() {
190 let service_name = "test-service";
191 let node_id = "test-node";
192 let ts = chrono::Utc::now().timestamp();
193
194 let mut mock_con = MockRedisConnection::new(vec![MockCmd::new(
195 redis::cmd("ZADD")
196 .arg(utils::get_zset_key(service_name))
197 .arg(ts)
198 .arg(node_id),
199 Ok(redis::Value::Okay),
200 )]);
201
202 let result = register_node(service_name, node_id, &mut mock_con, ts).await;
204
205 assert!(
206 result.is_ok(),
207 "Register node should be successful: {}",
208 result.unwrap_err()
209 );
210 }
211
212 #[tokio::test]
213 async fn test_get_nodes_success() {
214 let service_name = "test-service";
215 let node_id = "test-node";
216
217 let keys = ["node1", "node2", "node3"];
218 let keys_as_redis_value: Vec<redis::Value> = keys
219 .iter()
220 .map(|k| redis::Value::BulkString(k.as_bytes().to_vec()))
221 .collect();
222
223 let mock_con = MockRedisConnection::new(vec![MockCmd::new(
224 redis::cmd("ZRANGEBYSCORE")
225 .arg(utils::get_zset_key(service_name))
226 .arg((chrono::Utc::now() - chrono::Duration::seconds(DEFAULT_TIMEOUT as i64)).timestamp())
227 .arg("+inf"),
228 Ok(redis::Value::Array(keys_as_redis_value)),
229 )]);
230
231 let driver = RedisZSetDriver::new(mock_con, service_name, node_id).await.unwrap();
233 let result = driver.get_nodes().await;
234
235 assert!(
236 result.is_ok(),
237 "Get nodes should be successful: {}",
238 result.unwrap_err()
239 );
240 assert_eq!(result.unwrap(), keys, "The nodes should match");
241 }
242}