rns_embedded_runtime/
tcp.rs1use rns_embedded_core::{
2 packet::{decode_frame, encode_frame, PacketFrame},
3 transport::{EmbeddedTransport, LinkState, TransportCaps},
4 EmbeddedError, EmbeddedResult,
5};
6use std::io::{ErrorKind, Read, Write};
7use std::net::{SocketAddr, TcpStream};
8
9const LENGTH_PREFIX_LEN: usize = 2;
10
11pub struct TcpEmbeddedTransport {
12 stream: TcpStream,
13 state: LinkState,
14 caps: TransportCaps,
15 recv_buf: Vec<u8>,
16}
17
18impl TcpEmbeddedTransport {
19 pub fn connect(addr: SocketAddr, mtu_hint: u16) -> EmbeddedResult<Self> {
20 let stream = TcpStream::connect(addr).map_err(map_connect_error)?;
21 Self::from_stream(stream, mtu_hint)
22 }
23
24 pub fn from_stream(stream: TcpStream, mtu_hint: u16) -> EmbeddedResult<Self> {
25 stream.set_nonblocking(true).map_err(|_| EmbeddedError::InvalidState)?;
26 Ok(Self {
27 stream,
28 state: LinkState::Up,
29 caps: TransportCaps { mtu_hint, ordered_delivery: true },
30 recv_buf: Vec::new(),
31 })
32 }
33
34 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
35 self.stream.peer_addr()
36 }
37
38 fn next_frame_len(&self) -> EmbeddedResult<Option<usize>> {
39 if self.recv_buf.is_empty() {
40 return Ok(None);
41 }
42 if self.recv_buf.len() < LENGTH_PREFIX_LEN {
43 return Ok(None);
44 }
45 let frame_len = u16::from_be_bytes([self.recv_buf[0], self.recv_buf[1]]);
46 let frame_len = usize::from(frame_len);
47 if frame_len == 0 {
48 return Err(EmbeddedError::InvalidInput);
49 }
50 Ok(Some(LENGTH_PREFIX_LEN + frame_len))
51 }
52
53 fn refill_read_buffer(&mut self) -> EmbeddedResult<()> {
54 let mut scratch = [0_u8; 2048];
55 loop {
56 match self.stream.read(&mut scratch) {
57 Ok(0) => {
58 self.state = LinkState::Down;
59 return Err(EmbeddedError::Disconnected);
60 }
61 Ok(n) => {
62 self.recv_buf.extend_from_slice(&scratch[..n]);
63 if n < scratch.len() {
64 return Ok(());
65 }
66 }
67 Err(err) if err.kind() == ErrorKind::WouldBlock => return Ok(()),
68 Err(_) => {
69 self.state = LinkState::Down;
70 return Err(EmbeddedError::Disconnected);
71 }
72 }
73 }
74 }
75}
76
77impl EmbeddedTransport for TcpEmbeddedTransport {
78 fn link_state(&self) -> LinkState {
79 self.state
80 }
81
82 fn capabilities(&self) -> TransportCaps {
83 self.caps
84 }
85
86 fn send_frame(&mut self, frame: &PacketFrame) -> EmbeddedResult<()> {
87 if self.state != LinkState::Up {
88 return Err(EmbeddedError::Disconnected);
89 }
90 if frame.payload.len() > usize::from(self.caps.mtu_hint) {
91 return Err(EmbeddedError::InvalidArgument);
92 }
93 let encoded = encode_frame(frame)?;
94 let encoded_len =
95 u16::try_from(encoded.len()).map_err(|_| EmbeddedError::InvalidArgument)?;
96 let header = encoded_len.to_be_bytes();
97 self.stream.write_all(&header).map_err(|_| EmbeddedError::Disconnected)?;
98 self.stream.write_all(&encoded).map_err(|_| EmbeddedError::Disconnected)?;
99 self.stream.flush().map_err(|_| EmbeddedError::Disconnected)?;
100 Ok(())
101 }
102
103 fn poll_frame(&mut self) -> EmbeddedResult<Option<PacketFrame>> {
104 if self.state == LinkState::Down {
105 return Err(EmbeddedError::Disconnected);
106 }
107 self.refill_read_buffer()?;
108 let Some(frame_len) = self.next_frame_len()? else {
109 return Ok(None);
110 };
111 if self.recv_buf.len() < frame_len {
112 return Ok(None);
113 }
114
115 let packet_bytes: Vec<u8> =
116 self.recv_buf.drain(..frame_len).skip(LENGTH_PREFIX_LEN).collect();
117 let frame = decode_frame(&packet_bytes)?;
118 Ok(Some(frame))
119 }
120}
121
122fn map_connect_error(err: std::io::Error) -> EmbeddedError {
123 match err.kind() {
124 ErrorKind::WouldBlock | ErrorKind::TimedOut => EmbeddedError::Timeout,
125 ErrorKind::ConnectionRefused
126 | ErrorKind::ConnectionReset
127 | ErrorKind::ConnectionAborted
128 | ErrorKind::NotConnected
129 | ErrorKind::AddrNotAvailable
130 | ErrorKind::BrokenPipe => EmbeddedError::Disconnected,
131 _ => EmbeddedError::InvalidState,
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::TcpEmbeddedTransport;
138 use rns_embedded_core::{
139 packet::PacketFrame,
140 transport::{EmbeddedTransport, LinkState},
141 };
142 use std::net::TcpListener;
143 use std::thread;
144 use std::time::{Duration, Instant};
145
146 fn frame(kind: u8, seq: u32, payload: &[u8]) -> PacketFrame {
147 PacketFrame::new(kind, seq, payload.to_vec()).expect("frame")
148 }
149
150 fn connected_pair() -> (TcpEmbeddedTransport, TcpEmbeddedTransport) {
151 let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
152 let addr = listener.local_addr().expect("listener addr");
153
154 let server = thread::spawn(move || {
155 let (stream, _) = listener.accept().expect("accept");
156 TcpEmbeddedTransport::from_stream(stream, 1024).expect("server transport")
157 });
158
159 let client = TcpEmbeddedTransport::connect(addr, 1024).expect("client transport");
160 let server = server.join().expect("server join");
161 (client, server)
162 }
163
164 fn poll_until_frame(transport: &mut TcpEmbeddedTransport, timeout: Duration) -> PacketFrame {
165 let deadline = Instant::now() + timeout;
166 loop {
167 match transport.poll_frame() {
168 Ok(Some(frame)) => return frame,
169 Ok(None) if Instant::now() < deadline => thread::sleep(Duration::from_millis(5)),
170 Ok(None) => panic!("timed out waiting for frame"),
171 Err(err) => panic!("poll failed: {err:?}"),
172 }
173 }
174 }
175
176 #[test]
177 fn tcp_transport_round_trips_frame() {
178 let (mut client, mut server) = connected_pair();
179 client.send_frame(&frame(0x11, 7, b"announce")).expect("send");
180
181 let received = poll_until_frame(&mut server, Duration::from_secs(1));
182 assert_eq!(received.kind, 0x11);
183 assert_eq!(received.sequence, 7);
184 assert_eq!(received.payload, b"announce");
185 }
186
187 #[test]
188 fn tcp_transport_supports_bidirectional_exchange() {
189 let (mut client, mut server) = connected_pair();
190 client.send_frame(&frame(0x31, 1, b"hello")).expect("send client");
191 server.send_frame(&frame(0x32, 2, b"world")).expect("send server");
192
193 let rx_server = poll_until_frame(&mut server, Duration::from_secs(1));
194 let rx_client = poll_until_frame(&mut client, Duration::from_secs(1));
195 assert_eq!(rx_server.payload, b"hello");
196 assert_eq!(rx_client.payload, b"world");
197 }
198
199 #[test]
200 fn tcp_transport_tracks_disconnect() {
201 let (client, mut server) = connected_pair();
202 drop(client);
203
204 let deadline = Instant::now() + Duration::from_secs(1);
205 loop {
206 match server.poll_frame() {
207 Err(rns_embedded_core::EmbeddedError::Disconnected) => break,
208 Ok(None) if Instant::now() < deadline => thread::sleep(Duration::from_millis(5)),
209 other => panic!("unexpected poll result: {other:?}"),
210 }
211 }
212 assert_eq!(server.link_state(), LinkState::Down);
213 }
214}