distributed_scheduler/driver/
redis.rs1use std::{
2 fmt::Debug,
3 sync::{atomic::AtomicBool, Arc},
4};
5
6use redis::{aio::ConnectionLike, AsyncCommands};
7
8use super::{utils, Driver};
9
10const DEFAULT_TIMEOUT: u64 = 3;
11
12#[derive(Clone)]
13pub struct RedisDriver<C>
14where
15 C: ConnectionLike,
16{
17 con: C,
18
19 service_name: String,
20 node_id: String,
21 started: Arc<AtomicBool>,
22 timeout: u64,
23}
24
25impl<C> Debug for RedisDriver<C>
26where
27 C: ConnectionLike,
28{
29 fn fmt(
30 &self,
31 f: &mut std::fmt::Formatter<'_>,
32 ) -> std::fmt::Result {
33 f.debug_struct("RedisDriver")
34 .field("service_name", &self.service_name)
35 .field("node_id", &self.node_id)
36 .field("started", &self.started)
37 .field("timeout", &self.timeout)
38 .finish()
39 }
40}
41
42#[derive(Debug, thiserror::Error)]
43pub enum Error {
44 #[error("Redis error: {0}")]
45 Redis(#[from] redis::RedisError),
46 #[error("Empty service name")]
47 EmptyServiceName,
48 #[error("Empty node id")]
49 EmptyNodeId,
50}
51
52impl<C> RedisDriver<C>
53where
54 C: ConnectionLike,
55{
56 pub async fn new(
57 con: C,
58 service_name: &str,
59 node_id: &str,
60 ) -> Result<Self, Error> {
61 if service_name.is_empty() {
62 return Err(Error::EmptyServiceName);
63 }
64
65 if node_id.is_empty() {
66 return Err(Error::EmptyNodeId);
67 }
68
69 Ok(Self {
70 con,
71 service_name: service_name.into(),
72 node_id: utils::get_key_prefix(service_name) + node_id,
73 started: Arc::new(AtomicBool::new(false)),
74 timeout: DEFAULT_TIMEOUT,
75 })
76 }
77
78 pub fn with_timeout(
79 mut self,
80 timeout: u64,
81 ) -> Self {
82 self.timeout = timeout;
83 self
84 }
85
86 pub fn timeout(&self) -> u64 {
87 self.timeout
88 }
89}
90
91#[async_trait::async_trait]
92impl<C> Driver for RedisDriver<C>
93where
94 C: ConnectionLike + Send + Sync + Clone + 'static,
95{
96 type Error = Error;
97
98 fn node_id(&self) -> String {
99 self.node_id.clone()
100 }
101
102 async fn get_nodes(&self) -> Result<Vec<String>, Self::Error> {
104 let pattern = utils::get_key_prefix(&self.service_name) + "*";
105
106 let mut con = self.con.clone();
107 let mut res = con.scan_match(pattern).await?;
108
109 let mut nodes: Vec<String> = Vec::new();
110 while let Some(key) = res.next_item().await {
111 nodes.push(key);
112 }
113 Ok(nodes)
114 }
115
116 async fn start(&mut self) -> Result<(), Self::Error> {
118 if self.started.load(std::sync::atomic::Ordering::SeqCst) {
120 tracing::warn!("Driver has already started");
121 return Ok(());
122 }
123
124 self.started.store(true, std::sync::atomic::Ordering::SeqCst);
126
127 tokio::spawn({
129 let con = self.con.clone();
130 let node_id = self.node_id.clone();
131 let timeout = self.timeout;
132 let started = self.started.clone();
133
134 async move {
135 heartbeat(&node_id, timeout, con, started)
136 .await
137 .expect("Failed to start scheduler driver heartbeat")
138 }
139 });
140
141 Ok(())
142 }
143}
144
145impl<C> Drop for RedisDriver<C>
146where
147 C: ConnectionLike,
148{
149 fn drop(&mut self) {
150 self.started.store(false, std::sync::atomic::Ordering::SeqCst);
151 }
152}
153
154async fn register_node<C: ConnectionLike + Send + Sync>(
162 node_id: &str,
163 timeout: u64,
164 con: &mut C,
165) -> Result<(), redis::RedisError> {
166 con.set_ex(node_id, node_id, timeout).await?;
167 Ok(())
168}
169
170async fn heartbeat<C: ConnectionLike + Send + Sync>(
184 node_id: &str,
185 timeout: u64,
186 con: C,
187 started: Arc<AtomicBool>,
188) -> Result<(), Box<dyn std::error::Error>> {
189 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
190 let mut con = con;
191 let mut error_time = 0;
192
193 tracing::debug!("Started heartbeat");
194
195 loop {
196 if !started.load(std::sync::atomic::Ordering::SeqCst) {
198 break;
199 }
200
201 interval.tick().await;
203
204 register_node(node_id, timeout, &mut con)
206 .await
207 .map_err(|e| {
208 error_time += 1;
209 tracing::error!("Failed to register node: {:?}", e);
210 })
211 .ok();
212
213 if error_time >= 5 {
215 panic!("Failed to register node 5 times, stopping heartbeat");
216 }
217 }
218
219 tracing::info!("Heartbeat stopped");
220 Ok(())
221}
222
223#[cfg(test)]
224mod tests {
225 use redis_test::{MockCmd, MockRedisConnection};
226
227 use super::*;
228
229 #[tokio::test]
230 async fn test_register_node_success() {
231 let node_id = "test-node";
232 let timeout = 10_u64;
233
234 let mut mock_con = MockRedisConnection::new(vec![MockCmd::new(
235 redis::cmd("SETEX").arg(node_id).arg(timeout as usize).arg(node_id),
236 Ok(redis::Value::Okay),
237 )]);
238
239 let result = register_node(node_id, timeout, &mut mock_con).await;
241
242 assert!(result.is_ok(), "Node registration should be successful");
243 }
244
245 #[tokio::test]
246 async fn test_get_nodes_success() {
247 let service_name = "test-service";
248 let node_id = "test-node";
249 let pattern = utils::get_key_prefix(service_name) + "*";
250
251 let keys = ["test-service-node1", "test-service-node2", "test-service-node3"];
252 let keys_as_redis_values = keys
253 .iter()
254 .map(|k| redis::Value::BulkString(k.to_string().into_bytes()))
255 .collect::<Vec<_>>();
256
257 let mock_con = MockRedisConnection::new(vec![MockCmd::new(
258 redis::cmd("SCAN").arg("0").arg("MATCH").arg(&pattern),
259 Ok(redis::Value::Array(keys_as_redis_values)),
260 )]);
261
262 let driver = RedisDriver::new(mock_con, service_name, node_id).await.unwrap();
264 let result = driver.get_nodes().await;
265
266 assert!(result.is_ok(), "Get nodes should be successful");
267 assert_eq!(result.unwrap(), keys, "The nodes should match");
268 }
269}