ergot_base/interface_manager/
std_tcp_client.rs1use std::sync::Arc;
9
10use crate::{
11 Header, NetStack,
12 interface_manager::{
13 ConstInit, InterfaceManager, InterfaceSendError, cobs_stream,
14 std_utils::{
15 ReceiverError, StdQueue,
16 acc::{CobsAccumulator, FeedResult},
17 },
18 },
19 wire_frames::{CommonHeader, de_frame},
20};
21use bbq2::{prod_cons::stream::StreamConsumer, traits::storage::BoxedSlice};
22use log::{debug, error, info, warn};
23use maitake_sync::WaitQueue;
24use mutex::ScopedRawMutex;
25use tokio::{
26 io::{AsyncReadExt, AsyncWriteExt},
27 net::{
28 TcpStream,
29 tcp::{OwnedReadHalf, OwnedWriteHalf},
30 },
31 select,
32};
33
34#[derive(Default)]
35pub struct StdTcpClientIm {
36 inner: Option<StdTcpClientImInner>,
37 seq_no: u16,
38}
39
40struct StdTcpClientImInner {
41 interface: StdTcpTxHdl,
42 net_id: u16,
43 closer: Arc<WaitQueue>,
44}
45
46#[derive(Debug, PartialEq)]
47pub enum ClientError {
48 SocketAlreadyActive,
49}
50
51pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
52 stack: &'static NetStack<R, StdTcpClientIm>,
53 skt: OwnedReadHalf,
54 closer: Arc<WaitQueue>,
55}
56
57struct StdTcpTxHdl {
58 skt_tx: cobs_stream::Interface<StdQueue>,
59}
60
61impl StdTcpClientIm {
64 pub const fn new() -> Self {
65 Self {
66 inner: None,
67 seq_no: 0,
68 }
69 }
70}
71
72impl ConstInit for StdTcpClientIm {
73 #[allow(clippy::declare_interior_mutable_const)]
74 const INIT: Self = Self::new();
75}
76
77impl StdTcpClientIm {
78 fn common_send<'a, 'b>(
79 &'b mut self,
80 ihdr: &'a Header,
81 ) -> Result<(&'b mut StdTcpClientImInner, CommonHeader), InterfaceSendError> {
82 let intfc = match self.inner.take() {
83 None => return Err(InterfaceSendError::NoRouteToDest),
84 Some(intfc) if intfc.closer.is_closed() => {
85 drop(intfc);
86 return Err(InterfaceSendError::NoRouteToDest);
87 }
88 Some(intfc) => self.inner.insert(intfc),
89 };
90
91 if intfc.net_id == 0 {
92 return Err(InterfaceSendError::NoRouteToDest);
94 }
95 if ihdr.dst.network_id == intfc.net_id && ihdr.dst.node_id == 2 {
103 return Err(InterfaceSendError::DestinationLocal);
104 }
105
106 let mut hdr = ihdr.clone();
109 hdr.decrement_ttl()?;
110
111 if hdr.src.net_node_any() {
114 hdr.src.network_id = intfc.net_id;
118 hdr.src.node_id = 2;
119 }
120
121 if hdr.dst.port_id == 255 {
124 hdr.dst.network_id = intfc.net_id;
125 hdr.dst.node_id = 1;
126 }
127
128 let seq_no = self.seq_no;
129 self.seq_no = self.seq_no.wrapping_add(1);
130
131 let header = CommonHeader {
132 src: hdr.src,
133 dst: hdr.dst,
134 seq_no,
135 kind: hdr.kind,
136 ttl: hdr.ttl,
137 };
138 if [0, 255].contains(&hdr.dst.port_id) {
139 if ihdr.any_all.is_none() {
140 return Err(InterfaceSendError::AnyPortMissingKey);
141 }
142 }
143
144 Ok((intfc, header))
145 }
146}
147
148impl InterfaceManager for StdTcpClientIm {
149 fn send<T: serde::Serialize>(
150 &mut self,
151 hdr: &Header,
152 data: &T,
153 ) -> Result<(), InterfaceSendError> {
154 let (intfc, header) = self.common_send(hdr)?;
155 let res = intfc
156 .interface
157 .skt_tx
158 .send_ty(&header, hdr.any_all.as_ref(), data);
159
160 match res {
161 Ok(()) => Ok(()),
162 Err(()) => Err(InterfaceSendError::InterfaceFull),
163 }
164 }
165
166 fn send_raw(
167 &mut self,
168 hdr: &Header,
169 raw_hdr: &[u8],
170 data: &[u8],
171 ) -> Result<(), InterfaceSendError> {
172 let (intfc, header) = self.common_send(hdr)?;
173 let res = intfc.interface.skt_tx.send_raw(&header, raw_hdr, data);
174
175 match res {
176 Ok(()) => Ok(()),
177 Err(()) => Err(InterfaceSendError::InterfaceFull),
178 }
179 }
180
181 fn send_err(
182 &mut self,
183 hdr: &Header,
184 err: crate::ProtocolError,
185 ) -> Result<(), InterfaceSendError> {
186 let (intfc, header) = self.common_send(hdr)?;
187 let res = intfc.interface.skt_tx.send_err(&header, err);
188
189 match res {
190 Ok(()) => Ok(()),
191 Err(()) => Err(InterfaceSendError::InterfaceFull),
192 }
193 }
194}
195
196impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
197 pub async fn run(mut self) -> Result<(), ReceiverError> {
198 let res = self.run_inner().await;
199 self.stack.with_interface_manager(|im| {
201 _ = im.inner.take();
202 });
203 res
204 }
205
206 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
207 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
208 let mut raw_buf = [0u8; 4096];
209 let mut net_id = None;
210
211 loop {
212 let rd = self.skt.read(&mut raw_buf);
213 let close = self.closer.wait();
214
215 let ct = select! {
216 r = rd => {
217 match r {
218 Ok(0) | Err(_) => {
219 warn!("recv run closed");
220 return Err(ReceiverError::SocketClosed)
221 },
222 Ok(ct) => ct,
223 }
224 }
225 _c = close => {
226 return Err(ReceiverError::SocketClosed);
227 }
228 };
229
230 let buf = &raw_buf[..ct];
231 let mut window = buf;
232
233 'cobs: while !window.is_empty() {
234 window = match cobs_buf.feed_raw(window) {
235 FeedResult::Consumed => break 'cobs,
236 FeedResult::OverFull(new_wind) => new_wind,
237 FeedResult::DeserError(new_wind) => new_wind,
238 FeedResult::Success { data, remaining } => {
239 if let Some(mut frame) = de_frame(data) {
242 debug!("Got Frame!");
243 let take_net = net_id.is_none()
244 || net_id.is_some_and(|n| {
245 frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
246 });
247 if take_net {
248 self.stack.with_interface_manager(|im| {
249 if let Some(i) = im.inner.as_mut() {
250 i.net_id = frame.hdr.dst.network_id;
252 }
253 });
255 net_id = Some(frame.hdr.dst.network_id);
256 }
257
258 if let Some(net) = net_id.as_ref() {
264 if frame.hdr.src.network_id == 0 {
265 assert_ne!(
266 frame.hdr.src.node_id, 0,
267 "we got a local packet remotely?"
268 );
269 assert_ne!(
270 frame.hdr.src.node_id, 2,
271 "someone is pretending to be us?"
272 );
273
274 frame.hdr.src.network_id = *net;
275 }
276 }
277
278 let hdr = frame.hdr.clone();
284 let hdr: Header = hdr.into();
285 let res = match frame.body {
286 Ok(body) => self.stack.send_raw(&hdr, frame.hdr_raw, body),
287 Err(e) => self.stack.send_err(&hdr, e),
288 };
289 match res {
290 Ok(()) => {}
291 Err(e) => {
292 panic!("recv->send error: {e:?}");
294 }
295 }
296 } else {
297 warn!(
298 "Decode error! Ignoring frame on net_id {}",
299 net_id.unwrap_or(0)
300 );
301 }
302
303 remaining
304 }
305 };
306 }
307 }
308 }
309}
310
311pub fn register_interface<R: ScopedRawMutex>(
314 stack: &'static NetStack<R, StdTcpClientIm>,
315 socket: TcpStream,
316) -> Result<StdTcpRecvHdl<R>, ClientError> {
317 let (rx, tx) = socket.into_split();
318 let closer = Arc::new(WaitQueue::new());
319 stack.with_interface_manager(|im| {
320 if im.inner.is_some() {
321 return Err(ClientError::SocketAlreadyActive);
322 }
323
324 let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
325 let ctx = q.stream_producer();
326 let crx = q.stream_consumer();
327
328 im.inner = Some(StdTcpClientImInner {
329 interface: StdTcpTxHdl {
330 skt_tx: cobs_stream::Interface {
331 mtu: 1024,
332 prod: ctx,
333 },
334 },
335 net_id: 0,
336 closer: closer.clone(),
337 });
338 tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
340 Ok(())
341 })?;
342 Ok(StdTcpRecvHdl {
343 stack,
344 skt: rx,
345 closer,
346 })
347}
348
349async fn tx_worker(mut tx: OwnedWriteHalf, rx: StreamConsumer<StdQueue>, closer: Arc<WaitQueue>) {
350 info!("Started tx_worker");
351 loop {
352 let rxf = rx.wait_read();
353 let clf = closer.wait();
354
355 let frame = select! {
356 r = rxf => r,
357 _c = clf => {
358 break;
359 }
360 };
361
362 let len = frame.len();
363 info!("sending pkt len:{}", len);
364 let res = tx.write_all(&frame).await;
365 frame.release(len);
366 if let Err(e) = res {
367 error!("Err: {e:?}");
368 break;
369 }
370 }
371 warn!("Closing interface");
373}