1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use crate::Error;
use std::collections::{btree_map, BTreeMap};
use std::sync::atomic;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex};
use tokio::task;
use uuid::Uuid;

const ERR_LOCK_NOT_DEFINED: &str = "Lock not defined";
const ERR_INVALID_LOCK_TOKEN: &str = "Invalid lock token";

#[derive(Debug, Clone)]
pub struct Lock {
    unlock_trigger: mpsc::Sender<()>,
}

impl Lock {
    /// Returns true if released, false if not locked
    pub async fn release(&self) -> bool {
        self.unlock_trigger.send(()).await.is_ok()
    }
}

#[derive(Debug, Default)]
pub struct SharedLock {
    lock: Arc<Mutex<()>>,
    flag: Arc<atomic::AtomicBool>,
}

impl SharedLock {
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }
    pub async fn acquire(&self, expires: Duration) -> Lock {
        let lock = self.lock.clone();
        let (lock_trigger, lock_listener) = triggered::trigger();
        let (unlock_trigger, mut unlock_listener) = mpsc::channel(1);
        let flag = self.flag.clone();
        task::spawn(async move {
            // guard moved here
            let _g = lock.lock().await;
            // triggered as soon as the lock is acquired
            flag.store(true, atomic::Ordering::SeqCst);
            lock_trigger.trigger();
            // exited as soon as unlocked or expired or unlock_trigger dropped
            let _ = tokio::time::timeout(expires, unlock_listener.recv()).await;
            flag.store(false, atomic::Ordering::SeqCst);
        });
        // want lock to be acquired
        lock_listener.await;
        Lock { unlock_trigger }
    }
    pub fn clone_flag(&self) -> Arc<atomic::AtomicBool> {
        self.flag.clone()
    }
}

#[derive(Debug, Default)]
pub struct SharedLockFactory {
    shared_locks: BTreeMap<String, (Mutex<SharedLock>, Arc<atomic::AtomicBool>)>,
    locks: Mutex<BTreeMap<String, (Uuid, Lock)>>,
}

impl SharedLockFactory {
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }
    /// # Errors
    ///
    /// Will return `Err` if the lock already exists
    pub fn create(&mut self, lock_id: &str) -> Result<(), Error> {
        if let btree_map::Entry::Vacant(x) = self.shared_locks.entry(lock_id.to_owned()) {
            let slock = SharedLock::new();
            let flag = slock.clone_flag();
            x.insert((Mutex::new(slock), flag));
            Ok(())
        } else {
            Err(Error::duplicate(format!(
                "Shared lock {} already exists",
                lock_id
            )))
        }
    }
    /// # Errors
    ///
    /// Will return `Err` if the lock is not defined
    pub async fn acquire(&self, lock_id: &str, expires: Duration) -> Result<Uuid, Error> {
        if let Some((v, _)) = self.shared_locks.get(lock_id) {
            // wait for the lock and block other futures accessing it
            let lock = v.lock().await.acquire(expires).await;
            let token = Uuid::new_v4();
            self.locks
                .lock()
                .await
                .insert(lock_id.to_owned(), (token, lock));
            Ok(token)
        } else {
            Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
        }
    }
    /// # Errors
    ///
    /// Will return `Err` if the token is invalid, None forcibly releases the lock
    pub async fn release(&self, lock_id: &str, token: Option<&Uuid>) -> Result<bool, Error> {
        if let Some((tok, lock)) = self.locks.lock().await.get(lock_id) {
            if let Some(t) = token {
                if tok != t {
                    return Err(Error::not_found(ERR_INVALID_LOCK_TOKEN));
                }
            }
            Ok(lock.release().await)
        } else {
            Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
        }
    }
    /// # Errors
    ///
    /// Will return `Err` if the lock is not defined
    pub fn status(&self, lock_id: &str) -> Result<bool, Error> {
        if let Some((_, flag)) = self.shared_locks.get(lock_id) {
            Ok(flag.load(atomic::Ordering::SeqCst))
        } else {
            Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
        }
    }
    pub fn list(&self) -> Vec<(&str, bool)> {
        let mut result = Vec::new();
        for (id, (_, flag)) in &self.shared_locks {
            result.push((id.as_str(), flag.load(atomic::Ordering::SeqCst)));
        }
        result
    }
}