use std::{
error::Error,
future::Future,
sync::{
Arc,
atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering},
},
time::Duration,
};
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, trace, warn};
use crate::{
SendOutsideWasm,
executor::{JoinHandle, spawn},
sleep::sleep,
};
pub type CrossProcessLockGeneration = u64;
pub trait TryLock {
#[cfg(not(target_family = "wasm"))]
type LockError: Error + Send + Sync;
#[cfg(target_family = "wasm")]
type LockError: Error;
fn try_lock(
&self,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> impl Future<Output = Result<Option<CrossProcessLockGeneration>, Self::LockError>>
+ SendOutsideWasm;
}
#[derive(Clone, Debug)]
enum WaitingTime {
Some(u32),
Stop,
}
#[derive(Clone, Debug)]
#[must_use = "If unused, the `CrossProcessLock` will unlock at the end of the lease"]
pub struct CrossProcessLockGuard {
num_holders: Arc<AtomicU32>,
is_dirty: Arc<AtomicBool>,
}
impl CrossProcessLockGuard {
fn new(num_holders: Arc<AtomicU32>, is_dirty: Arc<AtomicBool>) -> Self {
Self { num_holders, is_dirty }
}
pub fn is_dirty(&self) -> bool {
self.is_dirty.load(Ordering::SeqCst)
}
pub fn clear_dirty(&self) {
self.is_dirty.store(false, Ordering::SeqCst);
}
}
impl Drop for CrossProcessLockGuard {
fn drop(&mut self) {
self.num_holders.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(Clone, Debug)]
pub struct CrossProcessLock<L>
where
L: TryLock + Clone + SendOutsideWasm + 'static,
{
locker: L,
num_holders: Arc<AtomicU32>,
locking_attempt: Arc<Mutex<()>>,
renew_task: Arc<Mutex<Option<JoinHandle<()>>>>,
lock_key: String,
lock_holder: String,
backoff: Arc<Mutex<WaitingTime>>,
generation: Arc<AtomicU64>,
is_dirty: Arc<AtomicBool>,
}
pub const LEASE_DURATION_MS: u32 = 500;
pub const EXTEND_LEASE_EVERY_MS: u64 = 50;
const INITIAL_BACKOFF_MS: u32 = 10;
pub const MAX_BACKOFF_MS: u32 = 1000;
pub const NO_CROSS_PROCESS_LOCK_GENERATION: CrossProcessLockGeneration = 0;
pub const FIRST_CROSS_PROCESS_LOCK_GENERATION: CrossProcessLockGeneration = 1;
impl<L> CrossProcessLock<L>
where
L: TryLock + Clone + SendOutsideWasm + 'static,
{
pub fn new(locker: L, lock_key: String, lock_holder: String) -> Self {
Self {
locker,
lock_key,
lock_holder,
backoff: Arc::new(Mutex::new(WaitingTime::Some(INITIAL_BACKOFF_MS))),
num_holders: Arc::new(0.into()),
locking_attempt: Arc::new(Mutex::new(())),
renew_task: Default::default(),
generation: Arc::new(AtomicU64::new(NO_CROSS_PROCESS_LOCK_GENERATION)),
is_dirty: Arc::new(AtomicBool::new(false)),
}
}
pub fn is_dirty(&self) -> bool {
self.is_dirty.load(Ordering::SeqCst)
}
pub fn clear_dirty(&self) {
self.is_dirty.store(false, Ordering::SeqCst);
}
#[instrument(skip(self), fields(?self.lock_key, ?self.lock_holder))]
pub async fn try_lock_once(
&self,
) -> Result<Result<CrossProcessLockState, CrossProcessLockUnobtained>, L::LockError> {
let mut _attempt = self.locking_attempt.lock().await;
if self.num_holders.load(Ordering::SeqCst) > 0 {
trace!("We already had the lock, incrementing holder count");
self.num_holders.fetch_add(1, Ordering::SeqCst);
return Ok(Ok(CrossProcessLockState::Clean(CrossProcessLockGuard::new(
self.num_holders.clone(),
self.is_dirty.clone(),
))));
}
if let Some(new_generation) =
self.locker.try_lock(LEASE_DURATION_MS, &self.lock_key, &self.lock_holder).await?
{
match self.generation.swap(new_generation, Ordering::SeqCst) {
NO_CROSS_PROCESS_LOCK_GENERATION => {
trace!(?new_generation, "Setting the lock generation for the first time");
}
previous_generation if previous_generation != new_generation => {
warn!(
?previous_generation,
?new_generation,
"The lock has been obtained, but it's been dirtied!"
);
self.is_dirty.store(true, Ordering::SeqCst);
}
_ => {
trace!("Same lock generation; no problem");
}
}
trace!("Lock obtained!");
} else {
trace!("Couldn't obtain the lock immediately.");
return Ok(Err(CrossProcessLockUnobtained::Busy));
}
trace!("Obtained the lock, spawning the lease extension task.");
let this = (*self).clone();
let mut renew_task = self.renew_task.lock().await;
if let Some(_prev) = renew_task.take() {
#[cfg(not(target_family = "wasm"))]
if !_prev.is_finished() {
trace!("aborting the previous renew task");
_prev.abort();
}
}
*renew_task = Some(spawn(async move {
loop {
{
let _guard = this.locking_attempt.lock().await;
if this.num_holders.load(Ordering::SeqCst) == 0 {
trace!("exiting the lease extension loop");
let fut = this.locker.try_lock(0, &this.lock_key, &this.lock_holder);
let _ = fut.await;
break;
}
}
sleep(Duration::from_millis(EXTEND_LEASE_EVERY_MS)).await;
match this
.locker
.try_lock(LEASE_DURATION_MS, &this.lock_key, &this.lock_holder)
.await
{
Ok(Some(_generation)) => {
}
Ok(None) => {
error!("Failed to renew the lock lease: the lock could not be obtained");
break;
}
Err(err) => {
error!("Error when extending the lock lease: {err:#}");
break;
}
}
}
}));
self.num_holders.fetch_add(1, Ordering::SeqCst);
let guard = CrossProcessLockGuard::new(self.num_holders.clone(), self.is_dirty.clone());
Ok(Ok(if self.is_dirty() {
CrossProcessLockState::Dirty(guard)
} else {
CrossProcessLockState::Clean(guard)
}))
}
#[instrument(skip(self), fields(?self.lock_key, ?self.lock_holder))]
pub async fn spin_lock(
&self,
max_backoff: Option<u32>,
) -> Result<Result<CrossProcessLockState, CrossProcessLockUnobtained>, L::LockError> {
let max_backoff = max_backoff.unwrap_or(MAX_BACKOFF_MS);
loop {
let lock_result = self.try_lock_once().await?;
if lock_result.is_ok() {
*self.backoff.lock().await = WaitingTime::Some(INITIAL_BACKOFF_MS);
return Ok(lock_result);
}
let mut backoff = self.backoff.lock().await;
let wait = match &mut *backoff {
WaitingTime::Some(val) => {
let wait = *val;
*val = val.saturating_mul(2);
if *val >= max_backoff {
*backoff = WaitingTime::Stop;
}
wait
}
WaitingTime::Stop => {
return Ok(Err(CrossProcessLockUnobtained::TimedOut));
}
};
debug!("Waiting {wait} before re-attempting to take the lock");
sleep(Duration::from_millis(wait.into())).await;
}
}
pub fn lock_holder(&self) -> &str {
&self.lock_holder
}
}
#[derive(Debug)]
#[must_use = "If unused, the `CrossProcessLock` will unlock at the end of the lease"]
pub enum CrossProcessLockState {
Clean(CrossProcessLockGuard),
Dirty(CrossProcessLockGuard),
}
impl CrossProcessLockState {
pub fn into_guard(self) -> CrossProcessLockGuard {
match self {
Self::Clean(guard) | Self::Dirty(guard) => guard,
}
}
pub fn map<F, G>(self, mapper: F) -> MappedCrossProcessLockState<G>
where
F: FnOnce(CrossProcessLockGuard) -> G,
{
match self {
Self::Clean(guard) => MappedCrossProcessLockState::Clean(mapper(guard)),
Self::Dirty(guard) => MappedCrossProcessLockState::Dirty(mapper(guard)),
}
}
}
#[derive(Debug)]
#[must_use = "If unused, the `CrossProcessLock` will unlock at the end of the lease"]
pub enum MappedCrossProcessLockState<G> {
Clean(G),
Dirty(G),
}
impl<G> MappedCrossProcessLockState<G> {
pub fn as_clean(&self) -> Option<&G> {
match self {
Self::Clean(guard) => Some(guard),
Self::Dirty(_) => None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CrossProcessLockUnobtained {
#[error(
"The lock couldn't be obtained immediately because it is busy, i.e. it is held by another holder"
)]
Busy,
#[error("The lock couldn't be obtained after several attempts: locking has timed out")]
TimedOut,
}
#[derive(Debug, thiserror::Error)]
pub enum CrossProcessLockError {
#[error(transparent)]
Unobtained(#[from] CrossProcessLockUnobtained),
#[error(transparent)]
#[cfg(not(target_family = "wasm"))]
TryLock(#[from] Box<dyn Error + Send + Sync>),
#[error(transparent)]
#[cfg(target_family = "wasm")]
TryLock(#[from] Box<dyn Error>),
}
#[cfg(test)]
#[cfg(not(target_family = "wasm"))] mod tests {
use std::{
collections::HashMap,
ops::Not,
sync::{Arc, RwLock, atomic},
};
use assert_matches::assert_matches;
use matrix_sdk_test_macros::async_test;
use tokio::{spawn, task::yield_now};
use super::{
CrossProcessLock, CrossProcessLockError, CrossProcessLockGeneration, CrossProcessLockState,
CrossProcessLockUnobtained, TryLock,
memory_store_helper::{Lease, try_take_leased_lock},
};
#[derive(Clone, Default)]
struct TestStore {
leases: Arc<RwLock<HashMap<String, Lease>>>,
}
impl TestStore {
fn try_take_leased_lock(
&self,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Option<CrossProcessLockGeneration> {
try_take_leased_lock(&mut self.leases.write().unwrap(), lease_duration_ms, key, holder)
}
}
#[derive(Debug, thiserror::Error)]
enum DummyError {}
impl From<DummyError> for CrossProcessLockError {
fn from(value: DummyError) -> Self {
Self::TryLock(Box::new(value))
}
}
impl TryLock for TestStore {
type LockError = DummyError;
async fn try_lock(
&self,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Result<Option<CrossProcessLockGeneration>, Self::LockError> {
Ok(self.try_take_leased_lock(lease_duration_ms, key, holder))
}
}
async fn release_lock(lock: CrossProcessLockState) {
drop(lock);
yield_now().await;
}
type TestResult = Result<(), CrossProcessLockError>;
#[async_test]
async fn test_simple_lock_unlock() -> TestResult {
let store = TestStore::default();
let lock = CrossProcessLock::new(store, "key".to_owned(), "first".to_owned());
let guard = lock.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Clean(_));
assert!(lock.is_dirty().not());
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
release_lock(guard).await;
assert!(lock.is_dirty().not());
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 0);
let guard = lock.spin_lock(None).await?.expect("spin lock must be obtained successfully");
assert!(lock.is_dirty().not());
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
release_lock(guard).await;
assert!(lock.is_dirty().not());
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 0);
Ok(())
}
#[async_test]
async fn test_self_recovery() -> TestResult {
let store = TestStore::default();
let lock = CrossProcessLock::new(store.clone(), "key".to_owned(), "first".to_owned());
let guard = lock.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Clean(_));
assert!(lock.is_dirty().not());
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
drop(lock);
let lock = CrossProcessLock::new(store.clone(), "key".to_owned(), "first".to_owned());
let guard =
lock.try_lock_once().await?.expect("lock (again) must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Clean(_));
assert!(lock.is_dirty().not());
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
Ok(())
}
#[async_test]
async fn test_multiple_holders_same_process() -> TestResult {
let store = TestStore::default();
let lock = CrossProcessLock::new(store, "key".to_owned(), "first".to_owned());
let guard1 = lock.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard1, CrossProcessLockState::Clean(_));
let guard2 = lock.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard2, CrossProcessLockState::Clean(_));
assert!(lock.is_dirty().not());
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 2);
release_lock(guard1).await;
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 1);
release_lock(guard2).await;
assert_eq!(lock.num_holders.load(atomic::Ordering::SeqCst), 0);
assert!(lock.is_dirty().not());
Ok(())
}
#[async_test]
async fn test_multiple_processes() -> TestResult {
let store = TestStore::default();
let lock1 = CrossProcessLock::new(store.clone(), "key".to_owned(), "first".to_owned());
let lock2 = CrossProcessLock::new(store, "key".to_owned(), "second".to_owned());
let guard1 = lock1.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard1, CrossProcessLockState::Clean(_));
assert!(lock1.is_dirty().not());
let err = lock2.try_lock_once().await?.expect_err("lock must NOT be obtained");
assert_matches!(err, CrossProcessLockUnobtained::Busy);
let lock2_clone = lock2.clone();
let task = spawn(async move { lock2_clone.spin_lock(Some(500)).await });
yield_now().await;
drop(guard1);
let guard2 = task
.await
.expect("join handle is properly awaited")
.expect("lock is successfully attempted")
.expect("lock must be obtained successfully");
assert_matches!(guard2, CrossProcessLockState::Clean(_));
assert!(lock1.is_dirty().not());
assert!(lock2.is_dirty().not());
assert_matches!(
lock1.spin_lock(Some(200)).await,
Ok(Err(CrossProcessLockUnobtained::TimedOut))
);
Ok(())
}
#[async_test]
async fn test_multiple_processes_up_to_dirty() -> TestResult {
let store = TestStore::default();
let lock1 = CrossProcessLock::new(store.clone(), "key".to_owned(), "first".to_owned());
let lock2 = CrossProcessLock::new(store, "key".to_owned(), "second".to_owned());
{
let guard = lock1.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Clean(_));
assert!(lock1.is_dirty().not());
drop(guard);
yield_now().await;
}
{
let guard = lock2.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Clean(_));
assert!(lock1.is_dirty().not());
drop(guard);
yield_now().await;
}
for _ in 0..3 {
{
let guard =
lock1.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Dirty(_));
assert!(lock1.is_dirty());
drop(guard);
yield_now().await;
}
{
let guard =
lock1.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Dirty(_));
assert!(lock1.is_dirty());
lock1.clear_dirty();
drop(guard);
yield_now().await;
}
{
let guard =
lock1.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Clean(_));
assert!(lock1.is_dirty().not());
drop(guard);
yield_now().await;
}
{
let guard =
lock2.try_lock_once().await?.expect("lock must be obtained successfully");
assert_matches!(guard, CrossProcessLockState::Dirty(_));
assert!(lock2.is_dirty());
lock2.clear_dirty();
drop(guard);
yield_now().await;
}
}
Ok(())
}
}
pub mod memory_store_helper {
use std::collections::{HashMap, hash_map::Entry};
use ruma::time::{Duration, Instant};
use super::{CrossProcessLockGeneration, FIRST_CROSS_PROCESS_LOCK_GENERATION};
#[derive(Debug)]
pub struct Lease {
holder: String,
expiration: Instant,
generation: CrossProcessLockGeneration,
}
pub fn try_take_leased_lock(
leases: &mut HashMap<String, Lease>,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Option<CrossProcessLockGeneration> {
let now = Instant::now();
let expiration = now + Duration::from_millis(lease_duration_ms.into());
match leases.entry(key.to_owned()) {
Entry::Occupied(mut entry) => {
let Lease {
holder: current_holder,
expiration: current_expiration,
generation: current_generation,
} = entry.get_mut();
if current_holder == holder {
*current_expiration = expiration;
Some(*current_generation)
} else {
if *current_expiration < now {
*current_holder = holder.to_owned();
*current_expiration = expiration;
*current_generation += 1;
Some(*current_generation)
} else {
None
}
}
}
Entry::Vacant(entry) => {
entry.insert(Lease {
holder: holder.to_owned(),
expiration: Instant::now() + Duration::from_millis(lease_duration_ms.into()),
generation: FIRST_CROSS_PROCESS_LOCK_GENERATION,
});
Some(FIRST_CROSS_PROCESS_LOCK_GENERATION)
}
}
}
}