use crate::command_types::Command;
use crate::device_address::DeviceAddress;
use moteus_protocol::CanFdFrame;
use std::sync::{Arc, Mutex};
#[derive(Clone, Default)]
pub struct ResponseCollector(Arc<Mutex<Vec<CanFdFrame>>>);
impl ResponseCollector {
pub fn new() -> Self {
Self(Arc::new(Mutex::new(Vec::new())))
}
pub fn push(&self, frame: CanFdFrame) {
if let Ok(mut guard) = self.0.lock() {
guard.push(frame);
}
}
pub fn take(&self) -> Vec<CanFdFrame> {
if let Ok(mut guard) = self.0.lock() {
std::mem::take(&mut *guard)
} else {
Vec::new()
}
}
pub fn len(&self) -> usize {
self.0.lock().map(|g| g.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn peek(&self) -> Vec<CanFdFrame> {
if let Ok(guard) = self.0.lock() {
guard.clone()
} else {
Vec::new()
}
}
pub fn clear(&self) {
if let Ok(mut guard) = self.0.lock() {
guard.clear();
}
}
}
impl std::fmt::Debug for ResponseCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let len = self.len();
f.debug_struct("ResponseCollector")
.field("len", &len)
.finish()
}
}
#[non_exhaustive]
#[derive(Clone, Default)]
pub enum FrameFilter {
BySource(u8),
#[default]
Any,
Custom(Arc<dyn Fn(&CanFdFrame) -> bool + Send + Sync>),
}
impl FrameFilter {
pub fn by_source(source_id: u8) -> Self {
FrameFilter::BySource(source_id)
}
pub fn any() -> Self {
FrameFilter::Any
}
pub fn custom<F>(f: F) -> Self
where
F: Fn(&CanFdFrame) -> bool + Send + Sync + 'static,
{
FrameFilter::Custom(Arc::new(f))
}
pub fn matches(&self, frame: &CanFdFrame) -> bool {
match self {
FrameFilter::BySource(source) => {
let frame_source = ((frame.arbitration_id >> 8) & 0x7F) as u8;
frame_source == *source
}
FrameFilter::Any => true,
FrameFilter::Custom(f) => f(frame),
}
}
}
impl std::fmt::Debug for FrameFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FrameFilter::BySource(id) => write!(f, "BySource({})", id),
FrameFilter::Any => write!(f, "Any"),
FrameFilter::Custom(_) => write!(f, "Custom(...)"),
}
}
}
#[non_exhaustive]
#[derive(Clone)]
pub struct Request {
pub frame: Option<CanFdFrame>,
pub channel: Option<usize>,
pub filter: FrameFilter,
pub expected_reply_count: u8,
pub expected_reply_size: u8,
pub child_device: Option<usize>,
pub address: Option<DeviceAddress>,
pub responses: ResponseCollector,
}
impl Request {
pub fn new(frame: CanFdFrame) -> Self {
let dest = (frame.arbitration_id & 0x7F) as u8;
let filter = if dest == 0x7F {
FrameFilter::Any
} else {
FrameFilter::BySource(dest)
};
let expected = if frame.arbitration_id & 0x8000 != 0 {
1
} else {
0
};
Self {
frame: Some(frame),
channel: None,
filter,
expected_reply_count: expected,
expected_reply_size: 0,
child_device: None,
address: None,
responses: ResponseCollector::new(),
}
}
pub fn from_command(cmd: Command) -> Self {
let channel = cmd.channel;
let address = cmd.address.clone();
let expected_reply_size = cmd.expected_reply_size;
if !cmd.reply_required {
let frame = cmd.into_frame();
return Self {
frame: Some(frame),
channel,
filter: FrameFilter::Any,
expected_reply_count: 0,
expected_reply_size,
child_device: None,
address,
responses: ResponseCollector::new(),
};
}
if cmd.raw {
let frame = cmd.into_frame();
return Self {
frame: Some(frame),
channel,
filter: FrameFilter::Any,
expected_reply_count: 1,
expected_reply_size,
child_device: None,
address,
responses: ResponseCollector::new(),
};
}
let dest_id = (cmd.destination as u8) & 0x7F;
let source_id = (cmd.source as u8) & 0x7F;
let prefix = cmd.can_prefix & 0x1FFF;
let reply_filter = cmd.reply_filter.clone();
let frame = cmd.into_frame();
let filter = FrameFilter::custom(move |f| {
if ((f.arbitration_id >> 16) & 0x1FFF) as u16 != prefix {
return false;
}
if dest_id != 0x7F {
let frame_source = ((f.arbitration_id >> 8) & 0x7F) as u8;
if frame_source != dest_id {
return false;
}
}
let frame_dest = (f.arbitration_id & 0x7F) as u8;
if frame_dest != source_id {
return false;
}
if let Some(ref rf) = reply_filter {
if !rf.matches(f) {
return false;
}
}
true
});
Self {
frame: Some(frame),
channel,
filter,
expected_reply_count: 1,
expected_reply_size,
child_device: None,
address,
responses: ResponseCollector::new(),
}
}
pub fn receive_only(filter: FrameFilter) -> Self {
Self {
frame: None,
channel: None,
filter,
expected_reply_count: 1,
expected_reply_size: 0,
child_device: None,
address: None,
responses: ResponseCollector::new(),
}
}
#[must_use]
pub fn with_channel(mut self, idx: usize) -> Self {
self.channel = Some(idx);
self
}
#[must_use]
pub fn with_filter(mut self, f: FrameFilter) -> Self {
self.filter = f;
self
}
#[must_use]
pub fn with_expected_replies(mut self, n: u8) -> Self {
self.expected_reply_count = n;
self
}
pub fn has_frame(&self) -> bool {
self.frame.is_some()
}
pub fn total_expected_replies(requests: &[Request]) -> usize {
requests
.iter()
.map(|r| r.expected_reply_count as usize)
.sum()
}
}
impl std::fmt::Debug for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Request")
.field("has_frame", &self.frame.is_some())
.field("channel", &self.channel)
.field("address", &self.address)
.field("filter", &self.filter)
.field("expected_replies", &self.expected_reply_count)
.field("responses", &self.responses)
.finish()
}
}
pub fn dispatch_frame(frame: &CanFdFrame, requests: &[Request]) -> bool {
for req in requests {
if req.filter.matches(frame) {
req.responses.push(frame.clone());
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_response_collector_new() {
let collector = ResponseCollector::new();
assert!(collector.is_empty());
assert_eq!(collector.len(), 0);
}
#[test]
fn test_response_collector_push_and_take() {
let collector = ResponseCollector::new();
let mut frame1 = CanFdFrame::new();
frame1.arbitration_id = 0x100;
let mut frame2 = CanFdFrame::new();
frame2.arbitration_id = 0x200;
collector.push(frame1);
collector.push(frame2);
assert_eq!(collector.len(), 2);
assert!(!collector.is_empty());
let frames = collector.take();
assert_eq!(frames.len(), 2);
assert_eq!(frames[0].arbitration_id, 0x100);
assert_eq!(frames[1].arbitration_id, 0x200);
assert!(collector.is_empty());
}
#[test]
fn test_response_collector_clone() {
let collector = ResponseCollector::new();
let mut frame = CanFdFrame::new();
frame.arbitration_id = 0x100;
collector.push(frame);
let clone = collector.clone();
assert_eq!(clone.len(), 1);
let mut frame2 = CanFdFrame::new();
frame2.arbitration_id = 0x200;
clone.push(frame2);
assert_eq!(collector.len(), 2);
assert_eq!(clone.len(), 2);
}
#[test]
fn test_response_collector_peek() {
let collector = ResponseCollector::new();
let mut frame = CanFdFrame::new();
frame.arbitration_id = 0x100;
collector.push(frame);
let frames = collector.peek();
assert_eq!(frames.len(), 1);
assert_eq!(collector.len(), 1);
}
#[test]
fn test_response_collector_clear() {
let collector = ResponseCollector::new();
collector.push(CanFdFrame::new());
collector.push(CanFdFrame::new());
assert_eq!(collector.len(), 2);
collector.clear();
assert!(collector.is_empty());
}
#[test]
fn test_frame_filter_by_source() {
let filter = FrameFilter::BySource(5);
let mut matching = CanFdFrame::new();
matching.arbitration_id = 0x0500;
let mut non_matching = CanFdFrame::new();
non_matching.arbitration_id = 0x0600;
assert!(filter.matches(&matching));
assert!(!filter.matches(&non_matching));
}
#[test]
fn test_frame_filter_any() {
let filter = FrameFilter::Any;
let mut frame1 = CanFdFrame::new();
frame1.arbitration_id = 0x100;
let mut frame2 = CanFdFrame::new();
frame2.arbitration_id = 0x200;
assert!(filter.matches(&frame1));
assert!(filter.matches(&frame2));
}
#[test]
fn test_frame_filter_custom() {
let filter = FrameFilter::custom(|f| f.arbitration_id % 2 == 0);
let mut even = CanFdFrame::new();
even.arbitration_id = 0x100;
let mut odd = CanFdFrame::new();
odd.arbitration_id = 0x101;
assert!(filter.matches(&even));
assert!(!filter.matches(&odd));
}
#[test]
fn test_request_new() {
let mut frame = CanFdFrame::new();
frame.arbitration_id = moteus_protocol::calculate_arbitration_id(0, 5, 0, true);
let request = Request::new(frame);
assert!(request.has_frame());
assert!(request.channel.is_none());
assert_eq!(request.expected_reply_count, 1);
assert!(request.responses.is_empty());
let mut response = CanFdFrame::new();
response.arbitration_id = 0x0500; assert!(request.filter.matches(&response));
}
#[test]
fn test_request_receive_only() {
let request = Request::receive_only(FrameFilter::BySource(3));
assert!(!request.has_frame());
assert_eq!(request.expected_reply_count, 1);
}
#[test]
fn test_request_builder_pattern() {
let mut frame = CanFdFrame::new();
frame.arbitration_id = moteus_protocol::calculate_arbitration_id(0, 1, 0, true);
let request = Request::new(frame)
.with_channel(2)
.with_filter(FrameFilter::Any)
.with_expected_replies(3);
assert_eq!(request.channel, Some(2));
assert_eq!(request.expected_reply_count, 3);
assert!(matches!(request.filter, FrameFilter::Any));
}
#[test]
fn test_total_expected_replies() {
let requests = vec![
Request::receive_only(FrameFilter::Any).with_expected_replies(2),
Request::receive_only(FrameFilter::Any).with_expected_replies(3),
Request::receive_only(FrameFilter::Any).with_expected_replies(1),
];
assert_eq!(Request::total_expected_replies(&requests), 6);
}
#[test]
fn test_dispatch_frame() {
let req1 = Request::receive_only(FrameFilter::BySource(1));
let req2 = Request::receive_only(FrameFilter::BySource(2));
let requests = vec![req1.clone(), req2.clone()];
let mut frame1 = CanFdFrame::new();
frame1.arbitration_id = 0x0100;
let mut frame2 = CanFdFrame::new();
frame2.arbitration_id = 0x0200;
let mut frame3 = CanFdFrame::new();
frame3.arbitration_id = 0x0300;
assert!(dispatch_frame(&frame1, &requests));
assert!(dispatch_frame(&frame2, &requests));
assert!(!dispatch_frame(&frame3, &requests));
assert_eq!(requests[0].responses.len(), 1);
assert_eq!(requests[1].responses.len(), 1);
}
#[test]
fn test_frame_filter_debug() {
let f1 = FrameFilter::BySource(5);
let f2 = FrameFilter::Any;
let f3 = FrameFilter::custom(|_| true);
assert_eq!(format!("{:?}", f1), "BySource(5)");
assert_eq!(format!("{:?}", f2), "Any");
assert_eq!(format!("{:?}", f3), "Custom(...)");
}
#[test]
fn test_request_no_reply_expected() {
let mut frame = CanFdFrame::new();
frame.arbitration_id = moteus_protocol::calculate_arbitration_id(0, 1, 0, false);
let request = Request::new(frame);
assert_eq!(request.expected_reply_count, 0);
}
#[test]
fn test_response_collector_thread_safety() {
use std::thread;
let collector = ResponseCollector::new();
let collector_clone = collector.clone();
let handle = thread::spawn(move || {
for i in 0..100 {
let mut frame = CanFdFrame::new();
frame.arbitration_id = i;
collector_clone.push(frame);
}
});
for i in 100..200 {
let mut frame = CanFdFrame::new();
frame.arbitration_id = i;
collector.push(frame);
}
handle.join().unwrap();
assert_eq!(collector.len(), 200);
}
}