use std::collections::VecDeque;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use log::trace;
use crate::TaskError;
pub enum ExchangerError<T> {
TaskError(TaskError),
ExchangerFull(T),
ExchangerEmpty,
}
impl<T> Debug for ExchangerError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ExchangerError::TaskError(e) => {
write!(f, "TaskError: {e:?}")
}
ExchangerError::ExchangerFull(_) => {
write!(f, "ExchangerFull")
}
ExchangerError::ExchangerEmpty => {
write!(f, "ExchangerEmpty")
}
}
}
}
impl<T> PartialEq for ExchangerError<T> {
fn eq(&self, other: &Self) -> bool {
match self {
ExchangerError::TaskError(e) => {
if let ExchangerError::TaskError(e2) = other {
return e == e2;
}
false
}
ExchangerError::ExchangerFull(_) => {
matches!(other, ExchangerError::ExchangerFull(_))
}
ExchangerError::ExchangerEmpty => {
matches!(other, ExchangerError::ExchangerEmpty)
}
}
}
}
struct InnerExchange<T: Send> {
mutex: Mutex<VecDeque<T>>,
take_condition: Condvar,
put_condition: Condvar,
shutdown: AtomicBool,
max_size: usize,
num_waiting_takers: AtomicU16,
num_waiting_putters: AtomicU16,
}
impl<T: Send> InnerExchange<T> {
pub fn new(max_size: usize) -> Self {
InnerExchange {
max_size,
mutex: Default::default(),
take_condition: Default::default(),
put_condition: Default::default(),
shutdown: AtomicBool::new(false),
num_waiting_takers: AtomicU16::new(0),
num_waiting_putters: AtomicU16::new(0),
}
}
pub fn take_blocking(&self) -> Result<T, ExchangerError<T>> {
let Ok(mut elems) = self.mutex.lock() else {
return Err(ExchangerError::TaskError(TaskError::LockingError));
};
if let Some(e) = elems.pop_front() {
trace!("Take_blocking popped one");
self.put_condition.notify_one();
return Ok(e);
}
if self.shutdown.load(Ordering::SeqCst) {
return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
}
trace!("Take_blocking waiting for element");
self.num_waiting_takers.fetch_add(1, Ordering::SeqCst);
let Ok(mut elems) = self.take_condition.wait_while(elems, |e| {
e.is_empty() && !self.shutdown.load(Ordering::SeqCst)
}) else {
return Err(ExchangerError::TaskError(TaskError::LockingError));
};
self.num_waiting_takers.fetch_sub(1, Ordering::SeqCst);
let Some(e) = elems.pop_front() else {
trace!("Take_blocking woken up for empty exchange");
return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
};
trace!("Take_blocking woken up for new element");
self.put_condition.notify_one();
Ok(e)
}
pub fn try_take(&self) -> Result<T, ExchangerError<T>> {
let Ok(mut elems) = self.mutex.lock() else {
return Err(ExchangerError::TaskError(TaskError::LockingError));
};
if let Some(e) = elems.pop_front() {
trace!("Take_blocking popped one");
self.put_condition.notify_one();
return Ok(e);
}
if self.shutdown.load(Ordering::SeqCst) {
return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
}
Err(ExchangerError::ExchangerEmpty)
}
pub fn put_blocking(&self, elem: T) -> Result<(), ExchangerError<T>> {
if self.shutdown.load(Ordering::SeqCst) {
return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
}
let Ok(mut elems) = self.mutex.lock() else {
return Err(ExchangerError::TaskError(TaskError::LockingError));
};
if elems.len() < self.max_size {
trace!("Put_blocking added one");
elems.push_back(elem);
self.take_condition.notify_one();
return Ok(());
}
trace!("Put_blocking full, waiting for empty spot");
self.num_waiting_putters.fetch_add(1, Ordering::SeqCst);
let Ok(mut elems) = self.put_condition.wait_while(elems, |e| {
e.len() >= self.max_size && !self.shutdown.load(Ordering::SeqCst)
}) else {
return Err(ExchangerError::TaskError(TaskError::LockingError));
};
self.num_waiting_putters.fetch_sub(1, Ordering::SeqCst);
if elems.len() == self.max_size {
trace!("Put_blocking woken up for full, cannot add new element");
return Err(ExchangerError::ExchangerFull(elem));
};
trace!("Put_blocking woken up for free space");
elems.push_back(elem);
self.take_condition.notify_one();
Ok(())
}
pub fn try_put(&self, elem: T) -> Result<(), ExchangerError<T>> {
if self.shutdown.load(Ordering::SeqCst) {
return Err(ExchangerError::TaskError(TaskError::ExecutorStoppingError));
}
let Ok(mut elems) = self.mutex.lock() else {
return Err(ExchangerError::TaskError(TaskError::LockingError));
};
if elems.len() < self.max_size {
trace!("try_put added one");
elems.push_back(elem);
self.take_condition.notify_one();
return Ok(());
}
Err(ExchangerError::ExchangerFull(elem))
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
while self.num_waiting_putters.load(Ordering::SeqCst) > 0 {
self.put_condition.notify_all();
}
while self.num_waiting_takers.load(Ordering::SeqCst) > 0 {
self.take_condition.notify_all();
}
}
}
pub struct Exchanger<T: Send> {
exchange: Arc<InnerExchange<T>>,
}
impl<T: Send> Clone for Exchanger<T> {
fn clone(&self) -> Self {
Exchanger {
exchange: self.exchange.clone(),
}
}
}
impl<T: Send> Exchanger<T> {
pub fn new(max_size: usize) -> Self {
Exchanger {
exchange: Arc::new(InnerExchange::new(max_size)),
}
}
pub fn push(&self, elem: T) -> Result<(), ExchangerError<T>> {
self.exchange.put_blocking(elem)
}
pub fn try_push(&self, elem: T) -> Result<(), ExchangerError<T>> {
self.exchange.try_put(elem)
}
pub fn take(&self) -> Result<T, ExchangerError<T>> {
self.exchange.take_blocking()
}
pub fn try_take(&self) -> Result<T, ExchangerError<T>> {
self.exchange.try_take()
}
pub fn shutdown(&self) {
self.exchange.shutdown();
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;
use log::{error, info, Level};
use crate::{Exchanger, ExchangerError, TaskError};
#[test]
pub fn test_single_sender_receiver() -> Result<(), ExchangerError<u32>> {
let (err_sender, err_receiver) = std::sync::mpsc::channel();
let err_sender2 = err_sender.clone();
let exch1 = Exchanger::<u32>::new(10);
let exch2 = exch1.clone();
let genthrd = std::thread::Builder::new()
.name("Sender".to_string())
.spawn(move || {
let mut sent = 0;
for i in 0..1_000 {
if let Err(e) = exch2.push(i) {
eprintln!("Error sending exchange: {e:?}");
if let Err(e) = err_sender.send((e, i, "send")) {
panic!("{e:?}");
}
}
sent += 1;
}
println!("Sent {sent}");
})
.unwrap();
let recv_thrd = std::thread::Builder::new()
.name("Receiver".to_string())
.spawn(move || {
let mut recvd = 0;
for i in 0..1_000 {
if let Err(e) = exch1.take() {
eprintln!("Error receiving exchange: {e:?}");
if let Err(e) = err_sender2.send((e, i, "recv")) {
panic!("{e:?}");
}
}
std::thread::sleep(Duration::from_millis(1)); recvd += 1;
}
println!("Received {recvd}");
})
.unwrap();
genthrd.join().unwrap();
recv_thrd.join().unwrap();
let mut errors: bool = false;
while let Ok(r) = err_receiver.recv() {
let (e, i, s) = r;
eprintln!("Error received {e:?} : {i} : {s}");
errors = true;
}
assert!(!errors);
Ok(())
}
#[test]
pub fn test_multiple_receivers() {
irox_log::init_console_level(Level::Info);
let (err_sender, err_receiver) = std::sync::mpsc::channel();
let err_sender2 = err_sender.clone();
let exch1 = Exchanger::<u32>::new(10);
let exch2 = exch1.clone();
let exch3 = exch1.clone();
let genthrd = std::thread::Builder::new()
.name("Sender".to_string())
.spawn(move || {
let mut sent = 0;
for i in 0..1_000_000 {
if let Err(e) = exch2.push(i) {
eprintln!("Error sending exchange: {e:?}");
if let Err(e) = err_sender.send((e, i, "send")) {
panic!("{e:?}");
}
}
sent += 1;
}
info!("Sent {sent}");
})
.unwrap();
let recv_count = Arc::new(AtomicU64::new(0));
let mut receivers: Vec<JoinHandle<()>> = Vec::new();
for thread_idx in 0..10 {
let counter = recv_count.clone();
let err_sender2 = err_sender2.clone();
let exch1 = exch1.clone();
let recv_thrd = std::thread::Builder::new()
.name(format!("Receiver {thread_idx}"))
.spawn(move || {
let counter = counter;
let mut recvd = 0;
loop {
if let Err(e) = exch1.take() {
if e == ExchangerError::TaskError(TaskError::ExecutorStoppingError) {
break;
}
error!("Error receiving exchange: {e:?}");
if let Err(e) = err_sender2.send((e, recvd, "recv")) {
panic!("Error sending error: {e:?}");
}
break;
}
recvd += 1;
counter.fetch_add(1, Ordering::Relaxed);
}
info!(
"Received {recvd} in thread {}",
std::thread::current().name().unwrap_or("")
);
})
.unwrap();
receivers.push(recv_thrd);
}
drop(err_sender2);
genthrd.join().unwrap();
info!("Generator thread joined");
exch3.shutdown();
info!("Executor shutdown");
for recv in receivers {
info!("Waiting on {}", recv.thread().name().unwrap_or(""));
recv.join().unwrap();
}
let mut errors: bool = false;
while let Ok(r) = err_receiver.recv() {
let (e, i, s) = r;
error!("Error received {e:?} : {i} : {s}");
errors = true;
}
assert!(!errors);
}
}