armature_distributed/
leader.rs1use redis::AsyncCommands;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::RwLock;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12#[derive(Debug, Error)]
14pub enum LeaderError {
15 #[error("Election failed: {0}")]
16 ElectionFailed(String),
17
18 #[error("Redis error: {0}")]
19 RedisError(#[from] redis::RedisError),
20
21 #[error("Not the leader")]
22 NotLeader,
23}
24
25pub type LeaderCallback =
27 Arc<dyn Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
28
29pub struct LeaderElection {
31 key: String,
33
34 node_id: String,
36
37 ttl: Duration,
39
40 refresh_interval: Duration,
42
43 conn: Arc<RwLock<redis::aio::ConnectionManager>>,
45
46 is_leader: Arc<AtomicBool>,
48
49 on_elected: Option<LeaderCallback>,
51
52 on_revoked: Option<LeaderCallback>,
54
55 running: Arc<AtomicBool>,
57}
58
59impl LeaderElection {
60 pub fn new(key: impl Into<String>, ttl: Duration, conn: redis::aio::ConnectionManager) -> Self {
78 let refresh_interval = Duration::from_millis((ttl.as_millis() / 3) as u64);
79
80 Self {
81 key: key.into(),
82 node_id: Uuid::new_v4().to_string(),
83 ttl,
84 refresh_interval,
85 conn: Arc::new(RwLock::new(conn)),
86 is_leader: Arc::new(AtomicBool::new(false)),
87 on_elected: None,
88 on_revoked: None,
89 running: Arc::new(AtomicBool::new(false)),
90 }
91 }
92
93 pub fn on_elected<F, Fut>(mut self, callback: F) -> Self
95 where
96 F: Fn() -> Fut + Send + Sync + 'static,
97 Fut: std::future::Future<Output = ()> + Send + 'static,
98 {
99 self.on_elected = Some(Arc::new(move || Box::pin(callback())));
100 self
101 }
102
103 pub fn on_revoked<F, Fut>(mut self, callback: F) -> Self
105 where
106 F: Fn() -> Fut + Send + Sync + 'static,
107 Fut: std::future::Future<Output = ()> + Send + 'static,
108 {
109 self.on_revoked = Some(Arc::new(move || Box::pin(callback())));
110 self
111 }
112
113 pub fn is_leader(&self) -> bool {
115 self.is_leader.load(Ordering::Acquire)
116 }
117
118 pub fn node_id(&self) -> &str {
120 &self.node_id
121 }
122
123 pub async fn start(self: Arc<Self>) -> Result<(), LeaderError> {
125 self.running.store(true, Ordering::Release);
126
127 info!(
128 "Starting leader election for key: {} (node: {})",
129 self.key, self.node_id
130 );
131
132 loop {
133 if !self.running.load(Ordering::Acquire) {
134 break;
135 }
136
137 match self.try_become_leader().await {
139 Ok(became_leader) => {
140 let was_leader = self.is_leader.load(Ordering::Acquire);
141
142 if became_leader && !was_leader {
143 self.is_leader.store(true, Ordering::Release);
145 info!("Node {} became leader for {}", self.node_id, self.key);
146
147 if let Some(callback) = &self.on_elected {
148 callback().await;
149 }
150 } else if !became_leader && was_leader {
151 self.is_leader.store(false, Ordering::Release);
153 warn!("Node {} lost leadership for {}", self.node_id, self.key);
154
155 if let Some(callback) = &self.on_revoked {
156 callback().await;
157 }
158 } else if became_leader {
159 debug!(
161 "Node {} refreshed leadership for {}",
162 self.node_id, self.key
163 );
164 }
165 }
166 Err(e) => {
167 error!("Leader election error: {}", e);
168
169 if self.is_leader.swap(false, Ordering::Release) {
171 if let Some(callback) = &self.on_revoked {
172 callback().await;
173 }
174 }
175 }
176 }
177
178 tokio::time::sleep(self.refresh_interval).await;
180 }
181
182 if self.is_leader.load(Ordering::Acquire) {
184 let _ = self.resign().await;
185 }
186
187 Ok(())
188 }
189
190 pub async fn stop(&self) {
192 self.running.store(false, Ordering::Release);
193 }
194
195 async fn try_become_leader(&self) -> Result<bool, LeaderError> {
197 let mut conn = self.conn.write().await;
198 let ttl_ms = self.ttl.as_millis() as usize;
199
200 let script = r#"
202 local current = redis.call("get", KEYS[1])
203 if current == false or current == ARGV[1] then
204 redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[2])
205 return 1
206 else
207 return 0
208 end
209 "#;
210
211 let result: i32 = redis::Script::new(script)
212 .key(&self.key)
213 .arg(&self.node_id)
214 .arg(ttl_ms)
215 .invoke_async(&mut *conn)
216 .await?;
217
218 Ok(result == 1)
219 }
220
221 async fn resign(&self) -> Result<(), LeaderError> {
223 let mut conn = self.conn.write().await;
224
225 let script = r#"
227 if redis.call("get", KEYS[1]) == ARGV[1] then
228 return redis.call("del", KEYS[1])
229 else
230 return 0
231 end
232 "#;
233
234 let _: i32 = redis::Script::new(script)
235 .key(&self.key)
236 .arg(&self.node_id)
237 .invoke_async(&mut *conn)
238 .await?;
239
240 self.is_leader.store(false, Ordering::Release);
241 info!("Node {} resigned from leadership", self.node_id);
242
243 Ok(())
244 }
245
246 pub async fn get_leader(&self) -> Result<Option<String>, LeaderError> {
248 let mut conn = self.conn.write().await;
249 let leader: Option<String> = conn.get(&self.key).await?;
250 Ok(leader)
251 }
252}
253
254pub struct LeaderElectionBuilder {
256 key: String,
257 ttl: Duration,
258 on_elected: Option<LeaderCallback>,
259 on_revoked: Option<LeaderCallback>,
260}
261
262impl LeaderElectionBuilder {
263 pub fn new(key: impl Into<String>) -> Self {
265 Self {
266 key: key.into(),
267 ttl: Duration::from_secs(30),
268 on_elected: None,
269 on_revoked: None,
270 }
271 }
272
273 pub fn with_ttl(mut self, ttl: Duration) -> Self {
275 self.ttl = ttl;
276 self
277 }
278
279 pub fn on_elected<F, Fut>(mut self, callback: F) -> Self
281 where
282 F: Fn() -> Fut + Send + Sync + 'static,
283 Fut: std::future::Future<Output = ()> + Send + 'static,
284 {
285 self.on_elected = Some(Arc::new(move || Box::pin(callback())));
286 self
287 }
288
289 pub fn on_revoked<F, Fut>(mut self, callback: F) -> Self
291 where
292 F: Fn() -> Fut + Send + Sync + 'static,
293 Fut: std::future::Future<Output = ()> + Send + 'static,
294 {
295 self.on_revoked = Some(Arc::new(move || Box::pin(callback())));
296 self
297 }
298
299 pub fn build(self, conn: redis::aio::ConnectionManager) -> LeaderElection {
301 let mut election = LeaderElection::new(self.key, self.ttl, conn);
302 election.on_elected = self.on_elected;
303 election.on_revoked = self.on_revoked;
304 election
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_leader_election_builder() {
314 let builder = LeaderElectionBuilder::new("test-leader").with_ttl(Duration::from_secs(60));
315
316 assert_eq!(builder.key, "test-leader");
317 assert_eq!(builder.ttl, Duration::from_secs(60));
318 }
319}