use core::{
cmp::Ordering,
fmt::Debug,
marker::PhantomData,
pin::Pin,
task::{
Context,
Poll,
},
};
#[allow(unused_imports)]
use log::info;
use futures::{
prelude::*,
ready,
};
use serde::{
de::DeserializeOwned,
Deserialize,
Serialize,
};
use pin_project::pin_project;
use embrio_core::io::{
Read,
Write,
};
use async_codec::{
Framed,
ReadFrameError,
WriteFrameError,
};
use crate::{
codec::PostCardCodec,
net::{
Msg,
Net,
},
};
#[derive(Deserialize, Serialize, Eq, PartialEq, Debug, Clone, Copy)]
pub enum SetupDir {
Left,
Right,
}
#[pin_project(project = ChainNetProj)]
pub struct ChainNet<T, LW, LR, RW, RR> {
#[pin]
rx_r: Framed<RR, PostCardCodec<Msg<T>>, [u8; 32]>,
#[pin]
tx_r: Framed<RW, PostCardCodec<Msg<T>>, [u8; 32]>,
#[pin]
rx_l: Framed<LR, PostCardCodec<Msg<T>>, [u8; 32]>,
#[pin]
tx_l: Framed<LW, PostCardCodec<Msg<T>>, [u8; 32]>,
id: i8,
current: Option<Msg<T>>,
_ph: PhantomData<fn() -> T>,
}
pub enum ChainError<WE, RE> {
Write(WriteFrameError<WE, postcard::Error>),
Read(ReadFrameError<RE, postcard::Error>),
}
impl<T, LW, LR, RW, RR> ChainNet<T, LW, LR, RW, RR> {
pub fn new(left: (LW, LR), right: (RW, RR)) -> Self {
Self {
tx_l: PostCardCodec::framed(left.0),
rx_l: PostCardCodec::framed(left.1),
tx_r: PostCardCodec::framed(right.0),
rx_r: PostCardCodec::framed(right.1),
id: 0,
current: None,
_ph: Default::default(),
}
}
}
impl<'a, T, LW, LR, RW, RR> ChainNetProj<'a, T, LW, LR, RW, RR>
where
T: Serialize + Clone,
LW: Write,
RW: Write<Error = LW::Error>,
{
fn poll_empty_buffer(
&mut self,
cx: &mut Context,
) -> Poll<Result<(), WriteFrameError<LW::Error, postcard::Error>>> {
if self.current.is_none() {
return Poll::Ready(Ok(()));
}
let id = *self.id;
let direction = match self.current.as_ref().unwrap() {
Msg::Broadcast { to, from, .. } => from.cmp(&to),
Msg::Direct { to, .. } => id.cmp(&to),
};
let (send_left, send_right) = match direction {
Ordering::Greater => (true, false),
Ordering::Less => (false, true),
Ordering::Equal => (true, true),
};
if send_left {
ready!(self.tx_l.as_mut().poll_ready(cx))?;
}
if send_right {
ready!(self.tx_r.as_mut().poll_ready(cx))?;
}
let mut data = self.current.take().unwrap();
if send_left {
let mut data = data.clone();
adjust_broadcast(&mut data, -1);
self.tx_l.as_mut().start_send(data.clone())?;
}
if send_right {
adjust_broadcast(&mut data, 1);
self.tx_r.as_mut().start_send(data)?;
}
Poll::Ready(Ok(()))
}
}
fn adjust_broadcast<T>(msg: &mut Msg<T>, adj: i8) {
match msg {
Msg::Broadcast { ref mut to, .. } => {
*to += adj;
}
_ => {}
}
}
impl<'a, T, LW, LR, RW, RR> ChainNetProj<'a, T, LW, LR, RW, RR>
where
T: DeserializeOwned,
LR: Read,
RR: Read<Error = LR::Error>,
{
#[allow(clippy::type_complexity)]
fn poll_readers(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<Msg<T>, ReadFrameError<LR::Error, postcard::Error>>>> {
let left_empty = match self.rx_l.as_mut().try_poll_next(cx) {
left @ Poll::Ready(Some(_)) => return left,
Poll::Ready(None) => true,
_ => false,
};
match ready!(self.rx_r.as_mut().try_poll_next(cx)) {
right @ Some(_) => Poll::Ready(right),
None => {
if left_empty {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
}
impl<T, LW, LR, RW, RR> Sink<Msg<T>> for ChainNet<T, LW, LR, RW, RR>
where
LW: Write,
RW: Write<Error = LW::Error>,
T: Serialize + Clone,
{
type Error = WriteFrameError<LW::Error, postcard::Error>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().poll_empty_buffer(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
ready!(this.poll_empty_buffer(cx))?;
ready!(this.tx_l.poll_flush(cx))?;
ready!(this.tx_r.poll_flush(cx))?;
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, mut item: Msg<T>) -> Result<(), Self::Error> {
let this = self.project();
assert!(this.current.is_none());
item.set_from(*this.id);
*this.current = Some(item.into());
Ok(())
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl<T, LW, LR, RW, RR> Stream for ChainNet<T, LW, LR, RW, RR>
where
T: Serialize + DeserializeOwned + Clone + Debug,
LW: Write,
LR: Read,
RW: Write<Error = LW::Error>,
RR: Read<Error = LR::Error>,
{
type Item = Result<Msg<T>, ChainError<LW::Error, LR::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
if ready!(this.poll_empty_buffer(cx)).is_err() {
return Poll::Ready(None);
}
let next = match ready!(this.poll_readers(cx)) {
None => return Poll::Ready(None),
Some(Ok(item)) => item,
Some(Err(e)) => return Poll::Ready(Some(Err(ChainError::Read(e)))),
};
match next {
Msg::Broadcast { to, .. } => {
*this.id = to;
*this.current = Some(next.clone());
return Poll::Ready(Some(Ok(next.into())));
}
Msg::Direct { to, .. } => {
if to == *this.id {
return Poll::Ready(Some(Ok(next.into())));
} else {
*this.current = Some(next);
}
}
}
}
}
}
impl<T, LW, LR, RW, RR> Net<T> for ChainNet<T, LW, LR, RW, RR>
where
T: Serialize + DeserializeOwned + Clone + Debug,
LW: Write,
LR: Read,
RW: Write<Error = LW::Error>,
RR: Read<Error = LR::Error>,
{
fn addr(&self) -> i8 {
self.id
}
}