bmart/
sync.rs

1use crate::Error;
2use std::collections::{btree_map, BTreeMap};
3use std::sync::atomic;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::{mpsc, Mutex};
7use tokio::task;
8use uuid::Uuid;
9
10const ERR_LOCK_NOT_DEFINED: &str = "Lock not defined";
11const ERR_INVALID_LOCK_TOKEN: &str = "Invalid lock token";
12
13#[derive(Debug, Clone)]
14pub struct Lock {
15    unlock_trigger: mpsc::Sender<()>,
16}
17
18impl Lock {
19    /// Returns true if released, false if not locked
20    pub async fn release(&self) -> bool {
21        self.unlock_trigger.send(()).await.is_ok()
22    }
23}
24
25#[derive(Debug, Default)]
26pub struct SharedLock {
27    lock: Arc<Mutex<()>>,
28    flag: Arc<atomic::AtomicBool>,
29}
30
31impl SharedLock {
32    #[must_use]
33    pub fn new() -> Self {
34        Self::default()
35    }
36    pub async fn acquire(&self, expires: Duration) -> Lock {
37        let lock = self.lock.clone();
38        let (lock_trigger, lock_listener) = triggered::trigger();
39        let (unlock_trigger, mut unlock_listener) = mpsc::channel(1);
40        let flag = self.flag.clone();
41        task::spawn(async move {
42            // guard moved here
43            let _g = lock.lock().await;
44            // triggered as soon as the lock is acquired
45            flag.store(true, atomic::Ordering::SeqCst);
46            lock_trigger.trigger();
47            // exited as soon as unlocked or expired or unlock_trigger dropped
48            let _ = tokio::time::timeout(expires, unlock_listener.recv()).await;
49            flag.store(false, atomic::Ordering::SeqCst);
50        });
51        // want lock to be acquired
52        lock_listener.await;
53        Lock { unlock_trigger }
54    }
55    pub fn clone_flag(&self) -> Arc<atomic::AtomicBool> {
56        self.flag.clone()
57    }
58}
59
60#[derive(Debug, Default)]
61pub struct SharedLockFactory {
62    shared_locks: BTreeMap<String, (Mutex<SharedLock>, Arc<atomic::AtomicBool>)>,
63    locks: Mutex<BTreeMap<String, (Uuid, Lock)>>,
64}
65
66impl SharedLockFactory {
67    #[must_use]
68    pub fn new() -> Self {
69        Self::default()
70    }
71    /// # Errors
72    ///
73    /// Will return `Err` if the lock already exists
74    pub fn create(&mut self, lock_id: &str) -> Result<(), Error> {
75        if let btree_map::Entry::Vacant(x) = self.shared_locks.entry(lock_id.to_owned()) {
76            let slock = SharedLock::new();
77            let flag = slock.clone_flag();
78            x.insert((Mutex::new(slock), flag));
79            Ok(())
80        } else {
81            Err(Error::duplicate(format!(
82                "Shared lock {} already exists",
83                lock_id
84            )))
85        }
86    }
87    /// # Errors
88    ///
89    /// Will return `Err` if the lock is not defined
90    pub async fn acquire(&self, lock_id: &str, expires: Duration) -> Result<Uuid, Error> {
91        if let Some((v, _)) = self.shared_locks.get(lock_id) {
92            // wait for the lock and block other futures accessing it
93            let lock = v.lock().await.acquire(expires).await;
94            let token = Uuid::new_v4();
95            self.locks
96                .lock()
97                .await
98                .insert(lock_id.to_owned(), (token, lock));
99            Ok(token)
100        } else {
101            Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
102        }
103    }
104    /// # Errors
105    ///
106    /// Will return `Err` if the token is invalid, None forcibly releases the lock
107    pub async fn release(&self, lock_id: &str, token: Option<&Uuid>) -> Result<bool, Error> {
108        if let Some((tok, lock)) = self.locks.lock().await.get(lock_id) {
109            if let Some(t) = token {
110                if tok != t {
111                    return Err(Error::not_found(ERR_INVALID_LOCK_TOKEN));
112                }
113            }
114            Ok(lock.release().await)
115        } else {
116            Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
117        }
118    }
119    /// # Errors
120    ///
121    /// Will return `Err` if the lock is not defined
122    pub fn status(&self, lock_id: &str) -> Result<bool, Error> {
123        if let Some((_, flag)) = self.shared_locks.get(lock_id) {
124            Ok(flag.load(atomic::Ordering::SeqCst))
125        } else {
126            Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
127        }
128    }
129    pub fn list(&self) -> Vec<(&str, bool)> {
130        let mut result = Vec::new();
131        for (id, (_, flag)) in &self.shared_locks {
132            result.push((id.as_str(), flag.load(atomic::Ordering::SeqCst)));
133        }
134        result
135    }
136}