use crate::{
channels::{BackStage, FlushResult, OverflowPolicy, SharedBackStage, Tx, TxConnectable},
prelude::RetentionPolicy,
};
use std::fmt;
pub const MAX_RECEIVER_COUNT: usize = 64;
pub struct DoubleBufferTx<T> {
outbox: BackStage<T>,
connections: Vec<SharedBackStage<T>>,
}
impl<T> DoubleBufferTx<T> {
pub fn new(capacity: usize) -> Self {
Self {
outbox: BackStage::new(OverflowPolicy::Reject(capacity), RetentionPolicy::Drop),
connections: Vec::new(),
}
}
pub fn new_auto_size() -> Self {
Self {
outbox: BackStage::new(OverflowPolicy::Resize, RetentionPolicy::Drop),
connections: Vec::new(),
}
}
pub fn push(&mut self, value: T) -> Result<(), TxSendError> {
self.outbox.push(value).map_err(|_| TxSendError::QueueFull)
}
pub fn push_many<I: IntoIterator<Item = T>>(&mut self, values: I) -> Result<(), TxSendError> {
for x in values.into_iter() {
self.push(x)?;
}
Ok(())
}
}
impl<V: Send + Sync + Clone> TxConnectable for DoubleBufferTx<V> {
type Message = V;
fn has_max_connection_count(&self) -> bool {
self.connections.len() >= MAX_RECEIVER_COUNT
}
fn overflow_policy(&self) -> OverflowPolicy {
*self.outbox.overflow_policy()
}
fn on_connect(&mut self, stage: SharedBackStage<Self::Message>) {
self.connections.push(stage);
}
}
#[derive(Debug, thiserror::Error)]
pub enum TxConnectError {
#[error("RX cannot be connected to more than one transmitter")]
ReceiverAlreadyConnected,
#[error("TX exceeded maximum connection count")]
MaxConnectionCountExceeded,
#[error(
"Cannot connect a TX with policy `Resize` to an RX with policy `Reject`.
Either change the TX policy to `Reject` or the RX policy to `Resize` or `Forget`."
)]
PolicyMismatch,
}
impl<T: Send + Sync + Clone> Tx for DoubleBufferTx<T> {
fn flush(&mut self) -> FlushResult {
let mut result = FlushResult::default();
result.available = self.outbox.len();
for (i, rx) in self.connections.iter().enumerate().skip(1) {
let mut q = rx.write().unwrap();
for v in self.outbox.iter() {
if matches!(q.push((*v).clone()), Err(_)) {
result.error_indicator.mark(i);
break;
}
result.cloned += 1;
result.published += 1;
}
}
if let Some(first_rx) = self.connections.get(0) {
let mut q = first_rx.write().unwrap();
for v in self.outbox.drain_all() {
if matches!(q.push(v), Err(_)) {
result.error_indicator.mark(0);
break;
}
result.published += 1;
}
} else {
self.outbox.clear();
}
result
}
fn is_connected(&self) -> bool {
!self.connections.is_empty()
}
fn len(&self) -> usize {
self.outbox.len()
}
}
#[derive(Debug)]
pub enum TxSendError {
QueueFull,
}
impl fmt::Display for TxSendError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
TxSendError::QueueFull => write!(fmt, "QueueFull"),
}
}
}
impl std::error::Error for TxSendError {}
#[cfg(test)]
mod tests {
use crate::{
channels::{FlushResult, SyncResult},
prelude::*,
};
use std::sync::mpsc;
fn fixed_channel<T: Clone + Send + Sync>(
size: usize,
) -> (DoubleBufferTx<T>, DoubleBufferRx<T>) {
let mut tx = DoubleBufferTx::new(size);
let mut rx =
DoubleBufferRx::new(OverflowPolicy::Reject(size), RetentionPolicy::EnforceEmpty);
connect(&mut tx, &mut rx).unwrap();
(tx, rx)
}
#[test]
fn test() {
const NUM_MESSAGES: usize = 100;
const NUM_ROUNDS: usize = 100;
let (mut tx, mut rx) = fixed_channel(NUM_MESSAGES);
let (sync_tx, sync_rx) = mpsc::sync_channel(1);
let (rep_tx, rep_rx) = mpsc::sync_channel(1);
let t1 = std::thread::spawn(move || {
for k in 0..NUM_ROUNDS {
sync_rx.recv().unwrap();
assert_eq!(
rx.sync(),
SyncResult {
received: NUM_MESSAGES,
..Default::default()
}
);
rep_tx.send(()).unwrap();
for i in 0..NUM_MESSAGES {
assert_eq!(rx.pop().unwrap(), format!("hello {k} {i}"));
}
}
});
let t2 = std::thread::spawn(move || {
for k in 0..NUM_ROUNDS {
for i in 0..NUM_MESSAGES {
tx.push(format!("hello {k} {i}")).unwrap();
}
assert_eq!(
tx.flush(),
FlushResult {
available: NUM_MESSAGES,
published: NUM_MESSAGES,
..Default::default()
}
);
sync_tx.send(()).unwrap();
rep_rx.recv().unwrap();
}
});
t1.join().unwrap();
t2.join().unwrap();
}
}