use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use futures::Stream;
use futures::stream::unfold;
use tokio::sync::{broadcast, watch};
use crate::TimeSpec;
use crate::error::{MotorcortexError, Result};
use crate::msg::{DataType, GroupStatusMsg};
use crate::parameter_value::{
GetParameterTuple, GetParameterValue, decode_parameter_value,
};
type Callback = Arc<dyn Fn(&Subscription) + Send + Sync + 'static>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Missed(pub u64);
impl std::fmt::Display for Missed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "stream consumer missed {} samples", self.0)
}
}
impl std::error::Error for Missed {}
struct GroupLayout {
description: GroupStatusMsg,
data_types: Vec<u32>,
}
impl GroupLayout {
fn from_group_msg(description: GroupStatusMsg) -> Self {
let data_types = description
.params
.iter()
.map(|p| {
DataType::try_from(p.info.data_type as i32)
.unwrap_or(DataType::Undefined) as u32
})
.collect();
Self {
description,
data_types,
}
}
}
struct SubscriptionInner {
id: AtomicU32,
alias: String,
fdiv: u32,
layout: RwLock<GroupLayout>,
buffer: watch::Sender<Vec<u8>>,
callback: RwLock<Option<Callback>>,
broadcast: Mutex<Option<broadcast::Sender<Vec<u8>>>>,
}
pub struct Subscription {
inner: Arc<SubscriptionInner>,
}
impl Subscription {
pub(crate) fn new(group_msg: GroupStatusMsg, fdiv: u32) -> Self {
let id = AtomicU32::new(group_msg.id);
let alias = group_msg.alias.clone();
let layout = RwLock::new(GroupLayout::from_group_msg(group_msg));
let (buffer, _) = watch::channel(Vec::new());
Self {
inner: Arc::new(SubscriptionInner {
id,
alias,
fdiv,
layout,
buffer,
callback: RwLock::new(None),
broadcast: Mutex::new(None),
}),
}
}
pub fn id(&self) -> u32 {
self.inner.id.load(Ordering::Acquire)
}
pub fn name(&self) -> &str {
&self.inner.alias
}
pub fn fdiv(&self) -> u32 {
self.inner.fdiv
}
pub fn paths(&self) -> Vec<String> {
self.inner
.layout
.read()
.unwrap()
.description
.params
.iter()
.map(|p| p.info.path.clone())
.collect()
}
pub(crate) fn rebind(&self, new_group: GroupStatusMsg) {
let new_id = new_group.id;
let new_layout = GroupLayout::from_group_msg(new_group);
{
let mut guard = self.inner.layout.write().unwrap();
*guard = new_layout;
}
self.inner.id.store(new_id, Ordering::Release);
}
pub fn notify<F>(&self, cb: F)
where
F: Fn(&Subscription) + Send + Sync + 'static,
{
*self.inner.callback.write().unwrap() = Some(Arc::new(cb));
}
pub fn read<V>(&self) -> Option<(TimeSpec, V)>
where
V: GetParameterTuple,
{
let rx = self.inner.buffer.subscribe();
let buffer = rx.borrow().clone();
let layout = self.inner.layout.read().unwrap();
decode_tuple::<V>(&layout, &buffer)
}
pub fn read_all<V>(&self) -> Option<(TimeSpec, Vec<V>)>
where
V: GetParameterValue + Default,
{
let rx = self.inner.buffer.subscribe();
let buffer = rx.borrow().clone();
let layout = self.inner.layout.read().unwrap();
decode_flat::<V>(&layout, &buffer)
}
pub async fn latest<V>(&self) -> Result<(TimeSpec, V)>
where
V: GetParameterTuple,
{
let mut rx = self.inner.buffer.subscribe();
loop {
let buffer = rx.borrow().clone();
if !buffer.is_empty() {
let layout = self.inner.layout.read().unwrap();
return decode_tuple::<V>(&layout, &buffer).ok_or_else(|| {
MotorcortexError::Decode(
"subscription payload used an unsupported protocol version".into(),
)
});
}
rx.changed().await.map_err(|_| {
MotorcortexError::Subscription(
"subscription watch channel closed before any payload arrived".into(),
)
})?;
}
}
pub fn stream<V>(&self, capacity: usize) -> impl Stream<Item = StreamResult<V>> + use<V>
where
V: GetParameterTuple + Send + 'static,
{
let sender = self.ensure_broadcast(capacity);
let rx = sender.subscribe();
let inner = Arc::clone(&self.inner);
unfold(rx, move |mut rx| {
let inner = Arc::clone(&inner);
async move {
loop {
match rx.recv().await {
Ok(buffer) => {
let decoded = {
let layout = inner.layout.read().unwrap();
decode_tuple::<V>(&layout, &buffer)
};
match decoded {
Some(decoded) => return Some((Ok(decoded), rx)),
None => continue,
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
return Some((Err(Missed(n)), rx));
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
}
})
}
pub(crate) fn update(&self, buffer: Vec<u8>) {
self.inner.buffer.send_replace(buffer.clone());
if let Some(tx) = self.inner.broadcast.lock().unwrap().as_ref() {
let _ = tx.send(buffer);
}
let cb = self.inner.callback.read().unwrap().clone();
if let Some(cb) = cb {
cb(self);
}
}
fn ensure_broadcast(&self, capacity: usize) -> broadcast::Sender<Vec<u8>> {
let mut guard = self.inner.broadcast.lock().unwrap();
guard
.get_or_insert_with(|| broadcast::channel(capacity).0)
.clone()
}
}
pub type StreamResult<V> = std::result::Result<(TimeSpec, V), Missed>;
impl Clone for Subscription {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
fn decode_tuple<V>(layout: &GroupLayout, buffer: &[u8]) -> Option<(TimeSpec, V)>
where
V: GetParameterTuple,
{
if buffer.is_empty() {
return None;
}
const HEADER_LEN: usize = 4;
let protocol_version = buffer[3];
if protocol_version != 1 {
return None;
}
let body = &buffer[HEADER_LEN..];
let ts = TimeSpec::from_buffer(body)?;
const TS_SIZE: usize = size_of::<TimeSpec>();
let payload = &body[TS_SIZE..];
let iter = layout
.description
.params
.iter()
.zip(layout.data_types.iter())
.scan(0usize, |cursor, (param, dt)| {
let size = param.size as usize;
let slice = &payload[*cursor..*cursor + size];
*cursor += size;
Some((dt, slice))
});
V::get_parameters(iter).ok().map(|v| (ts, v))
}
fn decode_flat<V>(layout: &GroupLayout, buffer: &[u8]) -> Option<(TimeSpec, Vec<V>)>
where
V: GetParameterValue + Default,
{
if buffer.is_empty() {
return None;
}
const HEADER_LEN: usize = 4;
let protocol_version = buffer[3];
if protocol_version != 1 {
return None;
}
let body = &buffer[HEADER_LEN..];
let ts = TimeSpec::from_buffer(body)?;
const TS_SIZE: usize = size_of::<TimeSpec>();
let payload = &body[TS_SIZE..];
let mut values = Vec::new();
let mut cursor = 0usize;
for (param, &data_type) in layout.description.params.iter().zip(layout.data_types.iter()) {
let size = param.size as usize;
let data_size = param.info.data_size as usize;
let n = param.info.number_of_elements as usize;
let bytes = &payload[cursor..cursor + size];
for i in 0..n {
let start = i * data_size;
let end = start + data_size;
values.push(decode_parameter_value::<V>(data_type, &bytes[start..end]));
}
cursor += size;
}
Some((ts, values))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::msg::{GroupParameterInfo, ParameterInfo, ParameterType, StatusCode};
use std::sync::Mutex;
fn param(path: &str, dtype: DataType, data_size: u32, n_elements: u32) -> GroupParameterInfo {
GroupParameterInfo {
index: 0,
offset: 0,
size: data_size * n_elements,
info: ParameterInfo {
id: 0,
data_type: dtype as u32,
data_size,
number_of_elements: n_elements,
flags: 0,
permissions: 0,
param_type: ParameterType::Parameter as i32,
group_id: 0,
unit: 0,
path: path.to_string(),
},
status: StatusCode::Ok as i32,
}
}
fn group(id: u32, alias: &str, params: Vec<GroupParameterInfo>) -> GroupStatusMsg {
GroupStatusMsg {
header: None,
id,
alias: alias.to_string(),
params,
status: StatusCode::Ok as i32,
}
}
fn protocol1(body: &[u8]) -> Vec<u8> {
let mut buf = vec![0u8, 0, 0, 1];
buf.extend_from_slice(&[0u8; 16]); buf.extend_from_slice(body);
buf
}
#[test]
fn id_and_name_reflect_group_msg() {
let sub = Subscription::new(group(7, "grp", vec![]), 1);
assert_eq!(sub.id(), 7);
assert_eq!(sub.name(), "grp");
}
#[test]
fn clone_is_shared() {
let sub = Subscription::new(group(1, "g", vec![]), 1);
let clone = sub.clone();
assert!(Arc::ptr_eq(&sub.inner, &clone.inner));
}
#[test]
fn read_returns_none_without_a_payload() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
assert!(sub.read::<f64>().is_none());
assert!(sub.read_all::<f64>().is_none());
}
#[test]
fn read_decodes_a_single_scalar_payload() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
sub.update(protocol1(&2.5_f64.to_le_bytes()));
let (_ts, value) = sub.read::<f64>().expect("decode ok");
assert_eq!(value, 2.5);
}
#[test]
fn read_all_decodes_flattened_array() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 3)]), 1);
let mut body = Vec::new();
body.extend_from_slice(&1.0f64.to_le_bytes());
body.extend_from_slice(&2.0f64.to_le_bytes());
body.extend_from_slice(&3.0f64.to_le_bytes());
sub.update(protocol1(&body));
let (_ts, values) = sub.read_all::<f64>().expect("decode");
assert_eq!(values, vec![1.0, 2.0, 3.0]);
}
#[test]
fn update_fires_the_callback() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
let hits = Arc::new(Mutex::new(0u32));
let counter = Arc::clone(&hits);
sub.notify(move |_| {
*counter.lock().unwrap() += 1;
});
sub.update(protocol1(&0f64.to_le_bytes()));
sub.update(protocol1(&0f64.to_le_bytes()));
assert_eq!(*hits.lock().unwrap(), 2);
}
#[test]
fn non_protocol_1_returns_none() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
let mut buf = vec![0u8, 0, 0, 0]; buf.extend_from_slice(&[0u8; 24]);
sub.update(buf);
assert!(sub.read::<f64>().is_none());
assert!(sub.read_all::<f64>().is_none());
}
#[tokio::test]
async fn latest_resolves_immediately_when_payload_already_present() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
sub.update(protocol1(&7.5f64.to_le_bytes()));
let (_ts, v) = sub.latest::<f64>().await.expect("decode ok");
assert_eq!(v, 7.5);
}
#[tokio::test]
async fn latest_waits_for_the_first_payload() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
let sub_for_push = sub.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
sub_for_push.update(protocol1(&42.0f64.to_le_bytes()));
});
let (_ts, v) = sub.latest::<f64>().await.expect("decode ok");
assert_eq!(v, 42.0);
}
#[tokio::test]
async fn stream_delivers_consecutive_payloads() {
use futures::StreamExt;
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
let mut stream = Box::pin(sub.stream::<f64>(16));
sub.update(protocol1(&1.0f64.to_le_bytes()));
sub.update(protocol1(&2.0f64.to_le_bytes()));
sub.update(protocol1(&3.0f64.to_le_bytes()));
for expected in [1.0, 2.0, 3.0f64] {
let item = tokio::time::timeout(std::time::Duration::from_millis(100), stream.next())
.await
.expect("stream must yield within 100 ms")
.expect("stream is not exhausted");
let (_ts, v) = item.expect("not lagged");
assert_eq!(v, expected);
}
}
#[tokio::test]
async fn stream_surfaces_lag_as_err() {
use futures::StreamExt;
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
let mut stream = Box::pin(sub.stream::<f64>(2));
for i in 0..8 {
sub.update(protocol1(&(i as f64).to_le_bytes()));
}
let mut saw_miss = false;
for _ in 0..8 {
let item = tokio::time::timeout(std::time::Duration::from_millis(100), stream.next())
.await
.expect("stream yields")
.expect("not exhausted");
if let Err(Missed(n)) = item {
assert!(n > 0, "Missed's inner count must be positive");
saw_miss = true;
break;
}
}
assert!(saw_miss, "expected to observe at least one Missed item");
}
#[tokio::test]
async fn stream_is_not_created_unless_requested() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
sub.update(protocol1(&1.0f64.to_le_bytes()));
assert!(sub.inner.broadcast.lock().unwrap().is_none());
}
#[test]
fn missed_formats_and_is_error() {
let m = Missed(7);
assert_eq!(m, Missed(7));
assert_eq!(format!("{m}"), "stream consumer missed 7 samples");
let _: &dyn std::error::Error = &m;
}
#[tokio::test]
async fn latest_errors_on_unsupported_protocol_version() {
let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
let mut buf = vec![0u8, 0, 0, 7]; buf.extend_from_slice(&[0u8; 24]);
sub.update(buf);
let err = sub
.latest::<f64>()
.await
.expect_err("unsupported protocol must error");
assert!(matches!(err, MotorcortexError::Decode(_)));
}
#[test]
fn fdiv_and_paths_reflect_the_constructor_args() {
let params = vec![
param("root/a", DataType::Double, 8, 1),
param("root/b", DataType::Int32, 4, 1),
];
let sub = Subscription::new(group(1, "g", params), 7);
assert_eq!(sub.fdiv(), 7);
assert_eq!(sub.paths(), vec!["root/a".to_string(), "root/b".to_string()]);
}
#[test]
fn rebind_swaps_id_and_layout() {
let sub = Subscription::new(
group(11, "grp", vec![param("root/a", DataType::Double, 8, 1)]),
1,
);
assert_eq!(sub.id(), 11);
assert_eq!(sub.paths(), vec!["root/a".to_string()]);
let new_group = group(
42,
"grp",
vec![
param("root/x", DataType::Double, 8, 1),
param("root/y", DataType::Int32, 4, 1),
],
);
sub.rebind(new_group);
assert_eq!(sub.id(), 42);
assert_eq!(sub.paths(), vec!["root/x".to_string(), "root/y".to_string()]);
assert_eq!(sub.name(), "grp");
assert_eq!(sub.fdiv(), 1);
}
#[test]
fn rebind_is_visible_to_outstanding_clones() {
let sub = Subscription::new(
group(1, "g", vec![param("root/a", DataType::Double, 8, 1)]),
1,
);
let clone = sub.clone();
let new_group = group(99, "g", vec![param("root/b", DataType::Double, 8, 1)]);
sub.rebind(new_group);
assert_eq!(clone.id(), 99);
assert_eq!(clone.paths(), vec!["root/b".to_string()]);
}
}