mpmc_channel/
lib.rs

1use std::{
2    fmt,
3    ops::{Deref, DerefMut},
4    sync::{Condvar, Mutex, MutexGuard, TryLockError},
5    time::Duration,
6};
7
8#[derive(Debug, Default)]
9struct Inner<T> {
10    waiter: u32,
11    data: T,
12}
13
14#[derive(Debug, Default)]
15pub struct MPMC<T> {
16    cvar: Condvar,
17    inner: Mutex<Inner<T>>,
18}
19
20pub struct WouldBlock;
21
22impl<T> MPMC<T> {
23    pub const fn new(data: T) -> Self {
24        Self {
25            cvar: Condvar::new(),
26            inner: Mutex::new(Inner { data, waiter: 0 }),
27        }
28    }
29
30    #[inline]
31    pub fn produce(&self) -> Producer<T> {
32        let guard = self.inner.lock().unwrap();
33        Producer {
34            cvar: &self.cvar,
35            guard,
36        }
37    }
38
39    #[inline]
40    pub fn consume(&self) -> Consume<T> {
41        let guard = self.inner.lock().unwrap();
42        Consume {
43            cvar: &self.cvar,
44            guard,
45        }
46    }
47    
48    pub fn try_produce(&self) -> Result<Producer<T>, WouldBlock> {
49        let guard = match self.inner.try_lock() {
50            Ok(val) => val,
51            Err(TryLockError::WouldBlock) => return Err(WouldBlock),
52            Err(err) => panic!("{err}"),
53        };
54        Ok(Producer {
55            cvar: &self.cvar,
56            guard,
57        })
58    }
59
60    pub fn try_consume(&self) -> Result<Consume<T>, WouldBlock> {
61        let guard = match self.inner.try_lock() {
62            Ok(val) => val,
63            Err(TryLockError::WouldBlock) => return Err(WouldBlock),
64            Err(err) => panic!("{err}"),
65        };
66        Ok(Consume {
67            cvar: &self.cvar,
68            guard,
69        })
70    }
71}
72
73pub struct Consume<'a, T> {
74    cvar: &'a Condvar,
75    guard: MutexGuard<'a, Inner<T>>,
76}
77
78pub struct Producer<'a, T> {
79    cvar: &'a Condvar,
80    guard: MutexGuard<'a, Inner<T>>,
81}
82
83impl<'a, T> Consume<'a, T> {
84    pub fn wait(mut self) -> Self {
85        self.guard.waiter += 1;
86        self.guard = self.cvar.wait(self.guard).unwrap();
87        self.guard.waiter -= 1;
88        self
89    }
90
91    pub fn wait_timeout(mut self, dur: Duration) -> Result<Self, WaitTimeOut> {
92        self.guard.waiter += 1;
93        let result = self.cvar.wait_timeout(self.guard, dur).unwrap();
94        self.guard = result.0;
95        self.guard.waiter -= 1;
96        if result.1.timed_out() {
97            return Err(WaitTimeOut);
98        }
99        Ok(self)
100    }
101}
102
103impl<'a, T> Producer<'a, T> {
104    pub fn notify_one(self) {
105        if self.guard.waiter != 0 {
106            drop(self.guard);
107            self.cvar.notify_one();
108        }
109    }
110
111    pub fn notify_all(self) {
112        if self.guard.waiter != 0 {
113            drop(self.guard);
114            self.cvar.notify_all();
115        }
116    }
117}
118
119impl<'a, T> Deref for Consume<'a, T> {
120    type Target = T;
121    #[inline]
122    fn deref(&self) -> &Self::Target {
123        &self.guard.data
124    }
125}
126
127impl<'a, T> DerefMut for Consume<'a, T> {
128    #[inline]
129    fn deref_mut(&mut self) -> &mut Self::Target {
130        &mut self.guard.data
131    }
132}
133
134impl<'a, T> Deref for Producer<'a, T> {
135    type Target = T;
136    #[inline]
137    fn deref(&self) -> &Self::Target {
138        &self.guard.data
139    }
140}
141
142impl<'a, T> DerefMut for Producer<'a, T> {
143    #[inline]
144    fn deref_mut(&mut self) -> &mut Self::Target {
145        &mut self.guard.data
146    }
147}
148
149#[derive(Debug, Clone, Copy)]
150pub struct WaitTimeOut;
151
152impl std::error::Error for WaitTimeOut {}
153impl fmt::Display for WaitTimeOut {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        fmt::Debug::fmt(self, f)
156    }
157}