armature_distributed/
leader.rs

1//! Distributed leader election using Redis
2
3use redis::AsyncCommands;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::RwLock;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12/// Leader election errors
13#[derive(Debug, Error)]
14pub enum LeaderError {
15    #[error("Election failed: {0}")]
16    ElectionFailed(String),
17
18    #[error("Redis error: {0}")]
19    RedisError(#[from] redis::RedisError),
20
21    #[error("Not the leader")]
22    NotLeader,
23}
24
25/// Leader election callback
26pub type LeaderCallback =
27    Arc<dyn Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
28
29/// Leader election coordinator
30pub struct LeaderElection {
31    /// Election key in Redis
32    key: String,
33
34    /// Unique node ID
35    node_id: String,
36
37    /// TTL for leadership
38    ttl: Duration,
39
40    /// Refresh interval (should be less than TTL)
41    refresh_interval: Duration,
42
43    /// Redis connection
44    conn: Arc<RwLock<redis::aio::ConnectionManager>>,
45
46    /// Is this node the leader?
47    is_leader: Arc<AtomicBool>,
48
49    /// Callback when becoming leader
50    on_elected: Option<LeaderCallback>,
51
52    /// Callback when losing leadership
53    on_revoked: Option<LeaderCallback>,
54
55    /// Running flag
56    running: Arc<AtomicBool>,
57}
58
59impl LeaderElection {
60    /// Create new leader election coordinator
61    ///
62    /// # Examples
63    ///
64    /// ```rust,ignore
65    /// use armature_distributed::LeaderElection;
66    /// use std::time::Duration;
67    ///
68    /// let client = redis::Client::open("redis://127.0.0.1/")?;
69    /// let conn = client.get_connection_manager().await?;
70    ///
71    /// let election = LeaderElection::new(
72    ///     "my-service-leader",
73    ///     Duration::from_secs(30),
74    ///     conn,
75    /// );
76    /// ```
77    pub fn new(key: impl Into<String>, ttl: Duration, conn: redis::aio::ConnectionManager) -> Self {
78        let refresh_interval = Duration::from_millis((ttl.as_millis() / 3) as u64);
79
80        Self {
81            key: key.into(),
82            node_id: Uuid::new_v4().to_string(),
83            ttl,
84            refresh_interval,
85            conn: Arc::new(RwLock::new(conn)),
86            is_leader: Arc::new(AtomicBool::new(false)),
87            on_elected: None,
88            on_revoked: None,
89            running: Arc::new(AtomicBool::new(false)),
90        }
91    }
92
93    /// Set callback for when this node becomes leader
94    pub fn on_elected<F, Fut>(mut self, callback: F) -> Self
95    where
96        F: Fn() -> Fut + Send + Sync + 'static,
97        Fut: std::future::Future<Output = ()> + Send + 'static,
98    {
99        self.on_elected = Some(Arc::new(move || Box::pin(callback())));
100        self
101    }
102
103    /// Set callback for when this node loses leadership
104    pub fn on_revoked<F, Fut>(mut self, callback: F) -> Self
105    where
106        F: Fn() -> Fut + Send + Sync + 'static,
107        Fut: std::future::Future<Output = ()> + Send + 'static,
108    {
109        self.on_revoked = Some(Arc::new(move || Box::pin(callback())));
110        self
111    }
112
113    /// Check if this node is the leader
114    pub fn is_leader(&self) -> bool {
115        self.is_leader.load(Ordering::Acquire)
116    }
117
118    /// Get the node ID
119    pub fn node_id(&self) -> &str {
120        &self.node_id
121    }
122
123    /// Start participating in leader election
124    pub async fn start(self: Arc<Self>) -> Result<(), LeaderError> {
125        self.running.store(true, Ordering::Release);
126
127        info!(
128            "Starting leader election for key: {} (node: {})",
129            self.key, self.node_id
130        );
131
132        loop {
133            if !self.running.load(Ordering::Acquire) {
134                break;
135            }
136
137            // Try to become leader
138            match self.try_become_leader().await {
139                Ok(became_leader) => {
140                    let was_leader = self.is_leader.load(Ordering::Acquire);
141
142                    if became_leader && !was_leader {
143                        // Newly elected
144                        self.is_leader.store(true, Ordering::Release);
145                        info!("Node {} became leader for {}", self.node_id, self.key);
146
147                        if let Some(callback) = &self.on_elected {
148                            callback().await;
149                        }
150                    } else if !became_leader && was_leader {
151                        // Lost leadership
152                        self.is_leader.store(false, Ordering::Release);
153                        warn!("Node {} lost leadership for {}", self.node_id, self.key);
154
155                        if let Some(callback) = &self.on_revoked {
156                            callback().await;
157                        }
158                    } else if became_leader {
159                        // Still leader, just refreshed
160                        debug!(
161                            "Node {} refreshed leadership for {}",
162                            self.node_id, self.key
163                        );
164                    }
165                }
166                Err(e) => {
167                    error!("Leader election error: {}", e);
168
169                    // If we were leader but encountered an error, we're no longer leader
170                    if self.is_leader.swap(false, Ordering::Release) {
171                        if let Some(callback) = &self.on_revoked {
172                            callback().await;
173                        }
174                    }
175                }
176            }
177
178            // Wait before next attempt
179            tokio::time::sleep(self.refresh_interval).await;
180        }
181
182        // Clean up on stop
183        if self.is_leader.load(Ordering::Acquire) {
184            let _ = self.resign().await;
185        }
186
187        Ok(())
188    }
189
190    /// Stop participating in leader election
191    pub async fn stop(&self) {
192        self.running.store(false, Ordering::Release);
193    }
194
195    /// Try to become or maintain leadership
196    async fn try_become_leader(&self) -> Result<bool, LeaderError> {
197        let mut conn = self.conn.write().await;
198        let ttl_ms = self.ttl.as_millis() as usize;
199
200        // Use Lua script for atomic operation
201        let script = r#"
202            local current = redis.call("get", KEYS[1])
203            if current == false or current == ARGV[1] then
204                redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[2])
205                return 1
206            else
207                return 0
208            end
209        "#;
210
211        let result: i32 = redis::Script::new(script)
212            .key(&self.key)
213            .arg(&self.node_id)
214            .arg(ttl_ms)
215            .invoke_async(&mut *conn)
216            .await?;
217
218        Ok(result == 1)
219    }
220
221    /// Resign from leadership
222    async fn resign(&self) -> Result<(), LeaderError> {
223        let mut conn = self.conn.write().await;
224
225        // Only delete if we're still the leader
226        let script = r#"
227            if redis.call("get", KEYS[1]) == ARGV[1] then
228                return redis.call("del", KEYS[1])
229            else
230                return 0
231            end
232        "#;
233
234        let _: i32 = redis::Script::new(script)
235            .key(&self.key)
236            .arg(&self.node_id)
237            .invoke_async(&mut *conn)
238            .await?;
239
240        self.is_leader.store(false, Ordering::Release);
241        info!("Node {} resigned from leadership", self.node_id);
242
243        Ok(())
244    }
245
246    /// Get current leader node ID
247    pub async fn get_leader(&self) -> Result<Option<String>, LeaderError> {
248        let mut conn = self.conn.write().await;
249        let leader: Option<String> = conn.get(&self.key).await?;
250        Ok(leader)
251    }
252}
253
254/// Leader election builder
255pub struct LeaderElectionBuilder {
256    key: String,
257    ttl: Duration,
258    on_elected: Option<LeaderCallback>,
259    on_revoked: Option<LeaderCallback>,
260}
261
262impl LeaderElectionBuilder {
263    /// Create new builder
264    pub fn new(key: impl Into<String>) -> Self {
265        Self {
266            key: key.into(),
267            ttl: Duration::from_secs(30),
268            on_elected: None,
269            on_revoked: None,
270        }
271    }
272
273    /// Set TTL
274    pub fn with_ttl(mut self, ttl: Duration) -> Self {
275        self.ttl = ttl;
276        self
277    }
278
279    /// Set elected callback
280    pub fn on_elected<F, Fut>(mut self, callback: F) -> Self
281    where
282        F: Fn() -> Fut + Send + Sync + 'static,
283        Fut: std::future::Future<Output = ()> + Send + 'static,
284    {
285        self.on_elected = Some(Arc::new(move || Box::pin(callback())));
286        self
287    }
288
289    /// Set revoked callback
290    pub fn on_revoked<F, Fut>(mut self, callback: F) -> Self
291    where
292        F: Fn() -> Fut + Send + Sync + 'static,
293        Fut: std::future::Future<Output = ()> + Send + 'static,
294    {
295        self.on_revoked = Some(Arc::new(move || Box::pin(callback())));
296        self
297    }
298
299    /// Build the leader election coordinator
300    pub fn build(self, conn: redis::aio::ConnectionManager) -> LeaderElection {
301        let mut election = LeaderElection::new(self.key, self.ttl, conn);
302        election.on_elected = self.on_elected;
303        election.on_revoked = self.on_revoked;
304        election
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_leader_election_builder() {
314        let builder = LeaderElectionBuilder::new("test-leader").with_ttl(Duration::from_secs(60));
315
316        assert_eq!(builder.key, "test-leader");
317        assert_eq!(builder.ttl, Duration::from_secs(60));
318    }
319}