1use std::sync::atomic::{AtomicU32, Ordering};
17use std::sync::{Arc, Mutex, RwLock};
18
19use futures::Stream;
20use futures::stream::unfold;
21use tokio::sync::{broadcast, watch};
22
23use crate::TimeSpec;
24use crate::error::{MotorcortexError, Result};
25use crate::msg::{DataType, GroupStatusMsg};
26use crate::parameter_value::{
27 GetParameterTuple, GetParameterValue, decode_parameter_value,
28};
29
30type Callback = Arc<dyn Fn(&Subscription) + Send + Sync + 'static>;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct Missed(pub u64);
39
40impl std::fmt::Display for Missed {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(f, "stream consumer missed {} samples", self.0)
43 }
44}
45
46impl std::error::Error for Missed {}
47
48struct GroupLayout {
52 description: GroupStatusMsg,
53 data_types: Vec<u32>,
54}
55
56impl GroupLayout {
57 fn from_group_msg(description: GroupStatusMsg) -> Self {
58 let data_types = description
59 .params
60 .iter()
61 .map(|p| {
62 DataType::try_from(p.info.data_type as i32)
63 .unwrap_or(DataType::Undefined) as u32
64 })
65 .collect();
66 Self {
67 description,
68 data_types,
69 }
70 }
71}
72
73struct SubscriptionInner {
74 id: AtomicU32,
79 alias: String,
80 fdiv: u32,
83 layout: RwLock<GroupLayout>,
87 buffer: watch::Sender<Vec<u8>>,
90 callback: RwLock<Option<Callback>>,
91 broadcast: Mutex<Option<broadcast::Sender<Vec<u8>>>>,
96}
97
98pub struct Subscription {
101 inner: Arc<SubscriptionInner>,
102}
103
104impl Subscription {
105 pub(crate) fn new(group_msg: GroupStatusMsg, fdiv: u32) -> Self {
108 let id = AtomicU32::new(group_msg.id);
109 let alias = group_msg.alias.clone();
110 let layout = RwLock::new(GroupLayout::from_group_msg(group_msg));
111 let (buffer, _) = watch::channel(Vec::new());
112 Self {
113 inner: Arc::new(SubscriptionInner {
114 id,
115 alias,
116 fdiv,
117 layout,
118 buffer,
119 callback: RwLock::new(None),
120 broadcast: Mutex::new(None),
121 }),
122 }
123 }
124
125 pub fn id(&self) -> u32 {
130 self.inner.id.load(Ordering::Acquire)
131 }
132
133 pub fn name(&self) -> &str {
136 &self.inner.alias
137 }
138
139 pub fn fdiv(&self) -> u32 {
141 self.inner.fdiv
142 }
143
144 pub fn paths(&self) -> Vec<String> {
149 self.inner
150 .layout
151 .read()
152 .unwrap()
153 .description
154 .params
155 .iter()
156 .map(|p| p.info.path.clone())
157 .collect()
158 }
159
160 pub(crate) fn rebind(&self, new_group: GroupStatusMsg) {
165 let new_id = new_group.id;
166 let new_layout = GroupLayout::from_group_msg(new_group);
167 {
172 let mut guard = self.inner.layout.write().unwrap();
173 *guard = new_layout;
174 }
175 self.inner.id.store(new_id, Ordering::Release);
176 }
177
178 pub fn notify<F>(&self, cb: F)
192 where
193 F: Fn(&Subscription) + Send + Sync + 'static,
194 {
195 *self.inner.callback.write().unwrap() = Some(Arc::new(cb));
196 }
197
198 pub fn read<V>(&self) -> Option<(TimeSpec, V)>
202 where
203 V: GetParameterTuple,
204 {
205 let rx = self.inner.buffer.subscribe();
206 let buffer = rx.borrow().clone();
207 let layout = self.inner.layout.read().unwrap();
208 decode_tuple::<V>(&layout, &buffer)
209 }
210
211 pub fn read_all<V>(&self) -> Option<(TimeSpec, Vec<V>)>
214 where
215 V: GetParameterValue + Default,
216 {
217 let rx = self.inner.buffer.subscribe();
218 let buffer = rx.borrow().clone();
219 let layout = self.inner.layout.read().unwrap();
220 decode_flat::<V>(&layout, &buffer)
221 }
222
223 pub async fn latest<V>(&self) -> Result<(TimeSpec, V)>
238 where
239 V: GetParameterTuple,
240 {
241 let mut rx = self.inner.buffer.subscribe();
242 loop {
243 let buffer = rx.borrow().clone();
246 if !buffer.is_empty() {
247 let layout = self.inner.layout.read().unwrap();
248 return decode_tuple::<V>(&layout, &buffer).ok_or_else(|| {
249 MotorcortexError::Decode(
250 "subscription payload used an unsupported protocol version".into(),
251 )
252 });
253 }
254 rx.changed().await.map_err(|_| {
256 MotorcortexError::Subscription(
257 "subscription watch channel closed before any payload arrived".into(),
258 )
259 })?;
260 }
261 }
262
263 pub fn stream<V>(&self, capacity: usize) -> impl Stream<Item = StreamResult<V>> + use<V>
296 where
297 V: GetParameterTuple + Send + 'static,
298 {
299 let sender = self.ensure_broadcast(capacity);
300 let rx = sender.subscribe();
301 let inner = Arc::clone(&self.inner);
302 unfold(rx, move |mut rx| {
303 let inner = Arc::clone(&inner);
304 async move {
305 loop {
306 match rx.recv().await {
307 Ok(buffer) => {
308 let decoded = {
309 let layout = inner.layout.read().unwrap();
310 decode_tuple::<V>(&layout, &buffer)
311 };
312 match decoded {
313 Some(decoded) => return Some((Ok(decoded), rx)),
314 None => continue,
319 }
320 }
321 Err(broadcast::error::RecvError::Lagged(n)) => {
322 return Some((Err(Missed(n)), rx));
323 }
324 Err(broadcast::error::RecvError::Closed) => return None,
325 }
326 }
327 }
328 })
329 }
330
331 pub(crate) fn update(&self, buffer: Vec<u8>) {
335 self.inner.buffer.send_replace(buffer.clone());
342 if let Some(tx) = self.inner.broadcast.lock().unwrap().as_ref() {
344 let _ = tx.send(buffer);
347 }
348 let cb = self.inner.callback.read().unwrap().clone();
351 if let Some(cb) = cb {
352 cb(self);
353 }
354 }
355
356 fn ensure_broadcast(&self, capacity: usize) -> broadcast::Sender<Vec<u8>> {
357 let mut guard = self.inner.broadcast.lock().unwrap();
358 guard
359 .get_or_insert_with(|| broadcast::channel(capacity).0)
360 .clone()
361 }
362}
363
364pub type StreamResult<V> = std::result::Result<(TimeSpec, V), Missed>;
366
367impl Clone for Subscription {
368 fn clone(&self) -> Self {
369 Self {
370 inner: Arc::clone(&self.inner),
371 }
372 }
373}
374
375fn decode_tuple<V>(layout: &GroupLayout, buffer: &[u8]) -> Option<(TimeSpec, V)>
376where
377 V: GetParameterTuple,
378{
379 if buffer.is_empty() {
380 return None;
381 }
382 const HEADER_LEN: usize = 4;
383 let protocol_version = buffer[3];
384 if protocol_version != 1 {
385 return None;
386 }
387 let body = &buffer[HEADER_LEN..];
388 let ts = TimeSpec::from_buffer(body)?;
389 const TS_SIZE: usize = size_of::<TimeSpec>();
390 let payload = &body[TS_SIZE..];
391 let iter = layout
392 .description
393 .params
394 .iter()
395 .zip(layout.data_types.iter())
396 .scan(0usize, |cursor, (param, dt)| {
397 let size = param.size as usize;
398 let slice = &payload[*cursor..*cursor + size];
399 *cursor += size;
400 Some((dt, slice))
401 });
402 V::get_parameters(iter).ok().map(|v| (ts, v))
403}
404
405fn decode_flat<V>(layout: &GroupLayout, buffer: &[u8]) -> Option<(TimeSpec, Vec<V>)>
406where
407 V: GetParameterValue + Default,
408{
409 if buffer.is_empty() {
410 return None;
411 }
412 const HEADER_LEN: usize = 4;
413 let protocol_version = buffer[3];
414 if protocol_version != 1 {
415 return None;
416 }
417 let body = &buffer[HEADER_LEN..];
418 let ts = TimeSpec::from_buffer(body)?;
419 const TS_SIZE: usize = size_of::<TimeSpec>();
420 let payload = &body[TS_SIZE..];
421
422 let mut values = Vec::new();
423 let mut cursor = 0usize;
424 for (param, &data_type) in layout.description.params.iter().zip(layout.data_types.iter()) {
425 let size = param.size as usize;
426 let data_size = param.info.data_size as usize;
427 let n = param.info.number_of_elements as usize;
428 let bytes = &payload[cursor..cursor + size];
429 for i in 0..n {
430 let start = i * data_size;
431 let end = start + data_size;
432 values.push(decode_parameter_value::<V>(data_type, &bytes[start..end]));
433 }
434 cursor += size;
435 }
436 Some((ts, values))
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use crate::msg::{GroupParameterInfo, ParameterInfo, ParameterType, StatusCode};
443 use std::sync::Mutex;
444
445 fn param(path: &str, dtype: DataType, data_size: u32, n_elements: u32) -> GroupParameterInfo {
446 GroupParameterInfo {
447 index: 0,
448 offset: 0,
449 size: data_size * n_elements,
450 info: ParameterInfo {
451 id: 0,
452 data_type: dtype as u32,
453 data_size,
454 number_of_elements: n_elements,
455 flags: 0,
456 permissions: 0,
457 param_type: ParameterType::Parameter as i32,
458 group_id: 0,
459 unit: 0,
460 path: path.to_string(),
461 },
462 status: StatusCode::Ok as i32,
463 }
464 }
465
466 fn group(id: u32, alias: &str, params: Vec<GroupParameterInfo>) -> GroupStatusMsg {
467 GroupStatusMsg {
468 header: None,
469 id,
470 alias: alias.to_string(),
471 params,
472 status: StatusCode::Ok as i32,
473 }
474 }
475
476 fn protocol1(body: &[u8]) -> Vec<u8> {
477 let mut buf = vec![0u8, 0, 0, 1];
478 buf.extend_from_slice(&[0u8; 16]); buf.extend_from_slice(body);
480 buf
481 }
482
483 #[test]
484 fn id_and_name_reflect_group_msg() {
485 let sub = Subscription::new(group(7, "grp", vec![]), 1);
486 assert_eq!(sub.id(), 7);
487 assert_eq!(sub.name(), "grp");
488 }
489
490 #[test]
491 fn clone_is_shared() {
492 let sub = Subscription::new(group(1, "g", vec![]), 1);
493 let clone = sub.clone();
494 assert!(Arc::ptr_eq(&sub.inner, &clone.inner));
495 }
496
497 #[test]
498 fn read_returns_none_without_a_payload() {
499 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
500 assert!(sub.read::<f64>().is_none());
501 assert!(sub.read_all::<f64>().is_none());
502 }
503
504 #[test]
505 fn read_decodes_a_single_scalar_payload() {
506 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
507 sub.update(protocol1(&2.5_f64.to_le_bytes()));
508 let (_ts, value) = sub.read::<f64>().expect("decode ok");
509 assert_eq!(value, 2.5);
510 }
511
512 #[test]
513 fn read_all_decodes_flattened_array() {
514 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 3)]), 1);
515 let mut body = Vec::new();
516 body.extend_from_slice(&1.0f64.to_le_bytes());
517 body.extend_from_slice(&2.0f64.to_le_bytes());
518 body.extend_from_slice(&3.0f64.to_le_bytes());
519 sub.update(protocol1(&body));
520 let (_ts, values) = sub.read_all::<f64>().expect("decode");
521 assert_eq!(values, vec![1.0, 2.0, 3.0]);
522 }
523
524 #[test]
525 fn update_fires_the_callback() {
526 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
527 let hits = Arc::new(Mutex::new(0u32));
528 let counter = Arc::clone(&hits);
529 sub.notify(move |_| {
530 *counter.lock().unwrap() += 1;
531 });
532 sub.update(protocol1(&0f64.to_le_bytes()));
533 sub.update(protocol1(&0f64.to_le_bytes()));
534 assert_eq!(*hits.lock().unwrap(), 2);
535 }
536
537 #[test]
538 fn non_protocol_1_returns_none() {
539 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
540 let mut buf = vec![0u8, 0, 0, 0]; buf.extend_from_slice(&[0u8; 24]);
542 sub.update(buf);
543 assert!(sub.read::<f64>().is_none());
544 assert!(sub.read_all::<f64>().is_none());
545 }
546
547 #[tokio::test]
548 async fn latest_resolves_immediately_when_payload_already_present() {
549 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
550 sub.update(protocol1(&7.5f64.to_le_bytes()));
551 let (_ts, v) = sub.latest::<f64>().await.expect("decode ok");
552 assert_eq!(v, 7.5);
553 }
554
555 #[tokio::test]
556 async fn latest_waits_for_the_first_payload() {
557 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
558 let sub_for_push = sub.clone();
559 tokio::spawn(async move {
561 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
562 sub_for_push.update(protocol1(&42.0f64.to_le_bytes()));
563 });
564 let (_ts, v) = sub.latest::<f64>().await.expect("decode ok");
565 assert_eq!(v, 42.0);
566 }
567
568 #[tokio::test]
569 async fn stream_delivers_consecutive_payloads() {
570 use futures::StreamExt;
571
572 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
573 let mut stream = Box::pin(sub.stream::<f64>(16));
574
575 sub.update(protocol1(&1.0f64.to_le_bytes()));
577 sub.update(protocol1(&2.0f64.to_le_bytes()));
578 sub.update(protocol1(&3.0f64.to_le_bytes()));
579
580 for expected in [1.0, 2.0, 3.0f64] {
581 let item = tokio::time::timeout(std::time::Duration::from_millis(100), stream.next())
582 .await
583 .expect("stream must yield within 100 ms")
584 .expect("stream is not exhausted");
585 let (_ts, v) = item.expect("not lagged");
586 assert_eq!(v, expected);
587 }
588 }
589
590 #[tokio::test]
591 async fn stream_surfaces_lag_as_err() {
592 use futures::StreamExt;
593
594 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
596 let mut stream = Box::pin(sub.stream::<f64>(2));
597
598 for i in 0..8 {
600 sub.update(protocol1(&(i as f64).to_le_bytes()));
601 }
602
603 let mut saw_miss = false;
607 for _ in 0..8 {
608 let item = tokio::time::timeout(std::time::Duration::from_millis(100), stream.next())
609 .await
610 .expect("stream yields")
611 .expect("not exhausted");
612 if let Err(Missed(n)) = item {
613 assert!(n > 0, "Missed's inner count must be positive");
614 saw_miss = true;
615 break;
616 }
617 }
618 assert!(saw_miss, "expected to observe at least one Missed item");
619 }
620
621 #[tokio::test]
622 async fn stream_is_not_created_unless_requested() {
623 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
626 sub.update(protocol1(&1.0f64.to_le_bytes()));
627 assert!(sub.inner.broadcast.lock().unwrap().is_none());
628 }
629
630 #[test]
631 fn missed_formats_and_is_error() {
632 let m = Missed(7);
633 assert_eq!(m, Missed(7));
634 assert_eq!(format!("{m}"), "stream consumer missed 7 samples");
635 let _: &dyn std::error::Error = &m;
636 }
637
638 #[tokio::test]
639 async fn latest_errors_on_unsupported_protocol_version() {
640 let sub = Subscription::new(group(1, "g", vec![param("x", DataType::Double, 8, 1)]), 1);
641 let mut buf = vec![0u8, 0, 0, 7]; buf.extend_from_slice(&[0u8; 24]);
643 sub.update(buf);
644 let err = sub
645 .latest::<f64>()
646 .await
647 .expect_err("unsupported protocol must error");
648 assert!(matches!(err, MotorcortexError::Decode(_)));
649 }
650
651 #[test]
652 fn fdiv_and_paths_reflect_the_constructor_args() {
653 let params = vec![
654 param("root/a", DataType::Double, 8, 1),
655 param("root/b", DataType::Int32, 4, 1),
656 ];
657 let sub = Subscription::new(group(1, "g", params), 7);
658 assert_eq!(sub.fdiv(), 7);
659 assert_eq!(sub.paths(), vec!["root/a".to_string(), "root/b".to_string()]);
660 }
661
662 #[test]
663 fn rebind_swaps_id_and_layout() {
664 let sub = Subscription::new(
665 group(11, "grp", vec![param("root/a", DataType::Double, 8, 1)]),
666 1,
667 );
668 assert_eq!(sub.id(), 11);
669 assert_eq!(sub.paths(), vec!["root/a".to_string()]);
670
671 let new_group = group(
672 42,
673 "grp",
674 vec![
675 param("root/x", DataType::Double, 8, 1),
676 param("root/y", DataType::Int32, 4, 1),
677 ],
678 );
679 sub.rebind(new_group);
680 assert_eq!(sub.id(), 42);
681 assert_eq!(sub.paths(), vec!["root/x".to_string(), "root/y".to_string()]);
682 assert_eq!(sub.name(), "grp");
684 assert_eq!(sub.fdiv(), 1);
685 }
686
687 #[test]
688 fn rebind_is_visible_to_outstanding_clones() {
689 let sub = Subscription::new(
690 group(1, "g", vec![param("root/a", DataType::Double, 8, 1)]),
691 1,
692 );
693 let clone = sub.clone();
694 let new_group = group(99, "g", vec![param("root/b", DataType::Double, 8, 1)]);
695 sub.rebind(new_group);
696 assert_eq!(clone.id(), 99);
699 assert_eq!(clone.paths(), vec!["root/b".to_string()]);
700 }
701}