keebrs 0.3.0

Keyboard firmware building blocks
Documentation
use core::{
    cmp::Ordering,
    fmt::Debug,
    marker::PhantomData,
    pin::Pin,
    task::{
        Context,
        Poll,
    },
};

#[allow(unused_imports)]
use log::info;

use futures::{
    prelude::*,
    ready,
};

use serde::{
    de::DeserializeOwned,
    Deserialize,
    Serialize,
};

use pin_project::pin_project;

use embrio_core::io::{
    Read,
    Write,
};

use async_codec::{
    Framed,
    ReadFrameError,
    WriteFrameError,
};

use crate::{
    codec::PostCardCodec,
    net::{
        Msg,
        Net,
    },
};

#[derive(Deserialize, Serialize, Eq, PartialEq, Debug, Clone, Copy)]
pub enum SetupDir {
    Left,
    Right,
}

#[pin_project(project = ChainNetProj)]
pub struct ChainNet<T, LW, LR, RW, RR> {
    #[pin]
    rx_r: Framed<RR, PostCardCodec<Msg<T>>, [u8; 32]>,
    #[pin]
    tx_r: Framed<RW, PostCardCodec<Msg<T>>, [u8; 32]>,
    #[pin]
    rx_l: Framed<LR, PostCardCodec<Msg<T>>, [u8; 32]>,
    #[pin]
    tx_l: Framed<LW, PostCardCodec<Msg<T>>, [u8; 32]>,
    id: i8,
    current: Option<Msg<T>>,
    _ph: PhantomData<fn() -> T>,
}

pub enum ChainError<WE, RE> {
    Write(WriteFrameError<WE, postcard::Error>),
    Read(ReadFrameError<RE, postcard::Error>),
}

impl<T, LW, LR, RW, RR> ChainNet<T, LW, LR, RW, RR> {
    pub fn new(left: (LW, LR), right: (RW, RR)) -> Self {
        Self {
            tx_l: PostCardCodec::framed(left.0),
            rx_l: PostCardCodec::framed(left.1),
            tx_r: PostCardCodec::framed(right.0),
            rx_r: PostCardCodec::framed(right.1),
            id: 0,
            current: None,
            _ph: Default::default(),
        }
    }
}

impl<'a, T, LW, LR, RW, RR> ChainNetProj<'a, T, LW, LR, RW, RR>
where
    T: Serialize + Clone,
    LW: Write,
    RW: Write<Error = LW::Error>,
{
    fn poll_empty_buffer(
        &mut self,
        cx: &mut Context,
    ) -> Poll<Result<(), WriteFrameError<LW::Error, postcard::Error>>> {
        if self.current.is_none() {
            return Poll::Ready(Ok(()));
        }

        let id = *self.id;

        // self.id vs destination
        let direction = match self.current.as_ref().unwrap() {
            Msg::Broadcast { to, from, .. } => from.cmp(&to),
            Msg::Direct { to, .. } => id.cmp(&to),
        };

        // If our id is greater than the destination, send it left. If less, send
        // right.
        // If equal, then it's *our* broadcast, and needs to go both directions.
        let (send_left, send_right) = match direction {
            Ordering::Greater => (true, false),
            Ordering::Less => (false, true),
            Ordering::Equal => (true, true),
        };

        if send_left {
            ready!(self.tx_l.as_mut().poll_ready(cx))?;
        }
        if send_right {
            ready!(self.tx_r.as_mut().poll_ready(cx))?;
        }

        let mut data = self.current.take().unwrap();

        if send_left {
            let mut data = data.clone();
            adjust_broadcast(&mut data, -1);

            self.tx_l.as_mut().start_send(data.clone())?;
        }
        if send_right {
            adjust_broadcast(&mut data, 1);
            self.tx_r.as_mut().start_send(data)?;
        }

        Poll::Ready(Ok(()))
    }
}

fn adjust_broadcast<T>(msg: &mut Msg<T>, adj: i8) {
    match msg {
        Msg::Broadcast { ref mut to, .. } => {
            *to += adj;
        }
        _ => {}
    }
}

impl<'a, T, LW, LR, RW, RR> ChainNetProj<'a, T, LW, LR, RW, RR>
where
    T: DeserializeOwned,
    LR: Read,
    RR: Read<Error = LR::Error>,
{
    #[allow(clippy::type_complexity)]
    fn poll_readers(
        &mut self,
        cx: &mut Context,
    ) -> Poll<Option<Result<Msg<T>, ReadFrameError<LR::Error, postcard::Error>>>> {
        let left_empty = match self.rx_l.as_mut().try_poll_next(cx) {
            left @ Poll::Ready(Some(_)) => return left,
            Poll::Ready(None) => true,
            _ => false,
        };
        match ready!(self.rx_r.as_mut().try_poll_next(cx)) {
            right @ Some(_) => Poll::Ready(right),
            None => {
                if left_empty {
                    Poll::Ready(None)
                } else {
                    Poll::Pending
                }
            }
        }
    }
}

impl<T, LW, LR, RW, RR> Sink<Msg<T>> for ChainNet<T, LW, LR, RW, RR>
where
    LW: Write,
    RW: Write<Error = LW::Error>,
    T: Serialize + Clone,
{
    type Error = WriteFrameError<LW::Error, postcard::Error>;

    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.project().poll_empty_buffer(cx)
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        let mut this = self.project();

        ready!(this.poll_empty_buffer(cx))?;
        ready!(this.tx_l.poll_flush(cx))?;
        ready!(this.tx_r.poll_flush(cx))?;

        Poll::Ready(Ok(()))
    }
    fn start_send(self: Pin<&mut Self>, mut item: Msg<T>) -> Result<(), Self::Error> {
        let this = self.project();
        assert!(this.current.is_none());
        item.set_from(*this.id);
        *this.current = Some(item.into());
        Ok(())
    }
    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }
}

impl<T, LW, LR, RW, RR> Stream for ChainNet<T, LW, LR, RW, RR>
where
    T: Serialize + DeserializeOwned + Clone + Debug,
    LW: Write,
    LR: Read,
    RW: Write<Error = LW::Error>,
    RR: Read<Error = LR::Error>,
{
    type Item = Result<Msg<T>, ChainError<LW::Error, LR::Error>>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
        let mut this = self.project();

        loop {
            // The next message from the readers might not be for this board, so
            // make sure we have a place to put passthrough messages.
            if ready!(this.poll_empty_buffer(cx)).is_err() {
                return Poll::Ready(None);
            }

            let next = match ready!(this.poll_readers(cx)) {
                None => return Poll::Ready(None),
                Some(Ok(item)) => item,
                Some(Err(e)) => return Poll::Ready(Some(Err(ChainError::Read(e)))),
            };

            match next {
                Msg::Broadcast { to, .. } => {
                    *this.id = to;
                    *this.current = Some(next.clone());
                    return Poll::Ready(Some(Ok(next.into())));
                }
                Msg::Direct { to, .. } => {
                    if to == *this.id {
                        return Poll::Ready(Some(Ok(next.into())));
                    } else {
                        *this.current = Some(next);
                    }
                }
            }
        }
    }
}

impl<T, LW, LR, RW, RR> Net<T> for ChainNet<T, LW, LR, RW, RR>
where
    T: Serialize + DeserializeOwned + Clone + Debug,
    LW: Write,
    LR: Read,
    RW: Write<Error = LW::Error>,
    RR: Read<Error = LR::Error>,
{
    fn addr(&self) -> i8 {
        self.id
    }
}