use distributed_lock_core::error::{LockError, LockResult};
use fred::prelude::*;
use super::helper::RedLockHelper;
use super::timeouts::RedLockTimeouts;
#[derive(Debug)]
pub struct RedLockAcquireResult {
pub acquire_results: Vec<bool>,
}
impl RedLockAcquireResult {
pub fn new(acquire_results: Vec<bool>) -> Self {
Self { acquire_results }
}
pub fn is_successful(&self, total_clients: usize) -> bool {
let success_count = self.acquire_results.iter().filter(|&&v| v).count();
RedLockHelper::has_sufficient_successes(success_count, total_clients)
}
pub fn success_count(&self) -> usize {
self.acquire_results.iter().filter(|&&v| v).count()
}
}
use tokio::task::JoinSet;
pub async fn acquire_redlock<F, Fut>(
try_acquire_fn: F,
clients: &[RedisClient],
timeouts: &RedLockTimeouts,
cancel_token: &tokio::sync::watch::Receiver<bool>,
) -> LockResult<Option<RedLockAcquireResult>>
where
F: Fn(&RedisClient) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = LockResult<bool>> + Send,
{
if clients.is_empty() {
return Err(LockError::InvalidName(
"no Redis clients provided".to_string(),
));
}
if clients.len() == 1 {
return acquire_single_client(try_acquire_fn, &clients[0], timeouts, cancel_token).await;
}
let acquire_timeout = timeouts.acquire_timeout();
let timeout_duration = acquire_timeout.as_duration();
let mut join_set = JoinSet::new();
for (idx, client) in clients.iter().enumerate() {
let client_clone = client.clone();
let try_acquire_fn_clone = try_acquire_fn.clone();
join_set.spawn(async move {
let result = try_acquire_fn_clone(&client_clone).await;
(idx, result)
});
}
let mut results: Vec<Option<bool>> = vec![None; clients.len()];
let mut success_count = 0;
let mut fail_count = 0;
let timeout_fut = async {
if let Some(dur) = timeout_duration {
tokio::time::sleep(dur).await;
true } else {
std::future::pending::<bool>().await;
false }
};
tokio::pin!(timeout_fut);
let mut cancel_rx = cancel_token.clone();
loop {
tokio::select! {
_ = cancel_rx.changed() => {
if *cancel_rx.borrow() {
return Err(LockError::Cancelled);
}
}
_ = &mut timeout_fut => {
return Ok(None);
}
Some(join_result) = join_set.join_next() => {
match join_result {
Ok((idx, lock_res)) => {
match lock_res {
Ok(true) => {
results[idx] = Some(true);
success_count += 1;
if RedLockHelper::has_sufficient_successes(success_count, clients.len()) {
let final_results: Vec<bool> = results.iter().map(|r| r.unwrap_or(false)).collect();
return Ok(Some(RedLockAcquireResult::new(final_results)));
}
}
Ok(false) => {
results[idx] = Some(false);
fail_count += 1;
}
Err(_) => {
results[idx] = Some(false);
fail_count += 1;
}
}
}
Err(_) => {
fail_count += 1;
}
}
if RedLockHelper::has_too_many_failures_or_faults(fail_count, clients.len()) {
return Ok(None);
}
}
else => {
return Ok(None);
}
}
}
}
async fn acquire_single_client<F, Fut>(
try_acquire_fn: F,
client: &RedisClient,
timeouts: &RedLockTimeouts,
cancel_token: &tokio::sync::watch::Receiver<bool>,
) -> LockResult<Option<RedLockAcquireResult>>
where
F: Fn(&RedisClient) -> Fut + Send + Sync,
Fut: std::future::Future<Output = LockResult<bool>> + Send,
{
let acquire_timeout = timeouts.acquire_timeout();
let timeout_duration = acquire_timeout.as_duration();
if cancel_token.has_changed().unwrap_or(false) && *cancel_token.borrow() {
return Err(LockError::Cancelled);
}
let acquire_future = try_acquire_fn(client);
let result = if let Some(timeout_dur) = timeout_duration {
match tokio::time::timeout(timeout_dur, acquire_future).await {
Ok(Ok(true)) => true,
Ok(Ok(false)) => return Ok(None),
Ok(Err(e)) => return Err(e),
Err(_) => return Ok(None), }
} else {
loop {
let mut cancel_rx = cancel_token.clone();
tokio::select! {
result = try_acquire_fn(client) => {
match result {
Ok(true) => break true,
Ok(false) => return Ok(None),
Err(e) => return Err(e),
}
}
_ = cancel_rx.changed() => {
if *cancel_rx.borrow() {
return Err(LockError::Cancelled);
}
}
}
}
};
Ok(Some(RedLockAcquireResult::new(vec![result])))
}