maybe_fut/api/sync/
barrier.rs

1use crate::maybe_fut_constructor_sync;
2
3/// A barrier enables multiple threads to synchronize the beginning of some computation.
4#[derive(Debug, Unwrap)]
5#[unwrap_types(
6    std(std::sync::Barrier),
7    tokio(tokio::sync::Barrier),
8    tokio_gated("tokio-sync")
9)]
10pub struct Barrier(BarrierInner);
11
12/// Inner wrapper for [`Barrier`].
13#[derive(Debug)]
14enum BarrierInner {
15    /// Std barrier.
16    Std(std::sync::Barrier),
17    /// Tokio barrier.
18    #[cfg(tokio_sync)]
19    #[cfg_attr(docsrs, doc(cfg(feature = "tokio-sync")))]
20    Tokio(tokio::sync::Barrier),
21}
22
23impl From<std::sync::Barrier> for Barrier {
24    fn from(barrier: std::sync::Barrier) -> Self {
25        Self(BarrierInner::Std(barrier))
26    }
27}
28
29#[cfg(tokio_sync)]
30#[cfg_attr(docsrs, doc(cfg(feature = "tokio-sync")))]
31impl From<tokio::sync::Barrier> for Barrier {
32    fn from(barrier: tokio::sync::Barrier) -> Self {
33        Self(BarrierInner::Tokio(barrier))
34    }
35}
36
37impl Barrier {
38    maybe_fut_constructor_sync!(
39        /// Creates a new barrier that can block a given number of threads.
40        ///
41        /// A barrier will block n-1 threads which call [`Self::wait`] and then wake up all threads at once when the `n`th thread calls [`Self::wait`].
42        new(n: usize) -> Self,
43        std::sync::Barrier::new,
44        tokio::sync::Barrier::new,
45        tokio_sync
46    );
47
48    /// Blocks the current thread until all threads have rendezvoused here.
49    ///
50    /// Barriers are re-usable after all threads have rendezvoused once, and can be used continuously.
51    pub async fn wait(&self) -> BarrierWaitResult {
52        match &self.0 {
53            BarrierInner::Std(barrier) => barrier.wait().into(),
54            #[cfg(tokio_sync)]
55            BarrierInner::Tokio(barrier) => barrier.wait().await.into(),
56        }
57    }
58}
59
60/// Result of a [`Barrier`] [`Barrier::wait`] operation.
61#[derive(Debug)]
62pub struct BarrierWaitResult(InnerBarrierWaitResult);
63
64/// Inner wrapper for [`BarrierWaitResult`].
65#[derive(Debug)]
66enum InnerBarrierWaitResult {
67    /// Std barrier wait result.
68    Std(std::sync::BarrierWaitResult),
69    /// Tokio barrier wait result.
70    #[cfg(tokio_sync)]
71    Tokio(tokio::sync::BarrierWaitResult),
72}
73
74impl From<std::sync::BarrierWaitResult> for BarrierWaitResult {
75    fn from(result: std::sync::BarrierWaitResult) -> Self {
76        Self(InnerBarrierWaitResult::Std(result))
77    }
78}
79
80#[cfg(tokio_sync)]
81impl From<tokio::sync::BarrierWaitResult> for BarrierWaitResult {
82    fn from(result: tokio::sync::BarrierWaitResult) -> Self {
83        Self(InnerBarrierWaitResult::Tokio(result))
84    }
85}
86
87impl BarrierWaitResult {
88    /// Returns `true` if this thread is the "leader thread" for the call to [`Barrier::wait`].
89    ///
90    /// Only one thread will have `true` returned from their result, all other threads will have `false` returned.
91    pub fn is_leader(&self) -> bool {
92        match &self.0 {
93            InnerBarrierWaitResult::Std(result) => result.is_leader(),
94            #[cfg(tokio_sync)]
95            InnerBarrierWaitResult::Tokio(result) => result.is_leader(),
96        }
97    }
98}
99
100#[cfg(test)]
101mod test {
102
103    use super::*;
104
105    #[test]
106    fn test_should_create_barrier_sync() {
107        let barrier = Barrier::new(1);
108        assert!(matches!(barrier.0, BarrierInner::Std(_)));
109    }
110
111    #[cfg(tokio_sync)]
112    #[tokio::test]
113    async fn test_should_create_barrier_async() {
114        let barrier = Barrier::new(1);
115        assert!(matches!(barrier.0, BarrierInner::Tokio(_)));
116    }
117
118    #[test]
119    fn test_should_create_barrier_wait_result_sync() {
120        let barrier = Barrier::new(1);
121        let result = crate::SyncRuntime::block_on(barrier.wait());
122        assert!(matches!(result.0, InnerBarrierWaitResult::Std(_)));
123    }
124
125    #[cfg(tokio_sync)]
126    #[tokio::test]
127    async fn test_should_create_barrier_wait_result_async() {
128        let barrier = Barrier::new(1);
129        let result = barrier.wait().await;
130        assert!(matches!(result.0, InnerBarrierWaitResult::Tokio(_)));
131    }
132}