zenoh_sync/
mvar.rs

1//
2// Copyright (c) 2023 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14use std::sync::atomic::{AtomicUsize, Ordering};
15
16use tokio::sync::Mutex;
17use zenoh_core::zasynclock;
18
19use crate::Condition;
20
21pub struct Mvar<T> {
22    inner: Mutex<Option<T>>,
23    cond_put: Condition,
24    cond_take: Condition,
25    wait_put: AtomicUsize,
26    wait_take: AtomicUsize,
27}
28
29impl<T> Mvar<T> {
30    pub fn new() -> Mvar<T> {
31        Mvar {
32            inner: Mutex::new(None),
33            cond_put: Condition::new(),
34            cond_take: Condition::new(),
35            wait_put: AtomicUsize::new(0),
36            wait_take: AtomicUsize::new(0),
37        }
38    }
39
40    pub fn has_take_waiting(&self) -> bool {
41        self.wait_take.load(Ordering::Acquire) > 0
42    }
43
44    pub fn has_put_waiting(&self) -> bool {
45        self.wait_put.load(Ordering::Acquire) > 0
46    }
47
48    pub async fn try_take(&self) -> Option<T> {
49        let mut guard = zasynclock!(self.inner);
50        if let Some(inner) = guard.take() {
51            drop(guard);
52            self.cond_put.notify_one();
53            return Some(inner);
54        }
55        None
56    }
57
58    pub async fn take(&self) -> T {
59        loop {
60            let mut guard = zasynclock!(self.inner);
61            if let Some(inner) = guard.take() {
62                self.wait_take.fetch_sub(1, Ordering::AcqRel);
63                drop(guard);
64                self.cond_put.notify_one();
65                return inner;
66            }
67            self.wait_take.fetch_add(1, Ordering::AcqRel);
68            self.cond_take.wait(guard).await;
69        }
70    }
71
72    pub async fn put(&self, inner: T) {
73        loop {
74            let mut guard = zasynclock!(self.inner);
75            if guard.is_some() {
76                self.wait_put.fetch_add(1, Ordering::AcqRel);
77                self.cond_put.wait(guard).await;
78            } else {
79                *guard = Some(inner);
80                self.wait_put.fetch_sub(1, Ordering::AcqRel);
81                drop(guard);
82                self.cond_take.notify_one();
83                return;
84            }
85        }
86    }
87}
88
89impl<T> Default for Mvar<T> {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use zenoh_result::ZResult;
98
99    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
100    async fn mvar() -> ZResult<()> {
101        use std::{sync::Arc, time::Duration};
102
103        use super::Mvar;
104
105        const TIMEOUT: Duration = Duration::from_secs(60);
106
107        let count: usize = 1_000;
108        let mvar: Arc<Mvar<usize>> = Arc::new(Mvar::new());
109
110        let c_mvar = mvar.clone();
111        let ch = tokio::task::spawn(async move {
112            for _ in 0..count {
113                let n = c_mvar.take().await;
114                print!("-{n} ");
115            }
116        });
117
118        let ph = tokio::task::spawn(async move {
119            for i in 0..count {
120                mvar.put(i).await;
121                print!("+{i} ");
122            }
123        });
124
125        let _ = tokio::time::timeout(TIMEOUT, ph).await?;
126        let _ = tokio::time::timeout(TIMEOUT, ch).await?;
127        println!();
128        Ok(())
129    }
130}