irox_threading/
mpmc.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2023 IROX Contributors
3//
4
5//!
6//! Multi-Producer, Multi-Consumer
7//!
8
9use std::collections::VecDeque;
10use std::fmt::{Debug, Formatter};
11use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
12use std::sync::{Arc, Condvar, Mutex};
13
14use log::trace;
15
16use crate::TaskError;
17
18///
19/// Exchanger errors
20pub enum ExchangerError<T> {
21    /// Error with the underlying executor
22    TaskError(TaskError),
23    /// Exchanger cannot accept an element as it is full.
24    ExchangerFull(T),
25    /// Exchanger cannot deliver an element as it is empty.
26    ExchangerEmpty,
27}
28
29impl<T> Debug for ExchangerError<T> {
30    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31        match self {
32            ExchangerError::TaskError(e) => {
33                write!(f, "TaskError: {e:?}")
34            }
35            ExchangerError::ExchangerFull(_) => {
36                write!(f, "ExchangerFull")
37            }
38            ExchangerError::ExchangerEmpty => {
39                write!(f, "ExchangerEmpty")
40            }
41        }
42    }
43}
44
45impl<T> PartialEq for ExchangerError<T> {
46    fn eq(&self, other: &Self) -> bool {
47        match self {
48            ExchangerError::TaskError(e) => {
49                if let ExchangerError::TaskError(e2) = other {
50                    return e == e2;
51                }
52                false
53            }
54            ExchangerError::ExchangerFull(_) => {
55                matches!(other, ExchangerError::ExchangerFull(_))
56            }
57            ExchangerError::ExchangerEmpty => {
58                matches!(other, ExchangerError::ExchangerEmpty)
59            }
60        }
61    }
62}
63
64struct InnerExchange<T: Send> {
65    mutex: Mutex<VecDeque<T>>,
66    take_condition: Condvar,
67    put_condition: Condvar,
68    shutdown: AtomicBool,
69    max_size: usize,
70    num_waiting_takers: AtomicU16,
71    num_waiting_putters: AtomicU16,
72}
73
74impl<T: Send> InnerExchange<T> {
75    /// Creates a new InnerExchange.
76    pub fn new(max_size: usize) -> Self {
77        InnerExchange {
78            max_size,
79            mutex: Default::default(),
80            take_condition: Default::default(),
81            put_condition: Default::default(),
82            shutdown: AtomicBool::new(false),
83            num_waiting_takers: AtomicU16::new(0),
84            num_waiting_putters: AtomicU16::new(0),
85        }
86    }
87
88    ///
89    /// Takes and returns an element, blocking if none are available.
90    pub fn take_blocking(&self) -> Result<T, ExchangerError<T>> {
91        let Ok(mut elems) = self.mutex.lock() else {
92            return Err(ExchangerError::TaskError(TaskError::LockingError));
93        };
94        if let Some(e) = elems.pop_front() {
95            trace!("Take_blocking popped one");
96            self.put_condition.notify_one();
97            return Ok(e);
98        }
99        // If there's still a few elements left in the queue, at this point they can be safely
100        // flushed while shutting down, since no new elements can be added if this exchanger
101        // is shutdown.  Prevent the waiter from waiting if we're shutting down.
102        if self.shutdown.load(Ordering::SeqCst) {
103            return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
104        }
105        trace!("Take_blocking waiting for element");
106        self.num_waiting_takers.fetch_add(1, Ordering::SeqCst);
107        let Ok(mut elems) = self.take_condition.wait_while(elems, |e| {
108            e.is_empty() && !self.shutdown.load(Ordering::SeqCst)
109        }) else {
110            return Err(ExchangerError::TaskError(TaskError::LockingError));
111        };
112        self.num_waiting_takers.fetch_sub(1, Ordering::SeqCst);
113
114        let Some(e) = elems.pop_front() else {
115            trace!("Take_blocking woken up for empty exchange");
116            // empty wakeup?  probably stopping.
117            return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
118        };
119        trace!("Take_blocking woken up for new element");
120        self.put_condition.notify_one();
121        Ok(e)
122    }
123
124    ///
125    /// Attempts to take one from this exchanger.  If one is available, it is returned. This
126    /// function will not block, and if it would have blocked, returns
127    /// [`Err(TaskError::ExchangerEmpty)`]
128    pub fn try_take(&self) -> Result<T, ExchangerError<T>> {
129        let Ok(mut elems) = self.mutex.lock() else {
130            return Err(ExchangerError::TaskError(TaskError::LockingError));
131        };
132        if let Some(e) = elems.pop_front() {
133            trace!("Take_blocking popped one");
134            self.put_condition.notify_one();
135            return Ok(e);
136        }
137        // If there's still a few elements left in the queue, at this point they can be safely
138        // flushed while shutting down, since no new elements can be added if this exchanger
139        // is shutdown.  Prevent the waiter from waiting if we're shutting down.
140        if self.shutdown.load(Ordering::SeqCst) {
141            return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
142        }
143        Err(ExchangerError::ExchangerEmpty)
144    }
145
146    ///
147    /// Puts an element into this exchanger to allow exchanges to occur.  Will block indefinitely
148    /// until there's enough space to put one in.
149    pub fn put_blocking(&self, elem: T) -> Result<(), ExchangerError<T>> {
150        // prevent any new elements from being put in if we're shutting down.
151        if self.shutdown.load(Ordering::SeqCst) {
152            return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
153        }
154
155        let Ok(mut elems) = self.mutex.lock() else {
156            return Err(ExchangerError::TaskError(TaskError::LockingError));
157        };
158        if elems.len() < self.max_size {
159            trace!("Put_blocking added one");
160            elems.push_back(elem);
161            self.take_condition.notify_one();
162            return Ok(());
163        }
164        trace!("Put_blocking full, waiting for empty spot");
165        // full, wait until a spot opens.
166        self.num_waiting_putters.fetch_add(1, Ordering::SeqCst);
167        let Ok(mut elems) = self.put_condition.wait_while(elems, |e| {
168            e.len() >= self.max_size && !self.shutdown.load(Ordering::SeqCst)
169        }) else {
170            return Err(ExchangerError::TaskError(TaskError::LockingError));
171        };
172        self.num_waiting_putters.fetch_sub(1, Ordering::SeqCst);
173
174        if elems.len() == self.max_size {
175            trace!("Put_blocking woken up for full, cannot add new element");
176            return Err(ExchangerError::ExchangerFull(elem));
177        };
178
179        trace!("Put_blocking woken up for free space");
180        elems.push_back(elem);
181        self.take_condition.notify_one();
182        Ok(())
183    }
184
185    ///
186    /// Attempts to put a new element into the exchanger, returning [`Ok`] if it was successful.
187    /// This function will not block, if the call would have blocked, returns
188    /// [`Err(ExchangerError::ExchangerFull(T))`] so you can have the element back.
189    pub fn try_put(&self, elem: T) -> Result<(), ExchangerError<T>> {
190        // prevent any new elements from being put in if we're shutting down.
191        if self.shutdown.load(Ordering::SeqCst) {
192            return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
193        }
194
195        let Ok(mut elems) = self.mutex.lock() else {
196            return Err(ExchangerError::TaskError(TaskError::LockingError));
197        };
198        if elems.len() < self.max_size {
199            trace!("try_put added one");
200            elems.push_back(elem);
201            self.take_condition.notify_one();
202            return Ok(());
203        }
204        Err(ExchangerError::ExchangerFull(elem))
205    }
206
207    ///
208    /// Stops this exchanger, waking up all waiters and
209    pub fn shutdown(&self) {
210        self.shutdown.store(true, Ordering::SeqCst);
211        while self.num_waiting_putters.load(Ordering::SeqCst) > 0 {
212            self.put_condition.notify_all();
213        }
214        while self.num_waiting_takers.load(Ordering::SeqCst) > 0 {
215            self.take_condition.notify_all();
216        }
217    }
218}
219
220///
221/// A thread-safe, shared exchange buffer.  Allows multiple producers to push elements in, up to a
222/// blocking max capacity.  Allows multiple consumers to take elements out, blocking if none
223/// available.
224pub struct Exchanger<T: Send> {
225    exchange: Arc<InnerExchange<T>>,
226}
227
228impl<T: Send> Clone for Exchanger<T> {
229    fn clone(&self) -> Self {
230        Exchanger {
231            exchange: self.exchange.clone(),
232        }
233    }
234}
235
236impl<T: Send> Exchanger<T> {
237    ///
238    /// Creates a new exchanger, that can store at most these elements in the queue.
239    ///
240    /// Note:  Putting a value of `0`/zero for max_size implies that this exchanger will do no work
241    /// and exchange no items.  A recommended minimum value of `1` should be used instead.
242    pub fn new(max_size: usize) -> Self {
243        Exchanger {
244            exchange: Arc::new(InnerExchange::new(max_size)),
245        }
246    }
247
248    ///
249    /// Push a new element into the exchanger, blocking until space is available.
250    pub fn push(&self, elem: T) -> Result<(), ExchangerError<T>> {
251        self.exchange.put_blocking(elem)
252    }
253
254    pub fn try_push(&self, elem: T) -> Result<(), ExchangerError<T>> {
255        self.exchange.try_put(elem)
256    }
257
258    ///
259    /// Take a new element from the exchanger, blocking until one is available.
260    pub fn take(&self) -> Result<T, ExchangerError<T>> {
261        self.exchange.take_blocking()
262    }
263
264    pub fn try_take(&self) -> Result<T, ExchangerError<T>> {
265        self.exchange.try_take()
266    }
267
268    ///
269    /// Shuts down this exchanger, preventing new pushes.  Any objects already pushed will be
270    /// permitted to be taken, and once empty, takers will receive a
271    /// [`TaskError::ExecutorStoppingError`]
272    pub fn shutdown(&self) {
273        self.exchange.shutdown();
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use std::sync::atomic::{AtomicU64, Ordering};
280    use std::sync::Arc;
281    use std::thread::JoinHandle;
282    use std::time::Duration;
283
284    use log::{error, info, Level};
285
286    use crate::{Exchanger, ExchangerError, TaskError};
287
288    #[test]
289    pub fn test_single_sender_receiver() -> Result<(), ExchangerError<u32>> {
290        // irox_log::init_console_level(Level::Trace);
291        let (err_sender, err_receiver) = std::sync::mpsc::channel();
292        let err_sender2 = err_sender.clone();
293        let exch1 = Exchanger::<u32>::new(10);
294        let exch2 = exch1.clone();
295        let genthrd = std::thread::Builder::new()
296            .name("Sender".to_string())
297            .spawn(move || {
298                let mut sent = 0;
299                for i in 0..1_000 {
300                    if let Err(e) = exch2.push(i) {
301                        eprintln!("Error sending exchange: {e:?}");
302                        if let Err(e) = err_sender.send((e, i, "send")) {
303                            panic!("{e:?}");
304                        }
305                    }
306                    sent += 1;
307                }
308                println!("Sent {sent}");
309            })
310            .unwrap();
311
312        let recv_thrd = std::thread::Builder::new()
313            .name("Receiver".to_string())
314            .spawn(move || {
315                let mut recvd = 0;
316                for i in 0..1_000 {
317                    if let Err(e) = exch1.take() {
318                        eprintln!("Error receiving exchange: {e:?}");
319                        if let Err(e) = err_sender2.send((e, i, "recv")) {
320                            panic!("{e:?}");
321                        }
322                    }
323                    std::thread::sleep(Duration::from_millis(1)); // simulate work.
324                    recvd += 1;
325                }
326                println!("Received {recvd}");
327            })
328            .unwrap();
329
330        genthrd.join().unwrap();
331        recv_thrd.join().unwrap();
332
333        let mut errors: bool = false;
334        while let Ok(r) = err_receiver.recv() {
335            let (e, i, s) = r;
336            eprintln!("Error received {e:?} : {i} : {s}");
337            errors = true;
338        }
339
340        assert!(!errors);
341
342        Ok(())
343    }
344
345    #[test]
346    pub fn test_multiple_receivers() {
347        irox_log::init_console_level(Level::Info);
348        let (err_sender, err_receiver) = std::sync::mpsc::channel();
349        let err_sender2 = err_sender.clone();
350        let exch1 = Exchanger::<u32>::new(10);
351        let exch2 = exch1.clone();
352        let exch3 = exch1.clone();
353        let genthrd = std::thread::Builder::new()
354            .name("Sender".to_string())
355            .spawn(move || {
356                let mut sent = 0;
357                for i in 0..1_000_000 {
358                    if let Err(e) = exch2.push(i) {
359                        eprintln!("Error sending exchange: {e:?}");
360                        if let Err(e) = err_sender.send((e, i, "send")) {
361                            panic!("{e:?}");
362                        }
363                    }
364                    sent += 1;
365                }
366                info!("Sent {sent}");
367            })
368            .unwrap();
369
370        let recv_count = Arc::new(AtomicU64::new(0));
371        let mut receivers: Vec<JoinHandle<()>> = Vec::new();
372        for thread_idx in 0..10 {
373            let counter = recv_count.clone();
374            let err_sender2 = err_sender2.clone();
375            let exch1 = exch1.clone();
376            let recv_thrd = std::thread::Builder::new()
377                .name(format!("Receiver {thread_idx}"))
378                .spawn(move || {
379                    let counter = counter;
380
381                    let mut recvd = 0;
382                    loop {
383                        if let Err(e) = exch1.take() {
384                            if e == ExchangerError::TaskError(TaskError::ExecutorStoppingError) {
385                                // it's a good thing!
386                                break;
387                            }
388                            error!("Error receiving exchange: {e:?}");
389                            if let Err(e) = err_sender2.send((e, recvd, "recv")) {
390                                panic!("Error sending error: {e:?}");
391                            }
392                            break;
393                        }
394                        // std::thread::sleep(Duration::from_millis(1));// simulate work.
395                        recvd += 1;
396                        counter.fetch_add(1, Ordering::Relaxed);
397                    }
398                    info!(
399                        "Received {recvd} in thread {}",
400                        std::thread::current().name().unwrap_or("")
401                    );
402                })
403                .unwrap();
404            receivers.push(recv_thrd);
405        }
406        drop(err_sender2);
407
408        genthrd.join().unwrap();
409        info!("Generator thread joined");
410        exch3.shutdown();
411        info!("Executor shutdown");
412
413        for recv in receivers {
414            info!("Waiting on {}", recv.thread().name().unwrap_or(""));
415            recv.join().unwrap();
416        }
417
418        let mut errors: bool = false;
419        while let Ok(r) = err_receiver.recv() {
420            let (e, i, s) = r;
421            error!("Error received {e:?} : {i} : {s}");
422            errors = true;
423        }
424
425        assert!(!errors);
426    }
427}