rns-embedded-runtime 0.5.1

Runtime support layer for embedded Reticulum transports and alloc-backed targets.
Documentation
use rns_embedded_core::{
    packet::{decode_frame, encode_frame, PacketFrame},
    transport::{EmbeddedTransport, LinkState, TransportCaps},
    EmbeddedError, EmbeddedResult,
};
use std::io::{ErrorKind, Read, Write};
use std::net::{SocketAddr, TcpStream};

const LENGTH_PREFIX_LEN: usize = 2;

pub struct TcpEmbeddedTransport {
    stream: TcpStream,
    state: LinkState,
    caps: TransportCaps,
    recv_buf: Vec<u8>,
}

impl TcpEmbeddedTransport {
    pub fn connect(addr: SocketAddr, mtu_hint: u16) -> EmbeddedResult<Self> {
        let stream = TcpStream::connect(addr).map_err(map_connect_error)?;
        Self::from_stream(stream, mtu_hint)
    }

    pub fn from_stream(stream: TcpStream, mtu_hint: u16) -> EmbeddedResult<Self> {
        stream.set_nonblocking(true).map_err(|_| EmbeddedError::InvalidState)?;
        Ok(Self {
            stream,
            state: LinkState::Up,
            caps: TransportCaps { mtu_hint, ordered_delivery: true },
            recv_buf: Vec::new(),
        })
    }

    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
        self.stream.peer_addr()
    }

    fn next_frame_len(&self) -> EmbeddedResult<Option<usize>> {
        if self.recv_buf.is_empty() {
            return Ok(None);
        }
        if self.recv_buf.len() < LENGTH_PREFIX_LEN {
            return Ok(None);
        }
        let frame_len = u16::from_be_bytes([self.recv_buf[0], self.recv_buf[1]]);
        let frame_len = usize::from(frame_len);
        if frame_len == 0 {
            return Err(EmbeddedError::InvalidInput);
        }
        Ok(Some(LENGTH_PREFIX_LEN + frame_len))
    }

    fn refill_read_buffer(&mut self) -> EmbeddedResult<()> {
        let mut scratch = [0_u8; 2048];
        loop {
            match self.stream.read(&mut scratch) {
                Ok(0) => {
                    self.state = LinkState::Down;
                    return Err(EmbeddedError::Disconnected);
                }
                Ok(n) => {
                    self.recv_buf.extend_from_slice(&scratch[..n]);
                    if n < scratch.len() {
                        return Ok(());
                    }
                }
                Err(err) if err.kind() == ErrorKind::WouldBlock => return Ok(()),
                Err(_) => {
                    self.state = LinkState::Down;
                    return Err(EmbeddedError::Disconnected);
                }
            }
        }
    }
}

impl EmbeddedTransport for TcpEmbeddedTransport {
    fn link_state(&self) -> LinkState {
        self.state
    }

    fn capabilities(&self) -> TransportCaps {
        self.caps
    }

    fn send_frame(&mut self, frame: &PacketFrame) -> EmbeddedResult<()> {
        if self.state != LinkState::Up {
            return Err(EmbeddedError::Disconnected);
        }
        if frame.payload.len() > usize::from(self.caps.mtu_hint) {
            return Err(EmbeddedError::InvalidArgument);
        }
        let encoded = encode_frame(frame)?;
        let encoded_len =
            u16::try_from(encoded.len()).map_err(|_| EmbeddedError::InvalidArgument)?;
        let header = encoded_len.to_be_bytes();
        self.stream.write_all(&header).map_err(|_| EmbeddedError::Disconnected)?;
        self.stream.write_all(&encoded).map_err(|_| EmbeddedError::Disconnected)?;
        self.stream.flush().map_err(|_| EmbeddedError::Disconnected)?;
        Ok(())
    }

    fn poll_frame(&mut self) -> EmbeddedResult<Option<PacketFrame>> {
        if self.state == LinkState::Down {
            return Err(EmbeddedError::Disconnected);
        }
        self.refill_read_buffer()?;
        let Some(frame_len) = self.next_frame_len()? else {
            return Ok(None);
        };
        if self.recv_buf.len() < frame_len {
            return Ok(None);
        }

        let packet_bytes: Vec<u8> =
            self.recv_buf.drain(..frame_len).skip(LENGTH_PREFIX_LEN).collect();
        let frame = decode_frame(&packet_bytes)?;
        Ok(Some(frame))
    }
}

fn map_connect_error(err: std::io::Error) -> EmbeddedError {
    match err.kind() {
        ErrorKind::WouldBlock | ErrorKind::TimedOut => EmbeddedError::Timeout,
        ErrorKind::ConnectionRefused
        | ErrorKind::ConnectionReset
        | ErrorKind::ConnectionAborted
        | ErrorKind::NotConnected
        | ErrorKind::AddrNotAvailable
        | ErrorKind::BrokenPipe => EmbeddedError::Disconnected,
        _ => EmbeddedError::InvalidState,
    }
}

#[cfg(test)]
mod tests {
    use super::TcpEmbeddedTransport;
    use rns_embedded_core::{
        packet::PacketFrame,
        transport::{EmbeddedTransport, LinkState},
    };
    use std::net::TcpListener;
    use std::thread;
    use std::time::{Duration, Instant};

    fn frame(kind: u8, seq: u32, payload: &[u8]) -> PacketFrame {
        PacketFrame::new(kind, seq, payload.to_vec()).expect("frame")
    }

    fn connected_pair() -> (TcpEmbeddedTransport, TcpEmbeddedTransport) {
        let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
        let addr = listener.local_addr().expect("listener addr");

        let server = thread::spawn(move || {
            let (stream, _) = listener.accept().expect("accept");
            TcpEmbeddedTransport::from_stream(stream, 1024).expect("server transport")
        });

        let client = TcpEmbeddedTransport::connect(addr, 1024).expect("client transport");
        let server = server.join().expect("server join");
        (client, server)
    }

    fn poll_until_frame(transport: &mut TcpEmbeddedTransport, timeout: Duration) -> PacketFrame {
        let deadline = Instant::now() + timeout;
        loop {
            match transport.poll_frame() {
                Ok(Some(frame)) => return frame,
                Ok(None) if Instant::now() < deadline => thread::sleep(Duration::from_millis(5)),
                Ok(None) => panic!("timed out waiting for frame"),
                Err(err) => panic!("poll failed: {err:?}"),
            }
        }
    }

    #[test]
    fn tcp_transport_round_trips_frame() {
        let (mut client, mut server) = connected_pair();
        client.send_frame(&frame(0x11, 7, b"announce")).expect("send");

        let received = poll_until_frame(&mut server, Duration::from_secs(1));
        assert_eq!(received.kind, 0x11);
        assert_eq!(received.sequence, 7);
        assert_eq!(received.payload, b"announce");
    }

    #[test]
    fn tcp_transport_supports_bidirectional_exchange() {
        let (mut client, mut server) = connected_pair();
        client.send_frame(&frame(0x31, 1, b"hello")).expect("send client");
        server.send_frame(&frame(0x32, 2, b"world")).expect("send server");

        let rx_server = poll_until_frame(&mut server, Duration::from_secs(1));
        let rx_client = poll_until_frame(&mut client, Duration::from_secs(1));
        assert_eq!(rx_server.payload, b"hello");
        assert_eq!(rx_client.payload, b"world");
    }

    #[test]
    fn tcp_transport_tracks_disconnect() {
        let (client, mut server) = connected_pair();
        drop(client);

        let deadline = Instant::now() + Duration::from_secs(1);
        loop {
            match server.poll_frame() {
                Err(rns_embedded_core::EmbeddedError::Disconnected) => break,
                Ok(None) if Instant::now() < deadline => thread::sleep(Duration::from_millis(5)),
                other => panic!("unexpected poll result: {other:?}"),
            }
        }
        assert_eq!(server.link_state(), LinkState::Down);
    }
}