use tokio::sync::mpsc;
use super::types::LiveState;
pub const DEFAULT_STREAM_CAPACITY: usize = 8192;
#[derive(Debug, Clone)]
pub struct StreamOptions {
pub capacity: usize,
pub on_overflow: OverflowPolicy,
}
impl Default for StreamOptions {
fn default() -> Self {
Self {
capacity: DEFAULT_STREAM_CAPACITY,
on_overflow: OverflowPolicy::Lagged,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum OverflowPolicy {
Lagged,
}
#[derive(Debug, Clone)]
pub enum StatusStreamItem {
Sample(LiveState),
Lagged { dropped: u64 },
}
pub struct StatusStream {
rx: mpsc::Receiver<StatusStreamItem>,
}
impl std::fmt::Debug for StatusStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StatusStream").finish_non_exhaustive()
}
}
impl StatusStream {
pub(crate) fn new(rx: mpsc::Receiver<StatusStreamItem>) -> Self {
Self { rx }
}
pub async fn recv(&mut self) -> Option<StatusStreamItem> {
self.rx.recv().await
}
pub fn try_recv(&mut self) -> Option<StatusStreamItem> {
self.rx.try_recv().ok()
}
}
pub(crate) struct Subscriber {
tx: mpsc::Sender<StatusStreamItem>,
pending_lagged: u64,
capacity: usize,
}
impl std::fmt::Debug for Subscriber {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Subscriber")
.field("capacity", &self.capacity)
.field("pending_lagged", &self.pending_lagged)
.field("closed", &self.tx.is_closed())
.finish()
}
}
impl Subscriber {
pub(crate) fn new(opts: &StreamOptions) -> (Self, StatusStream) {
let capacity = opts.capacity.max(1);
let (tx, rx) = mpsc::channel(capacity);
(
Self {
tx,
pending_lagged: 0,
capacity,
},
StatusStream::new(rx),
)
}
pub(crate) fn push(&mut self, state: &LiveState) -> bool {
if self.tx.is_closed() {
return false;
}
if self.pending_lagged > 0 {
match self.tx.try_send(StatusStreamItem::Lagged {
dropped: self.pending_lagged,
}) {
Ok(()) => self.pending_lagged = 0,
Err(mpsc::error::TrySendError::Full(_)) => {
self.pending_lagged = self.pending_lagged.saturating_add(1);
return true;
}
Err(mpsc::error::TrySendError::Closed(_)) => return false,
}
}
match self.tx.try_send(StatusStreamItem::Sample(state.clone())) {
Ok(()) => true,
Err(mpsc::error::TrySendError::Full(_)) => {
self.pending_lagged = self.pending_lagged.saturating_add(1);
true
}
Err(mpsc::error::TrySendError::Closed(_)) => false,
}
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use super::*;
use crate::cia402::types::{Connection, Logic, Measurements};
use crate::types::MotorMode;
fn sample(seq: u8) -> LiveState {
LiveState {
connection: Connection {
online: true,
..Default::default()
},
logic: Some(Logic::Enabled(MotorMode::ProfileVelocity)),
measurements: Measurements {
status_word: Some(0x0237),
position_rev: Some(seq as f32 * 0.01),
..Default::default()
},
timestamp: Instant::now(),
}
}
#[test]
fn options_default_uses_lagged_policy() {
let o = StreamOptions::default();
assert_eq!(o.on_overflow, OverflowPolicy::Lagged);
assert_eq!(o.capacity, DEFAULT_STREAM_CAPACITY);
}
#[tokio::test]
async fn push_delivers_samples_in_order() {
let opts = StreamOptions {
capacity: 4,
on_overflow: OverflowPolicy::Lagged,
};
let (mut sub, mut stream) = Subscriber::new(&opts);
for i in 0..4 {
assert!(sub.push(&sample(i)));
}
for i in 0..4 {
let item = stream.recv().await.unwrap();
match item {
StatusStreamItem::Sample(s) => {
assert!((s.measurements.position_rev.unwrap() - i as f32 * 0.01).abs() < 1e-6);
}
StatusStreamItem::Lagged { .. } => panic!("unexpected lagged"),
}
}
}
#[tokio::test]
async fn push_overflow_emits_lagged_with_dropped_count() {
let opts = StreamOptions {
capacity: 2,
on_overflow: OverflowPolicy::Lagged,
};
let (mut sub, mut stream) = Subscriber::new(&opts);
assert!(sub.push(&sample(0)));
assert!(sub.push(&sample(1)));
assert!(sub.push(&sample(2)));
assert!(sub.push(&sample(3)));
let first = stream.recv().await.unwrap();
assert!(matches!(first, StatusStreamItem::Sample(_)));
assert!(sub.push(&sample(4)));
let second = stream.recv().await.unwrap();
assert!(matches!(second, StatusStreamItem::Sample(_)));
let lagged = stream.recv().await.unwrap();
match lagged {
StatusStreamItem::Lagged { dropped } => assert_eq!(dropped, 2),
_ => panic!("expected Lagged"),
}
}
#[tokio::test]
async fn push_returns_false_when_subscriber_dropped() {
let opts = StreamOptions {
capacity: 4,
on_overflow: OverflowPolicy::Lagged,
};
let (mut sub, stream) = Subscriber::new(&opts);
drop(stream);
assert!(!sub.push(&sample(0)));
}
}