1use std::collections::VecDeque;
10use std::fmt::{Debug, Formatter};
11use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
12use std::sync::{Arc, Condvar, Mutex};
13
14use log::trace;
15
16use crate::TaskError;
17
18pub enum ExchangerError<T> {
21 TaskError(TaskError),
23 ExchangerFull(T),
25 ExchangerEmpty,
27}
28
29impl<T> Debug for ExchangerError<T> {
30 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31 match self {
32 ExchangerError::TaskError(e) => {
33 write!(f, "TaskError: {e:?}")
34 }
35 ExchangerError::ExchangerFull(_) => {
36 write!(f, "ExchangerFull")
37 }
38 ExchangerError::ExchangerEmpty => {
39 write!(f, "ExchangerEmpty")
40 }
41 }
42 }
43}
44
45impl<T> PartialEq for ExchangerError<T> {
46 fn eq(&self, other: &Self) -> bool {
47 match self {
48 ExchangerError::TaskError(e) => {
49 if let ExchangerError::TaskError(e2) = other {
50 return e == e2;
51 }
52 false
53 }
54 ExchangerError::ExchangerFull(_) => {
55 matches!(other, ExchangerError::ExchangerFull(_))
56 }
57 ExchangerError::ExchangerEmpty => {
58 matches!(other, ExchangerError::ExchangerEmpty)
59 }
60 }
61 }
62}
63
64struct InnerExchange<T: Send> {
65 mutex: Mutex<VecDeque<T>>,
66 take_condition: Condvar,
67 put_condition: Condvar,
68 shutdown: AtomicBool,
69 max_size: usize,
70 num_waiting_takers: AtomicU16,
71 num_waiting_putters: AtomicU16,
72}
73
74impl<T: Send> InnerExchange<T> {
75 pub fn new(max_size: usize) -> Self {
77 InnerExchange {
78 max_size,
79 mutex: Default::default(),
80 take_condition: Default::default(),
81 put_condition: Default::default(),
82 shutdown: AtomicBool::new(false),
83 num_waiting_takers: AtomicU16::new(0),
84 num_waiting_putters: AtomicU16::new(0),
85 }
86 }
87
88 pub fn take_blocking(&self) -> Result<T, ExchangerError<T>> {
91 let Ok(mut elems) = self.mutex.lock() else {
92 return Err(ExchangerError::TaskError(TaskError::LockingError));
93 };
94 if let Some(e) = elems.pop_front() {
95 trace!("Take_blocking popped one");
96 self.put_condition.notify_one();
97 return Ok(e);
98 }
99 if self.shutdown.load(Ordering::SeqCst) {
103 return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
104 }
105 trace!("Take_blocking waiting for element");
106 self.num_waiting_takers.fetch_add(1, Ordering::SeqCst);
107 let Ok(mut elems) = self.take_condition.wait_while(elems, |e| {
108 e.is_empty() && !self.shutdown.load(Ordering::SeqCst)
109 }) else {
110 return Err(ExchangerError::TaskError(TaskError::LockingError));
111 };
112 self.num_waiting_takers.fetch_sub(1, Ordering::SeqCst);
113
114 let Some(e) = elems.pop_front() else {
115 trace!("Take_blocking woken up for empty exchange");
116 return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
118 };
119 trace!("Take_blocking woken up for new element");
120 self.put_condition.notify_one();
121 Ok(e)
122 }
123
124 pub fn try_take(&self) -> Result<T, ExchangerError<T>> {
129 let Ok(mut elems) = self.mutex.lock() else {
130 return Err(ExchangerError::TaskError(TaskError::LockingError));
131 };
132 if let Some(e) = elems.pop_front() {
133 trace!("Take_blocking popped one");
134 self.put_condition.notify_one();
135 return Ok(e);
136 }
137 if self.shutdown.load(Ordering::SeqCst) {
141 return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
142 }
143 Err(ExchangerError::ExchangerEmpty)
144 }
145
146 pub fn put_blocking(&self, elem: T) -> Result<(), ExchangerError<T>> {
150 if self.shutdown.load(Ordering::SeqCst) {
152 return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
153 }
154
155 let Ok(mut elems) = self.mutex.lock() else {
156 return Err(ExchangerError::TaskError(TaskError::LockingError));
157 };
158 if elems.len() < self.max_size {
159 trace!("Put_blocking added one");
160 elems.push_back(elem);
161 self.take_condition.notify_one();
162 return Ok(());
163 }
164 trace!("Put_blocking full, waiting for empty spot");
165 self.num_waiting_putters.fetch_add(1, Ordering::SeqCst);
167 let Ok(mut elems) = self.put_condition.wait_while(elems, |e| {
168 e.len() >= self.max_size && !self.shutdown.load(Ordering::SeqCst)
169 }) else {
170 return Err(ExchangerError::TaskError(TaskError::LockingError));
171 };
172 self.num_waiting_putters.fetch_sub(1, Ordering::SeqCst);
173
174 if elems.len() == self.max_size {
175 trace!("Put_blocking woken up for full, cannot add new element");
176 return Err(ExchangerError::ExchangerFull(elem));
177 };
178
179 trace!("Put_blocking woken up for free space");
180 elems.push_back(elem);
181 self.take_condition.notify_one();
182 Ok(())
183 }
184
185 pub fn try_put(&self, elem: T) -> Result<(), ExchangerError<T>> {
190 if self.shutdown.load(Ordering::SeqCst) {
192 return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
193 }
194
195 let Ok(mut elems) = self.mutex.lock() else {
196 return Err(ExchangerError::TaskError(TaskError::LockingError));
197 };
198 if elems.len() < self.max_size {
199 trace!("try_put added one");
200 elems.push_back(elem);
201 self.take_condition.notify_one();
202 return Ok(());
203 }
204 Err(ExchangerError::ExchangerFull(elem))
205 }
206
207 pub fn shutdown(&self) {
210 self.shutdown.store(true, Ordering::SeqCst);
211 while self.num_waiting_putters.load(Ordering::SeqCst) > 0 {
212 self.put_condition.notify_all();
213 }
214 while self.num_waiting_takers.load(Ordering::SeqCst) > 0 {
215 self.take_condition.notify_all();
216 }
217 }
218}
219
220pub struct Exchanger<T: Send> {
225 exchange: Arc<InnerExchange<T>>,
226}
227
228impl<T: Send> Clone for Exchanger<T> {
229 fn clone(&self) -> Self {
230 Exchanger {
231 exchange: self.exchange.clone(),
232 }
233 }
234}
235
236impl<T: Send> Exchanger<T> {
237 pub fn new(max_size: usize) -> Self {
243 Exchanger {
244 exchange: Arc::new(InnerExchange::new(max_size)),
245 }
246 }
247
248 pub fn push(&self, elem: T) -> Result<(), ExchangerError<T>> {
251 self.exchange.put_blocking(elem)
252 }
253
254 pub fn try_push(&self, elem: T) -> Result<(), ExchangerError<T>> {
255 self.exchange.try_put(elem)
256 }
257
258 pub fn take(&self) -> Result<T, ExchangerError<T>> {
261 self.exchange.take_blocking()
262 }
263
264 pub fn try_take(&self) -> Result<T, ExchangerError<T>> {
265 self.exchange.try_take()
266 }
267
268 pub fn shutdown(&self) {
273 self.exchange.shutdown();
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use std::sync::atomic::{AtomicU64, Ordering};
280 use std::sync::Arc;
281 use std::thread::JoinHandle;
282 use std::time::Duration;
283
284 use log::{error, info, Level};
285
286 use crate::{Exchanger, ExchangerError, TaskError};
287
288 #[test]
289 pub fn test_single_sender_receiver() -> Result<(), ExchangerError<u32>> {
290 let (err_sender, err_receiver) = std::sync::mpsc::channel();
292 let err_sender2 = err_sender.clone();
293 let exch1 = Exchanger::<u32>::new(10);
294 let exch2 = exch1.clone();
295 let genthrd = std::thread::Builder::new()
296 .name("Sender".to_string())
297 .spawn(move || {
298 let mut sent = 0;
299 for i in 0..1_000 {
300 if let Err(e) = exch2.push(i) {
301 eprintln!("Error sending exchange: {e:?}");
302 if let Err(e) = err_sender.send((e, i, "send")) {
303 panic!("{e:?}");
304 }
305 }
306 sent += 1;
307 }
308 println!("Sent {sent}");
309 })
310 .unwrap();
311
312 let recv_thrd = std::thread::Builder::new()
313 .name("Receiver".to_string())
314 .spawn(move || {
315 let mut recvd = 0;
316 for i in 0..1_000 {
317 if let Err(e) = exch1.take() {
318 eprintln!("Error receiving exchange: {e:?}");
319 if let Err(e) = err_sender2.send((e, i, "recv")) {
320 panic!("{e:?}");
321 }
322 }
323 std::thread::sleep(Duration::from_millis(1)); recvd += 1;
325 }
326 println!("Received {recvd}");
327 })
328 .unwrap();
329
330 genthrd.join().unwrap();
331 recv_thrd.join().unwrap();
332
333 let mut errors: bool = false;
334 while let Ok(r) = err_receiver.recv() {
335 let (e, i, s) = r;
336 eprintln!("Error received {e:?} : {i} : {s}");
337 errors = true;
338 }
339
340 assert!(!errors);
341
342 Ok(())
343 }
344
345 #[test]
346 pub fn test_multiple_receivers() {
347 irox_log::init_console_level(Level::Info);
348 let (err_sender, err_receiver) = std::sync::mpsc::channel();
349 let err_sender2 = err_sender.clone();
350 let exch1 = Exchanger::<u32>::new(10);
351 let exch2 = exch1.clone();
352 let exch3 = exch1.clone();
353 let genthrd = std::thread::Builder::new()
354 .name("Sender".to_string())
355 .spawn(move || {
356 let mut sent = 0;
357 for i in 0..1_000_000 {
358 if let Err(e) = exch2.push(i) {
359 eprintln!("Error sending exchange: {e:?}");
360 if let Err(e) = err_sender.send((e, i, "send")) {
361 panic!("{e:?}");
362 }
363 }
364 sent += 1;
365 }
366 info!("Sent {sent}");
367 })
368 .unwrap();
369
370 let recv_count = Arc::new(AtomicU64::new(0));
371 let mut receivers: Vec<JoinHandle<()>> = Vec::new();
372 for thread_idx in 0..10 {
373 let counter = recv_count.clone();
374 let err_sender2 = err_sender2.clone();
375 let exch1 = exch1.clone();
376 let recv_thrd = std::thread::Builder::new()
377 .name(format!("Receiver {thread_idx}"))
378 .spawn(move || {
379 let counter = counter;
380
381 let mut recvd = 0;
382 loop {
383 if let Err(e) = exch1.take() {
384 if e == ExchangerError::TaskError(TaskError::ExecutorStoppingError) {
385 break;
387 }
388 error!("Error receiving exchange: {e:?}");
389 if let Err(e) = err_sender2.send((e, recvd, "recv")) {
390 panic!("Error sending error: {e:?}");
391 }
392 break;
393 }
394 recvd += 1;
396 counter.fetch_add(1, Ordering::Relaxed);
397 }
398 info!(
399 "Received {recvd} in thread {}",
400 std::thread::current().name().unwrap_or("")
401 );
402 })
403 .unwrap();
404 receivers.push(recv_thrd);
405 }
406 drop(err_sender2);
407
408 genthrd.join().unwrap();
409 info!("Generator thread joined");
410 exch3.shutdown();
411 info!("Executor shutdown");
412
413 for recv in receivers {
414 info!("Waiting on {}", recv.thread().name().unwrap_or(""));
415 recv.join().unwrap();
416 }
417
418 let mut errors: bool = false;
419 while let Ok(r) = err_receiver.recv() {
420 let (e, i, s) = r;
421 error!("Error received {e:?} : {i} : {s}");
422 errors = true;
423 }
424
425 assert!(!errors);
426 }
427}