embedded-nal-tcpextensions 0.1.2

Extensions to embedded-nal's TCP parts around the more precise use of the stack's buffers
Documentation
use embedded_nal::nb;

use crate::TcpExactStack;

/// A wrapper around a TcpStack that provides TcpExactStack
///
/// The implementation is comparatively crude: there's just a per-socket buffer that data is copied
/// into on demand.
///
/// Using this is generally not recommended -- TCP stacks usually have the buffers in there
/// somewhere, and and "just" need to expose them.
pub struct BufferedStack<ST: embedded_nal::TcpClientStack, const N: usize>(ST);

impl<ST: embedded_nal::TcpClientStack, const N: usize> BufferedStack<ST, N> {
    pub fn new(wrapped: ST) -> Self {
        BufferedStack(wrapped)
    }

    /// Attempt sending any content of the buffer, returning Ok only if the buffer is now empty.
    fn try_flush_sendbuffer(
        &mut self,
        socket: &mut <Self as embedded_nal::TcpClientStack>::TcpSocket,
    ) -> Result<(), embedded_nal::nb::Error<<Self as embedded_nal::TcpClientStack>::Error>> {
        if !socket.sendbuf.is_empty() {
            match self.0.send(&mut socket.socket, &socket.sendbuf) {
                // Both WouldBlock and actual errors go out here. Actual errors are a bit late
                // (given the send_all already returned successfully), but then again, this could
                // just as well have happened while things are in the OS's buffer.
                Err(e) => Err(e),
                // All flushed, we can go on
                Ok(n) if n == socket.sendbuf.len() => Ok(socket.sendbuf.clear()),
                Ok(n) => {
                    socket.sendbuf.copy_within(n.., 0);
                    socket.sendbuf.truncate(socket.sendbuf.len() - n);
                    Err(embedded_nal::nb::Error::WouldBlock)
                }
            }
        } else {
            Ok(())
        }
    }
}

/// Socket wrapper for BufferedStack
// For the server socket (which accepts), the buffer is useless -- too bad the TCP socket API
// doesn't have types for different roles.
pub struct BufferedSocket<SO, const N: usize> {
    socket: SO,
    recvbuf: heapless::Vec<u8, N>,
    sendbuf: heapless::Vec<u8, N>,
}

impl<ST: embedded_nal::TcpFullStack, const N: usize> embedded_nal::TcpFullStack
    for BufferedStack<ST, N>
{
    fn bind(&mut self, socket: &mut Self::TcpSocket, port: u16) -> Result<(), Self::Error> {
        self.0.bind(&mut socket.socket, port)
    }
    fn listen(&mut self, socket: &mut Self::TcpSocket) -> Result<(), Self::Error> {
        self.0.listen(&mut socket.socket)
    }
    fn accept(
        &mut self,
        socket: &mut Self::TcpSocket,
    ) -> Result<(Self::TcpSocket, embedded_nal::SocketAddr), embedded_nal::nb::Error<Self::Error>>
    {
        self.0.accept(&mut socket.socket).map(|(socket, addr)| {
            (
                BufferedSocket {
                    socket,
                    recvbuf: Default::default(),
                    sendbuf: Default::default(),
                },
                addr,
            )
        })
    }
}

impl<ST: embedded_nal::TcpClientStack, const N: usize> embedded_nal::TcpClientStack
    for BufferedStack<ST, N>
{
    type TcpSocket = BufferedSocket<ST::TcpSocket, N>;
    type Error = ST::Error;

    fn socket(&mut self) -> Result<Self::TcpSocket, Self::Error> {
        Ok(BufferedSocket {
            socket: self.0.socket()?,
            recvbuf: Default::default(),
            sendbuf: Default::default(),
        })
    }
    fn connect(
        &mut self,
        socket: &mut Self::TcpSocket,
        addr: embedded_nal::SocketAddr,
    ) -> Result<(), embedded_nal::nb::Error<Self::Error>> {
        self.0.connect(&mut socket.socket, addr)
    }
    fn is_connected(&mut self, socket: &Self::TcpSocket) -> Result<bool, Self::Error> {
        self.0.is_connected(&socket.socket)
    }
    fn send(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &[u8],
    ) -> Result<usize, embedded_nal::nb::Error<Self::Error>> {
        // First, send out anything that is enqueued
        self.try_flush_sendbuffer(socket)?;

        assert!(socket.sendbuf.is_empty());
        self.0.send(&mut socket.socket, buffer)
    }
    fn receive(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &mut [u8],
    ) -> Result<usize, embedded_nal::nb::Error<Self::Error>> {
        // There is no task that'd flush the buffer out, so we depend on something to make
        // progress. The read is definitely the best candidate to make that.
        match self.try_flush_sendbuffer(socket) {
            Ok(()) => (),
            // Maybe we made progress, maybe not -- but anyway we tried, and that's all that
            // matters in the receive path. No need to stop receiving just because we have a full
            // send buffer.
            Err(nb::Error::WouldBlock) => (),
            // Ensure the error isn't lost. This may not be 100% precise in half-open connections,
            // but I doubt embedded-nal aims to support them. (If we'd want to, we'd need to mark
            // the send buffer as having erred).
            Err(e) => return Err(e),
        };

        match socket.recvbuf.len() {
            // The common case
            0 => self.0.receive(&mut socket.socket, buffer),
            // The easy case (sure we could try to receive more, but it's TCP and prepared to get
            // data piecemeal, so just eat it as it is)
            present if present >= buffer.len() => {
                buffer[..present].copy_from_slice(&socket.recvbuf);
                socket.recvbuf.clear();
                Ok(present)
            }
            // The tricky case: Even when reading this there's still data left over. This only
            // happens if a long and incomplete read_exactly is followed by a short read. Still
            // needs to be implemented...
            present => {
                buffer.copy_from_slice(&socket.recvbuf[..buffer.len()]);
                socket.recvbuf.copy_within(buffer.len().., 0);
                socket.recvbuf.truncate(present - buffer.len());
                Ok(buffer.len())
            }
        }
    }
    fn close(&mut self, mut socket: Self::TcpSocket) -> Result<(), Self::Error> {
        match self.try_flush_sendbuffer(&mut socket) {
            Ok(()) => (),
            // As close can't WouldBlock, it would appear that not having sent some data is
            // considered acceptable in embedded-nal
            Err(nb::Error::WouldBlock) => (),
            // ... and then it's just logical that errors from there are discarded too.
            Err(nb::Error::Other(_)) => (),
        }
        self.0.close(socket.socket)
    }
}

impl<ST: embedded_nal::TcpClientStack, const N: usize> TcpExactStack
    for BufferedStack<ST, N>
{
    const RECVBUFLEN: usize = N;

    const SENDBUFLEN: usize = N;

    fn receive_exact(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &mut [u8],
    ) -> nb::Result<(), Self::Error> {
        let len_start = socket.recvbuf.len();
        let missing = buffer.len().checked_sub(len_start);

        if let Some(missing) = missing {
            if missing > 0 {
                // unsafe: All u8 values are valid.
                //
                // The safe alternative would be `socket.recvbuf.resize_default(buffer.len());`,
                // which needlessly zeroes out text.
                //
                // There are proposals out there on how to do these things more elegantly, but
                // AFAICT they're not done yet (and I can't look it up right now).
                unsafe {
                    socket.recvbuf.set_len(buffer.len());
                }
                // Note: This panics at the bounds check when too much is asked.
                let received = self.0.receive(
                    &mut socket.socket,
                    &mut socket.recvbuf[len_start..buffer.len()],
                )?;
                socket.recvbuf.truncate(len_start + received);
            }
        }

        if socket.recvbuf.len() >= buffer.len() {
            // It *can* be greater than, if receive_exact was incompletely called earlier; receive
            // already handles the back-rotation of any leftovers, and is guaranteed to succeed in
            // this case.
            use embedded_nal::TcpClientStack;
            self.receive(socket, buffer).map(|_| ())
        } else {
            Err(nb::Error::WouldBlock)
        }
    }

    fn send_all(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &[u8],
    ) -> Result<(), embedded_nal::nb::Error<Self::Error>> {
        use embedded_nal::TcpClientStack;

        match self.send(socket, buffer) {
            Err(e) => Err(e),
            Ok(n) if n == buffer.len() => Ok(()),
            Ok(n) => {
                assert!(
                    socket.sendbuf.is_empty(),
                    "Internal post-condition of send() violated"
                );
                socket
                    .sendbuf
                    .extend_from_slice(&buffer[n..])
                    .expect("Send leftovers exceed buffer announced in SENDBUFLEN");

                Err(embedded_nal::nb::Error::WouldBlock)
            }
        }
    }
}