distributed_lock_redis/redlock/
acquire.rs

1//! RedLock acquire algorithm implementation.
2
3use distributed_lock_core::error::{LockError, LockResult};
4use fred::prelude::*;
5
6use super::helper::RedLockHelper;
7use super::timeouts::RedLockTimeouts;
8
9/// Result of a RedLock acquire operation.
10#[derive(Debug)]
11pub struct RedLockAcquireResult {
12    /// Results indexed by client position (true = success, false = failed).
13    pub acquire_results: Vec<bool>,
14}
15
16impl RedLockAcquireResult {
17    /// Creates a new acquire result.
18    pub fn new(acquire_results: Vec<bool>) -> Self {
19        Self { acquire_results }
20    }
21
22    /// Checks if the acquire was successful (majority consensus).
23    pub fn is_successful(&self, total_clients: usize) -> bool {
24        let success_count = self.acquire_results.iter().filter(|&&v| v).count();
25        RedLockHelper::has_sufficient_successes(success_count, total_clients)
26    }
27
28    /// Returns the number of successful acquisitions.
29    pub fn success_count(&self) -> usize {
30        self.acquire_results.iter().filter(|&&v| v).count()
31    }
32}
33
34// ... imports ...
35use tokio::task::JoinSet; // Add this import
36
37// ...
38
39/// Acquires a lock using the RedLock algorithm across multiple Redis servers.
40///
41/// The algorithm requires majority consensus: for N servers, we need at least
42/// (N/2 + 1) successful acquisitions.
43///
44/// # Arguments
45///
46/// * `try_acquire_fn` - Function that attempts to acquire the lock on a single client
47/// * `clients` - List of Redis clients to acquire on
48/// * `timeouts` - Timeout configuration
49/// * `cancel_token` - Cancellation token
50pub async fn acquire_redlock<F, Fut>(
51    try_acquire_fn: F,
52    clients: &[RedisClient],
53    timeouts: &RedLockTimeouts,
54    cancel_token: &tokio::sync::watch::Receiver<bool>,
55) -> LockResult<Option<RedLockAcquireResult>>
56where
57    F: Fn(&RedisClient) -> Fut + Send + Sync + Clone + 'static,
58    Fut: std::future::Future<Output = LockResult<bool>> + Send,
59{
60    if clients.is_empty() {
61        return Err(LockError::InvalidName(
62            "no Redis clients provided".to_string(),
63        ));
64    }
65
66    // Single client case - simpler path
67    if clients.len() == 1 {
68        return acquire_single_client(try_acquire_fn, &clients[0], timeouts, cancel_token).await;
69    }
70
71    // Multi-client RedLock algorithm
72    let acquire_timeout = timeouts.acquire_timeout();
73    let timeout_duration = acquire_timeout.as_duration();
74
75    // Use JoinSet to manage concurrent acquisition tasks
76    let mut join_set = JoinSet::new();
77
78    for (idx, client) in clients.iter().enumerate() {
79        let client_clone = client.clone();
80        let try_acquire_fn_clone = try_acquire_fn.clone();
81        join_set.spawn(async move {
82            let result = try_acquire_fn_clone(&client_clone).await;
83            (idx, result)
84        });
85    }
86
87    let mut results: Vec<Option<bool>> = vec![None; clients.len()];
88    let mut success_count = 0;
89    let mut fail_count = 0;
90
91    // Create timeout future (if applicable)
92    let timeout_fut = async {
93        if let Some(dur) = timeout_duration {
94            tokio::time::sleep(dur).await;
95            true // Timed out
96        } else {
97            std::future::pending::<bool>().await;
98            false // Never times out
99        }
100    };
101    tokio::pin!(timeout_fut);
102
103    // Create cancellation future
104    let mut cancel_rx = cancel_token.clone();
105
106    loop {
107        tokio::select! {
108            // 1. Check for cancellation
109            _ = cancel_rx.changed() => {
110               if *cancel_rx.borrow() {
111                   // Abort remaining tasks automatically when join_set is dropped
112                   return Err(LockError::Cancelled);
113               }
114            }
115
116            // 2. Check for overall timeout
117            _ = &mut timeout_fut => {
118                // Abort remaining tasks
119                return Ok(None);
120            }
121
122            // 3. Process completed tasks
123            Some(join_result) = join_set.join_next() => {
124                match join_result {
125                    Ok((idx, lock_res)) => {
126                        match lock_res {
127                            Ok(true) => {
128                                results[idx] = Some(true);
129                                success_count += 1;
130                                if RedLockHelper::has_sufficient_successes(success_count, clients.len()) {
131                                    // Majority reached!
132                                    // Fill remaining with false (unknown state but logically not holding)
133                                    // Note: In RedLock, we technically might have acquired others,
134                                    // but we only claim the ones we know about + default false.
135                                    // Background release will clean up anyway.
136                                    let final_results: Vec<bool> = results.iter().map(|r| r.unwrap_or(false)).collect();
137                                    return Ok(Some(RedLockAcquireResult::new(final_results)));
138                                }
139                            }
140                            Ok(false) => {
141                                results[idx] = Some(false);
142                                fail_count += 1;
143                            }
144                            Err(_) => {
145                                // Error acquiring
146                                results[idx] = Some(false);
147                                fail_count += 1;
148                            }
149                        }
150                    }
151                    Err(_) => {
152                        // Task panicked or cancelled
153                        fail_count += 1;
154                    }
155                }
156
157                // Check if too many failures to ever succeed
158                if RedLockHelper::has_too_many_failures_or_faults(fail_count, clients.len()) {
159                     return Ok(None);
160                }
161            }
162
163            // 4. If all tasks finished and we are here (join_set empty), we failed
164            else => {
165                 return Ok(None);
166            }
167        }
168    }
169}
170
171/// Acquires a lock on a single Redis client (simpler path).
172async fn acquire_single_client<F, Fut>(
173    try_acquire_fn: F,
174    client: &RedisClient,
175    timeouts: &RedLockTimeouts,
176    cancel_token: &tokio::sync::watch::Receiver<bool>,
177) -> LockResult<Option<RedLockAcquireResult>>
178where
179    F: Fn(&RedisClient) -> Fut + Send + Sync,
180    Fut: std::future::Future<Output = LockResult<bool>> + Send,
181{
182    let acquire_timeout = timeouts.acquire_timeout();
183    let timeout_duration = acquire_timeout.as_duration();
184
185    // Check for cancellation first
186    if cancel_token.has_changed().unwrap_or(false) && *cancel_token.borrow() {
187        return Err(LockError::Cancelled);
188    }
189
190    let acquire_future = try_acquire_fn(client);
191
192    let result = if let Some(timeout_dur) = timeout_duration {
193        match tokio::time::timeout(timeout_dur, acquire_future).await {
194            Ok(Ok(true)) => true,
195            Ok(Ok(false)) => return Ok(None),
196            Ok(Err(e)) => return Err(e),
197            Err(_) => return Ok(None), // Timeout
198        }
199    } else {
200        // No timeout - wait indefinitely (but check cancellation)
201        loop {
202            let mut cancel_rx = cancel_token.clone();
203            tokio::select! {
204                result = try_acquire_fn(client) => {
205                    match result {
206                        Ok(true) => break true,
207                        Ok(false) => return Ok(None),
208                        Err(e) => return Err(e),
209                    }
210                }
211                _ = cancel_rx.changed() => {
212                    if *cancel_rx.borrow() {
213                        return Err(LockError::Cancelled);
214                    }
215                    // Continue waiting
216                }
217            }
218        }
219    };
220
221    Ok(Some(RedLockAcquireResult::new(vec![result])))
222}