use std::time::Duration;
use crate::{
error::AUTDInternalError,
firmware::cpu::{check_if_msg_is_processed, RxMessage, TxDatagram},
geometry::Geometry,
};
#[cfg(feature = "async-trait")]
mod internal {
use super::*;
#[async_trait::async_trait]
pub trait Link: Send + Sync {
async fn close(&mut self) -> Result<(), AUTDInternalError>;
async fn send(&mut self, tx: &TxDatagram) -> Result<bool, AUTDInternalError>;
async fn receive(&mut self, rx: &mut [RxMessage]) -> Result<bool, AUTDInternalError>;
#[must_use]
fn is_open(&self) -> bool;
#[must_use]
fn timeout(&self) -> Duration;
#[inline(always)]
fn trace(&mut self, _: &TxDatagram, _: &mut [RxMessage], _: Option<Duration>) {}
}
#[async_trait::async_trait]
pub trait LinkBuilder: Send + Sync {
type L: Link;
async fn open(self, geometry: &Geometry) -> Result<Self::L, AUTDInternalError>;
}
#[async_trait::async_trait]
impl Link for Box<dyn Link> {
async fn close(&mut self) -> Result<(), AUTDInternalError> {
self.as_mut().close().await
}
async fn send(&mut self, tx: &TxDatagram) -> Result<bool, AUTDInternalError> {
self.as_mut().send(tx).await
}
async fn receive(&mut self, rx: &mut [RxMessage]) -> Result<bool, AUTDInternalError> {
self.as_mut().receive(rx).await
}
fn is_open(&self) -> bool {
self.as_ref().is_open()
}
fn timeout(&self) -> Duration {
self.as_ref().timeout()
}
#[inline(always)]
fn trace(&mut self, tx: &TxDatagram, rx: &mut [RxMessage], timeout: Option<Duration>) {
self.as_mut().trace(tx, rx, timeout)
}
}
}
#[cfg(not(feature = "async-trait"))]
mod internal {
use super::*;
pub trait Link: Send + Sync {
fn close(&mut self) -> impl std::future::Future<Output = Result<(), AUTDInternalError>>;
fn send(
&mut self,
tx: &TxDatagram,
) -> impl std::future::Future<Output = Result<bool, AUTDInternalError>>;
fn receive(
&mut self,
rx: &mut [RxMessage],
) -> impl std::future::Future<Output = Result<bool, AUTDInternalError>>;
#[must_use]
fn is_open(&self) -> bool;
#[must_use]
fn timeout(&self) -> Duration;
#[inline(always)]
fn trace(&mut self, _: &TxDatagram, _: &mut [RxMessage], _: Option<Duration>) {}
}
pub trait LinkBuilder {
type L: Link;
fn open(
self,
geometry: &Geometry,
) -> impl std::future::Future<Output = Result<Self::L, AUTDInternalError>>;
}
}
#[cfg(feature = "async-trait")]
pub use internal::Link;
#[cfg(feature = "async-trait")]
pub use internal::LinkBuilder;
#[cfg(not(feature = "async-trait"))]
pub use internal::Link;
#[cfg(not(feature = "async-trait"))]
pub use internal::LinkBuilder;
#[tracing::instrument(skip(link, tx, rx, timeout))]
pub async fn send_receive(
link: &mut impl Link,
tx: &TxDatagram,
rx: &mut [RxMessage],
timeout: Option<Duration>,
) -> Result<(), AUTDInternalError> {
link.trace(tx, rx, timeout);
let timeout = timeout.unwrap_or(link.timeout());
tracing::debug!("send with timeout: {:?}", timeout);
if !link.send(tx).await? {
return Err(AUTDInternalError::SendDataFailed);
}
wait_msg_processed(link, tx, rx, timeout).await
}
async fn wait_msg_processed(
link: &mut impl Link,
tx: &TxDatagram,
rx: &mut [RxMessage],
timeout: Duration,
) -> Result<(), AUTDInternalError> {
let start = std::time::Instant::now();
loop {
if link.receive(rx).await? && check_if_msg_is_processed(tx, rx).all(std::convert::identity)
{
return Ok(());
}
if start.elapsed() > timeout {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
rx.iter()
.try_fold((), |_, r| Result::<(), AUTDInternalError>::from(r))
.and_then(|_| {
if timeout == Duration::ZERO {
Ok(())
} else {
Err(AUTDInternalError::ConfirmResponseFailed)
}
})
}
#[cfg(test)]
mod tests {
use super::*;
struct MockLink {
pub is_open: bool,
pub timeout: Duration,
pub send_cnt: usize,
pub recv_cnt: usize,
pub down: bool,
}
impl Link for MockLink {
async fn close(&mut self) -> Result<(), AUTDInternalError> {
self.is_open = false;
Ok(())
}
async fn send(&mut self, _: &TxDatagram) -> Result<bool, AUTDInternalError> {
if !self.is_open {
return Err(AUTDInternalError::LinkClosed);
}
self.send_cnt += 1;
Ok(!self.down)
}
async fn receive(&mut self, rx: &mut [RxMessage]) -> Result<bool, AUTDInternalError> {
if !self.is_open {
return Err(AUTDInternalError::LinkClosed);
}
if self.recv_cnt > 10 {
return Err(AUTDInternalError::LinkError("too many".to_owned()));
}
self.recv_cnt += 1;
rx.iter_mut()
.for_each(|r| *r = RxMessage::new(self.recv_cnt as u8, r.data()));
Ok(!self.down)
}
fn is_open(&self) -> bool {
self.is_open
}
fn timeout(&self) -> Duration {
self.timeout
}
}
#[tokio::test]
async fn test_close() {
let mut link = MockLink {
is_open: true,
timeout: Duration::from_millis(0),
send_cnt: 0,
recv_cnt: 0,
down: false,
};
assert!(link.is_open());
link.close().await.unwrap();
assert!(!link.is_open());
}
#[tokio::test]
async fn test_send_receive() {
let mut link = MockLink {
is_open: true,
timeout: Duration::from_millis(0),
send_cnt: 0,
recv_cnt: 0,
down: false,
};
let tx = TxDatagram::new(0);
let mut rx = Vec::new();
assert_eq!(send_receive(&mut link, &tx, &mut rx, None).await, Ok(()));
link.is_open = false;
assert_eq!(
send_receive(&mut link, &tx, &mut rx, None).await,
Err(AUTDInternalError::LinkClosed)
);
link.is_open = true;
link.down = true;
assert_eq!(
send_receive(&mut link, &tx, &mut rx, None).await,
Err(AUTDInternalError::SendDataFailed)
);
link.down = false;
assert_eq!(
send_receive(&mut link, &tx, &mut rx, Some(Duration::from_millis(1))).await,
Ok(())
);
}
#[tokio::test]
async fn test_wait_msg_processed() {
let mut link = MockLink {
is_open: true,
timeout: Duration::from_millis(0),
send_cnt: 0,
recv_cnt: 0,
down: false,
};
let mut tx = TxDatagram::new(1);
tx[0].header.msg_id = 2;
let mut rx = vec![RxMessage::new(0, 0)];
assert_eq!(
wait_msg_processed(&mut link, &tx, &mut rx, Duration::from_millis(10)).await,
Ok(())
);
link.recv_cnt = 0;
link.is_open = false;
assert_eq!(
wait_msg_processed(&mut link, &tx, &mut rx, Duration::from_millis(10)).await,
Err(AUTDInternalError::LinkClosed)
);
link.recv_cnt = 0;
link.is_open = true;
link.down = true;
assert_eq!(
Err(AUTDInternalError::ConfirmResponseFailed),
wait_msg_processed(&mut link, &tx, &mut rx, Duration::from_millis(10)).await,
);
link.recv_cnt = 0;
link.is_open = true;
link.down = true;
assert_eq!(
Ok(()),
wait_msg_processed(&mut link, &tx, &mut rx, Duration::ZERO).await,
);
link.down = false;
link.recv_cnt = 0;
tx[0].header.msg_id = 20;
assert_eq!(
wait_msg_processed(&mut link, &tx, &mut rx, Duration::from_secs(10)).await,
Err(AUTDInternalError::LinkError("too many".to_owned()))
);
}
}