zenoh_sync/
signal.rs

1//
2// Copyright (c) 2023 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14use std::sync::{
15    atomic::{AtomicBool, Ordering::*},
16    Arc,
17};
18
19use tokio::sync::Semaphore;
20
21#[derive(Debug, Clone)]
22pub struct Signal {
23    shared: Arc<Inner>,
24}
25
26#[derive(Debug)]
27struct Inner {
28    semaphore: Semaphore,
29    triggered: AtomicBool,
30}
31
32impl Signal {
33    pub fn new() -> Self {
34        Signal {
35            shared: Arc::new(Inner {
36                semaphore: Semaphore::new(0),
37                triggered: AtomicBool::new(false),
38            }),
39        }
40    }
41
42    pub fn trigger(&self) {
43        let result = self
44            .shared
45            .triggered
46            .compare_exchange(false, true, AcqRel, Acquire);
47
48        if result.is_ok() {
49            // The maximum # of permits is defined in tokio doc.
50            // https://docs.rs/tokio/latest/tokio/sync/struct.Semaphore.html#method.add_permits
51            self.shared.semaphore.add_permits(usize::MAX >> 3);
52        }
53    }
54
55    pub fn is_triggered(&self) -> bool {
56        self.shared.triggered.load(Acquire)
57    }
58
59    pub async fn wait(&self) {
60        if !self.is_triggered() {
61            let _ = self.shared.semaphore.acquire().await;
62        }
63    }
64}
65
66impl Default for Signal {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::time::Duration;
75
76    use super::*;
77
78    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
79    async fn signal_test() {
80        let signal = Signal::new();
81
82        // spawn publisher
83        let r#pub = tokio::task::spawn({
84            let signal = signal.clone();
85
86            async move {
87                tokio::time::sleep(Duration::from_millis(200)).await;
88                signal.trigger();
89                signal.trigger(); // second trigger should not break
90            }
91        });
92
93        // spawn subscriber that waits immediately
94        let fast_sub = tokio::task::spawn({
95            let signal = signal.clone();
96
97            async move {
98                signal.wait().await;
99            }
100        });
101
102        // spawn subscriber that waits after the publisher triggers the signal
103        let slow_sub = tokio::task::spawn({
104            let signal = signal.clone();
105
106            async move {
107                tokio::time::sleep(Duration::from_millis(400)).await;
108                signal.wait().await;
109            }
110        });
111
112        // check that the slow subscriber does not half
113        let result = tokio::time::timeout(
114            Duration::from_millis(50000),
115            futures::future::join3(r#pub, fast_sub, slow_sub),
116        )
117        .await;
118        assert!(result.is_ok());
119
120        // verify if signal is in triggered state
121        assert!(signal.is_triggered());
122    }
123}