use crate::{
channels::{
BackStage, FrontStage, OverflowPolicy, Rx, RxChannelTimeseries, RxConnectable,
SharedBackStage, SyncResult,
},
core::{Message, TimestampKind},
prelude::{Pop, RetentionPolicy},
};
use core::ops;
use std::{
collections::vec_deque,
fmt,
sync::{Arc, RwLock},
};
pub struct DoubleBufferRx<T> {
pub(crate) back: SharedBackStage<T>,
front: FrontStage<T>,
pub(crate) is_connected: bool,
}
impl<T> DoubleBufferRx<T> {
pub fn new(overflow_policy: OverflowPolicy, retention_policy: RetentionPolicy) -> Self {
let back = BackStage::new(overflow_policy, retention_policy);
let capacity = back.capacity();
Self {
back: Arc::new(RwLock::new(back)),
front: FrontStage::new(capacity),
is_connected: false,
}
}
pub fn new_latest() -> Self {
Self::new(OverflowPolicy::Forget(1), RetentionPolicy::Keep)
}
pub fn new_auto_size() -> Self {
Self::new(OverflowPolicy::Resize, RetentionPolicy::Drop)
}
pub fn front_len(&self) -> usize {
self.front.len()
}
pub fn pop_all(&mut self) -> std::collections::vec_deque::Drain<'_, T> {
self.front.drain(..)
}
pub fn latest(&self) -> Option<&T> {
let n = self.front.len();
if n == 0 {
None
} else {
Some(&self.front[n - 1])
}
}
pub fn is_full(&self) -> bool {
match self.back.read().unwrap().overflow_policy() {
OverflowPolicy::Reject(n) | OverflowPolicy::Forget(n) => self.front.len() == *n,
OverflowPolicy::Resize => false,
}
}
pub fn clear(&mut self) {
self.front.clear();
}
pub fn drain<R>(&mut self, range: R) -> vec_deque::Drain<'_, T>
where
R: ops::RangeBounds<usize>,
{
self.front.drain(range)
}
}
impl<T> DoubleBufferRx<Message<T>> {
pub fn as_acq_time_series<'a>(&'a self) -> RxChannelTimeseries<'a, T> {
RxChannelTimeseries {
channel: self,
kind: TimestampKind::Acq,
}
}
pub fn as_pub_time_series<'a>(&'a self) -> RxChannelTimeseries<'a, T> {
RxChannelTimeseries {
channel: self,
kind: TimestampKind::Pub,
}
}
}
impl<T> Pop for DoubleBufferRx<T> {
type Output = T;
fn is_empty(&self) -> bool {
self.front.is_empty()
}
fn pop(&mut self) -> Result<T, RxRecvError> {
self.front.pop().ok_or(RxRecvError::QueueEmtpy)
}
}
impl<T> ops::Index<usize> for DoubleBufferRx<T> {
type Output = T;
fn index(&self, idx: usize) -> &Self::Output {
&self.front[idx]
}
}
impl<T> ops::IndexMut<usize> for DoubleBufferRx<T> {
fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
&mut self.front[idx]
}
}
impl<T: Send + Sync> Rx for DoubleBufferRx<T> {
fn is_connected(&self) -> bool {
self.is_connected
}
fn sync(&mut self) -> SyncResult {
self.back.write().unwrap().sync(&mut self.front)
}
fn len(&self) -> usize {
self.front_len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RxRecvError {
QueueEmtpy,
}
impl fmt::Display for RxRecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
RxRecvError::QueueEmtpy => write!(fmt, "QueueEmtpy"),
}
}
}
impl std::error::Error for RxRecvError {}
impl<T: Send + Sync> RxConnectable for DoubleBufferRx<T> {
type Message = T;
fn overflow_policy(&self) -> OverflowPolicy {
*self.back.read().unwrap().overflow_policy()
}
fn on_connect(&mut self) -> SharedBackStage<Self::Message> {
self.is_connected = true;
self.back.clone()
}
}
#[cfg(test)]
mod tests {
use crate::{
channels::{FlushResult, SyncResult},
prelude::*,
};
use std::sync::mpsc;
fn fixed_channel<T: Clone + Send + Sync>(
size: usize,
) -> (DoubleBufferTx<T>, DoubleBufferRx<T>) {
let mut tx = DoubleBufferTx::new(size);
let mut rx =
DoubleBufferRx::new(OverflowPolicy::Reject(size), RetentionPolicy::EnforceEmpty);
connect(&mut tx, &mut rx).unwrap();
(tx, rx)
}
#[test]
fn test() {
const NUM_MESSAGES: usize = 100;
const NUM_ROUNDS: usize = 100;
let (mut tx, mut rx) = fixed_channel(NUM_MESSAGES);
let (sync_tx, sync_rx) = mpsc::sync_channel(1);
let (rep_tx, rep_rx) = mpsc::sync_channel(1);
let t1 = std::thread::spawn(move || {
for k in 0..NUM_ROUNDS {
sync_rx.recv().unwrap();
assert_eq!(
rx.sync(),
SyncResult {
received: NUM_MESSAGES,
..Default::default()
}
);
rep_tx.send(()).unwrap();
for i in 0..NUM_MESSAGES {
assert_eq!(rx.pop().unwrap(), format!("hello {k} {i}"));
}
}
});
let t2 = std::thread::spawn(move || {
for k in 0..NUM_ROUNDS {
for i in 0..NUM_MESSAGES {
tx.push(format!("hello {k} {i}")).unwrap();
}
assert_eq!(
tx.flush(),
FlushResult {
available: NUM_MESSAGES,
published: NUM_MESSAGES,
..Default::default()
}
);
sync_tx.send(()).unwrap();
rep_rx.recv().unwrap();
}
});
t1.join().unwrap();
t2.join().unwrap();
}
}