1use std::sync::atomic::{AtomicUsize, Ordering};
15
16use tokio::sync::Mutex;
17use zenoh_core::zasynclock;
18
19use crate::Condition;
20
21pub struct Mvar<T> {
22 inner: Mutex<Option<T>>,
23 cond_put: Condition,
24 cond_take: Condition,
25 wait_put: AtomicUsize,
26 wait_take: AtomicUsize,
27}
28
29impl<T> Mvar<T> {
30 pub fn new() -> Mvar<T> {
31 Mvar {
32 inner: Mutex::new(None),
33 cond_put: Condition::new(),
34 cond_take: Condition::new(),
35 wait_put: AtomicUsize::new(0),
36 wait_take: AtomicUsize::new(0),
37 }
38 }
39
40 pub fn has_take_waiting(&self) -> bool {
41 self.wait_take.load(Ordering::Acquire) > 0
42 }
43
44 pub fn has_put_waiting(&self) -> bool {
45 self.wait_put.load(Ordering::Acquire) > 0
46 }
47
48 pub async fn try_take(&self) -> Option<T> {
49 let mut guard = zasynclock!(self.inner);
50 if let Some(inner) = guard.take() {
51 drop(guard);
52 self.cond_put.notify_one();
53 return Some(inner);
54 }
55 None
56 }
57
58 pub async fn take(&self) -> T {
59 loop {
60 let mut guard = zasynclock!(self.inner);
61 if let Some(inner) = guard.take() {
62 self.wait_take.fetch_sub(1, Ordering::AcqRel);
63 drop(guard);
64 self.cond_put.notify_one();
65 return inner;
66 }
67 self.wait_take.fetch_add(1, Ordering::AcqRel);
68 self.cond_take.wait(guard).await;
69 }
70 }
71
72 pub async fn put(&self, inner: T) {
73 loop {
74 let mut guard = zasynclock!(self.inner);
75 if guard.is_some() {
76 self.wait_put.fetch_add(1, Ordering::AcqRel);
77 self.cond_put.wait(guard).await;
78 } else {
79 *guard = Some(inner);
80 self.wait_put.fetch_sub(1, Ordering::AcqRel);
81 drop(guard);
82 self.cond_take.notify_one();
83 return;
84 }
85 }
86 }
87}
88
89impl<T> Default for Mvar<T> {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use zenoh_result::ZResult;
98
99 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
100 async fn mvar() -> ZResult<()> {
101 use std::{sync::Arc, time::Duration};
102
103 use super::Mvar;
104
105 const TIMEOUT: Duration = Duration::from_secs(60);
106
107 let count: usize = 1_000;
108 let mvar: Arc<Mvar<usize>> = Arc::new(Mvar::new());
109
110 let c_mvar = mvar.clone();
111 let ch = tokio::task::spawn(async move {
112 for _ in 0..count {
113 let n = c_mvar.take().await;
114 print!("-{n} ");
115 }
116 });
117
118 let ph = tokio::task::spawn(async move {
119 for i in 0..count {
120 mvar.put(i).await;
121 print!("+{i} ");
122 }
123 });
124
125 let _ = tokio::time::timeout(TIMEOUT, ph).await?;
126 let _ = tokio::time::timeout(TIMEOUT, ch).await?;
127 println!();
128 Ok(())
129 }
130}