Skip to main content

rns_embedded_runtime/
tcp.rs

1use rns_embedded_core::{
2    packet::{decode_frame, encode_frame, PacketFrame},
3    transport::{EmbeddedTransport, LinkState, TransportCaps},
4    EmbeddedError, EmbeddedResult,
5};
6use std::io::{ErrorKind, Read, Write};
7use std::net::{SocketAddr, TcpStream};
8
9const LENGTH_PREFIX_LEN: usize = 2;
10
11pub struct TcpEmbeddedTransport {
12    stream: TcpStream,
13    state: LinkState,
14    caps: TransportCaps,
15    recv_buf: Vec<u8>,
16}
17
18impl TcpEmbeddedTransport {
19    pub fn connect(addr: SocketAddr, mtu_hint: u16) -> EmbeddedResult<Self> {
20        let stream = TcpStream::connect(addr).map_err(map_connect_error)?;
21        Self::from_stream(stream, mtu_hint)
22    }
23
24    pub fn from_stream(stream: TcpStream, mtu_hint: u16) -> EmbeddedResult<Self> {
25        stream.set_nonblocking(true).map_err(|_| EmbeddedError::InvalidState)?;
26        Ok(Self {
27            stream,
28            state: LinkState::Up,
29            caps: TransportCaps { mtu_hint, ordered_delivery: true },
30            recv_buf: Vec::new(),
31        })
32    }
33
34    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
35        self.stream.peer_addr()
36    }
37
38    fn next_frame_len(&self) -> EmbeddedResult<Option<usize>> {
39        if self.recv_buf.is_empty() {
40            return Ok(None);
41        }
42        if self.recv_buf.len() < LENGTH_PREFIX_LEN {
43            return Ok(None);
44        }
45        let frame_len = u16::from_be_bytes([self.recv_buf[0], self.recv_buf[1]]);
46        let frame_len = usize::from(frame_len);
47        if frame_len == 0 {
48            return Err(EmbeddedError::InvalidInput);
49        }
50        Ok(Some(LENGTH_PREFIX_LEN + frame_len))
51    }
52
53    fn refill_read_buffer(&mut self) -> EmbeddedResult<()> {
54        let mut scratch = [0_u8; 2048];
55        loop {
56            match self.stream.read(&mut scratch) {
57                Ok(0) => {
58                    self.state = LinkState::Down;
59                    return Err(EmbeddedError::Disconnected);
60                }
61                Ok(n) => {
62                    self.recv_buf.extend_from_slice(&scratch[..n]);
63                    if n < scratch.len() {
64                        return Ok(());
65                    }
66                }
67                Err(err) if err.kind() == ErrorKind::WouldBlock => return Ok(()),
68                Err(_) => {
69                    self.state = LinkState::Down;
70                    return Err(EmbeddedError::Disconnected);
71                }
72            }
73        }
74    }
75}
76
77impl EmbeddedTransport for TcpEmbeddedTransport {
78    fn link_state(&self) -> LinkState {
79        self.state
80    }
81
82    fn capabilities(&self) -> TransportCaps {
83        self.caps
84    }
85
86    fn send_frame(&mut self, frame: &PacketFrame) -> EmbeddedResult<()> {
87        if self.state != LinkState::Up {
88            return Err(EmbeddedError::Disconnected);
89        }
90        if frame.payload.len() > usize::from(self.caps.mtu_hint) {
91            return Err(EmbeddedError::InvalidArgument);
92        }
93        let encoded = encode_frame(frame)?;
94        let encoded_len =
95            u16::try_from(encoded.len()).map_err(|_| EmbeddedError::InvalidArgument)?;
96        let header = encoded_len.to_be_bytes();
97        self.stream.write_all(&header).map_err(|_| EmbeddedError::Disconnected)?;
98        self.stream.write_all(&encoded).map_err(|_| EmbeddedError::Disconnected)?;
99        self.stream.flush().map_err(|_| EmbeddedError::Disconnected)?;
100        Ok(())
101    }
102
103    fn poll_frame(&mut self) -> EmbeddedResult<Option<PacketFrame>> {
104        if self.state == LinkState::Down {
105            return Err(EmbeddedError::Disconnected);
106        }
107        self.refill_read_buffer()?;
108        let Some(frame_len) = self.next_frame_len()? else {
109            return Ok(None);
110        };
111        if self.recv_buf.len() < frame_len {
112            return Ok(None);
113        }
114
115        let packet_bytes: Vec<u8> =
116            self.recv_buf.drain(..frame_len).skip(LENGTH_PREFIX_LEN).collect();
117        let frame = decode_frame(&packet_bytes)?;
118        Ok(Some(frame))
119    }
120}
121
122fn map_connect_error(err: std::io::Error) -> EmbeddedError {
123    match err.kind() {
124        ErrorKind::WouldBlock | ErrorKind::TimedOut => EmbeddedError::Timeout,
125        ErrorKind::ConnectionRefused
126        | ErrorKind::ConnectionReset
127        | ErrorKind::ConnectionAborted
128        | ErrorKind::NotConnected
129        | ErrorKind::AddrNotAvailable
130        | ErrorKind::BrokenPipe => EmbeddedError::Disconnected,
131        _ => EmbeddedError::InvalidState,
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::TcpEmbeddedTransport;
138    use rns_embedded_core::{
139        packet::PacketFrame,
140        transport::{EmbeddedTransport, LinkState},
141    };
142    use std::net::TcpListener;
143    use std::thread;
144    use std::time::{Duration, Instant};
145
146    fn frame(kind: u8, seq: u32, payload: &[u8]) -> PacketFrame {
147        PacketFrame::new(kind, seq, payload.to_vec()).expect("frame")
148    }
149
150    fn connected_pair() -> (TcpEmbeddedTransport, TcpEmbeddedTransport) {
151        let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
152        let addr = listener.local_addr().expect("listener addr");
153
154        let server = thread::spawn(move || {
155            let (stream, _) = listener.accept().expect("accept");
156            TcpEmbeddedTransport::from_stream(stream, 1024).expect("server transport")
157        });
158
159        let client = TcpEmbeddedTransport::connect(addr, 1024).expect("client transport");
160        let server = server.join().expect("server join");
161        (client, server)
162    }
163
164    fn poll_until_frame(transport: &mut TcpEmbeddedTransport, timeout: Duration) -> PacketFrame {
165        let deadline = Instant::now() + timeout;
166        loop {
167            match transport.poll_frame() {
168                Ok(Some(frame)) => return frame,
169                Ok(None) if Instant::now() < deadline => thread::sleep(Duration::from_millis(5)),
170                Ok(None) => panic!("timed out waiting for frame"),
171                Err(err) => panic!("poll failed: {err:?}"),
172            }
173        }
174    }
175
176    #[test]
177    fn tcp_transport_round_trips_frame() {
178        let (mut client, mut server) = connected_pair();
179        client.send_frame(&frame(0x11, 7, b"announce")).expect("send");
180
181        let received = poll_until_frame(&mut server, Duration::from_secs(1));
182        assert_eq!(received.kind, 0x11);
183        assert_eq!(received.sequence, 7);
184        assert_eq!(received.payload, b"announce");
185    }
186
187    #[test]
188    fn tcp_transport_supports_bidirectional_exchange() {
189        let (mut client, mut server) = connected_pair();
190        client.send_frame(&frame(0x31, 1, b"hello")).expect("send client");
191        server.send_frame(&frame(0x32, 2, b"world")).expect("send server");
192
193        let rx_server = poll_until_frame(&mut server, Duration::from_secs(1));
194        let rx_client = poll_until_frame(&mut client, Duration::from_secs(1));
195        assert_eq!(rx_server.payload, b"hello");
196        assert_eq!(rx_client.payload, b"world");
197    }
198
199    #[test]
200    fn tcp_transport_tracks_disconnect() {
201        let (client, mut server) = connected_pair();
202        drop(client);
203
204        let deadline = Instant::now() + Duration::from_secs(1);
205        loop {
206            match server.poll_frame() {
207                Err(rns_embedded_core::EmbeddedError::Disconnected) => break,
208                Ok(None) if Instant::now() < deadline => thread::sleep(Duration::from_millis(5)),
209                other => panic!("unexpected poll result: {other:?}"),
210            }
211        }
212        assert_eq!(server.link_state(), LinkState::Down);
213    }
214}