1use crate::Error;
2use std::collections::{btree_map, BTreeMap};
3use std::sync::atomic;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::{mpsc, Mutex};
7use tokio::task;
8use uuid::Uuid;
9
10const ERR_LOCK_NOT_DEFINED: &str = "Lock not defined";
11const ERR_INVALID_LOCK_TOKEN: &str = "Invalid lock token";
12
13#[derive(Debug, Clone)]
14pub struct Lock {
15 unlock_trigger: mpsc::Sender<()>,
16}
17
18impl Lock {
19 pub async fn release(&self) -> bool {
21 self.unlock_trigger.send(()).await.is_ok()
22 }
23}
24
25#[derive(Debug, Default)]
26pub struct SharedLock {
27 lock: Arc<Mutex<()>>,
28 flag: Arc<atomic::AtomicBool>,
29}
30
31impl SharedLock {
32 #[must_use]
33 pub fn new() -> Self {
34 Self::default()
35 }
36 pub async fn acquire(&self, expires: Duration) -> Lock {
37 let lock = self.lock.clone();
38 let (lock_trigger, lock_listener) = triggered::trigger();
39 let (unlock_trigger, mut unlock_listener) = mpsc::channel(1);
40 let flag = self.flag.clone();
41 task::spawn(async move {
42 let _g = lock.lock().await;
44 flag.store(true, atomic::Ordering::SeqCst);
46 lock_trigger.trigger();
47 let _ = tokio::time::timeout(expires, unlock_listener.recv()).await;
49 flag.store(false, atomic::Ordering::SeqCst);
50 });
51 lock_listener.await;
53 Lock { unlock_trigger }
54 }
55 pub fn clone_flag(&self) -> Arc<atomic::AtomicBool> {
56 self.flag.clone()
57 }
58}
59
60#[derive(Debug, Default)]
61pub struct SharedLockFactory {
62 shared_locks: BTreeMap<String, (Mutex<SharedLock>, Arc<atomic::AtomicBool>)>,
63 locks: Mutex<BTreeMap<String, (Uuid, Lock)>>,
64}
65
66impl SharedLockFactory {
67 #[must_use]
68 pub fn new() -> Self {
69 Self::default()
70 }
71 pub fn create(&mut self, lock_id: &str) -> Result<(), Error> {
75 if let btree_map::Entry::Vacant(x) = self.shared_locks.entry(lock_id.to_owned()) {
76 let slock = SharedLock::new();
77 let flag = slock.clone_flag();
78 x.insert((Mutex::new(slock), flag));
79 Ok(())
80 } else {
81 Err(Error::duplicate(format!(
82 "Shared lock {} already exists",
83 lock_id
84 )))
85 }
86 }
87 pub async fn acquire(&self, lock_id: &str, expires: Duration) -> Result<Uuid, Error> {
91 if let Some((v, _)) = self.shared_locks.get(lock_id) {
92 let lock = v.lock().await.acquire(expires).await;
94 let token = Uuid::new_v4();
95 self.locks
96 .lock()
97 .await
98 .insert(lock_id.to_owned(), (token, lock));
99 Ok(token)
100 } else {
101 Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
102 }
103 }
104 pub async fn release(&self, lock_id: &str, token: Option<&Uuid>) -> Result<bool, Error> {
108 if let Some((tok, lock)) = self.locks.lock().await.get(lock_id) {
109 if let Some(t) = token {
110 if tok != t {
111 return Err(Error::not_found(ERR_INVALID_LOCK_TOKEN));
112 }
113 }
114 Ok(lock.release().await)
115 } else {
116 Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
117 }
118 }
119 pub fn status(&self, lock_id: &str) -> Result<bool, Error> {
123 if let Some((_, flag)) = self.shared_locks.get(lock_id) {
124 Ok(flag.load(atomic::Ordering::SeqCst))
125 } else {
126 Err(Error::not_found(ERR_LOCK_NOT_DEFINED))
127 }
128 }
129 pub fn list(&self) -> Vec<(&str, bool)> {
130 let mut result = Vec::new();
131 for (id, (_, flag)) in &self.shared_locks {
132 result.push((id.as_str(), flag.load(atomic::Ordering::SeqCst)));
133 }
134 result
135 }
136}