ergot_base/interface_manager/
std_tcp_client.rs1use std::sync::Arc;
9
10use crate::{
11 Header, HeaderSeq, NetStack,
12 interface_manager::std_utils::{OwnedFrame, ser_frame},
13};
14
15use super::{
16 ConstInit, InterfaceManager, InterfaceSendError,
17 std_utils::{
18 ReceiverError,
19 acc::{CobsAccumulator, FeedResult},
20 de_frame,
21 },
22};
23use log::{debug, error, info, warn};
24use maitake_sync::WaitQueue;
25use mutex::ScopedRawMutex;
26use tokio::{
27 io::{AsyncReadExt, AsyncWriteExt},
28 net::{
29 TcpStream,
30 tcp::{OwnedReadHalf, OwnedWriteHalf},
31 },
32 select,
33 sync::mpsc::{Receiver, Sender, channel, error::TrySendError},
34};
35
36#[derive(Default)]
37pub struct StdTcpClientIm {
38 inner: Option<StdTcpClientImInner>,
39 seq_no: u16,
40}
41
42struct StdTcpClientImInner {
43 interface: StdTcpTxHdl,
44 net_id: u16,
45 closer: Arc<WaitQueue>,
46}
47
48#[derive(Debug, PartialEq)]
49pub enum ClientError {
50 SocketAlreadyActive,
51}
52
53pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
54 stack: &'static NetStack<R, StdTcpClientIm>,
55 skt: OwnedReadHalf,
56 closer: Arc<WaitQueue>,
57}
58
59struct StdTcpTxHdl {
60 skt_tx: Sender<OwnedFrame>,
61}
62
63impl StdTcpClientIm {
66 pub const fn new() -> Self {
67 Self {
68 inner: None,
69 seq_no: 0,
70 }
71 }
72}
73
74impl ConstInit for StdTcpClientIm {
75 #[allow(clippy::declare_interior_mutable_const)]
76 const INIT: Self = Self::new();
77}
78
79impl InterfaceManager for StdTcpClientIm {
80 fn send<T: serde::Serialize>(
81 &mut self,
82 hdr: &Header,
83 data: &T,
84 ) -> Result<(), InterfaceSendError> {
85 let Some(intfc) = self.inner.as_mut() else {
86 return Err(InterfaceSendError::NoRouteToDest);
87 };
88 if intfc.net_id == 0 {
89 return Err(InterfaceSendError::NoRouteToDest);
91 }
92 if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
100 return Err(InterfaceSendError::DestinationLocal);
101 }
102
103 let mut hdr = hdr.clone();
106 hdr.decrement_ttl()?;
107
108 if hdr.src.net_node_any() {
111 hdr.src.network_id = intfc.net_id;
115 hdr.src.node_id = 2;
116 }
117
118 if hdr.dst.port_id == 255 {
121 hdr.dst.network_id = intfc.net_id;
122 hdr.dst.node_id = 1;
123 }
124
125 let seq_no = self.seq_no;
126 self.seq_no = self.seq_no.wrapping_add(1);
127 let res = intfc.interface.skt_tx.try_send(OwnedFrame {
128 hdr: HeaderSeq {
129 src: hdr.src,
130 dst: hdr.dst,
131 seq_no,
132 key: hdr.key,
133 kind: hdr.kind,
134 ttl: hdr.ttl,
135 },
136 body: Ok(postcard::to_stdvec(data).unwrap()),
137 });
138 match res {
139 Ok(()) => Ok(()),
140 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
141 Err(TrySendError::Closed(_)) => {
142 if let Some(i) = self.inner.take() {
143 i.closer.close();
144 }
145 Err(InterfaceSendError::NoRouteToDest)
146 }
147 }
148 }
149
150 fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
151 let Some(intfc) = self.inner.as_mut() else {
152 return Err(InterfaceSendError::NoRouteToDest);
153 };
154 if intfc.net_id == 0 {
155 return Err(InterfaceSendError::NoRouteToDest);
157 }
158 if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
166 return Err(InterfaceSendError::DestinationLocal);
167 }
168
169 let mut hdr = hdr.clone();
172 hdr.decrement_ttl()?;
173
174 if hdr.src.net_node_any() {
177 hdr.src.network_id = intfc.net_id;
181 hdr.src.node_id = 2;
182 }
183
184 if hdr.dst.port_id == 255 {
187 hdr.dst.network_id = intfc.net_id;
188 hdr.dst.node_id = 1;
189 }
190
191 let seq_no = self.seq_no;
192 self.seq_no = self.seq_no.wrapping_add(1);
193 let res = intfc.interface.skt_tx.try_send(OwnedFrame {
194 hdr: HeaderSeq {
195 src: hdr.src,
196 dst: hdr.dst,
197 seq_no,
198 key: hdr.key,
199 kind: hdr.kind,
200 ttl: hdr.ttl,
201 },
202 body: Ok(data.to_vec()),
203 });
204 match res {
205 Ok(()) => Ok(()),
206 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
207 Err(TrySendError::Closed(_)) => {
208 self.inner.take();
209 Err(InterfaceSendError::NoRouteToDest)
210 }
211 }
212 }
213
214 fn send_err(
215 &mut self,
216 hdr: &Header,
217 err: crate::ProtocolError,
218 ) -> Result<(), InterfaceSendError> {
219 let Some(intfc) = self.inner.as_mut() else {
220 return Err(InterfaceSendError::NoRouteToDest);
221 };
222 if intfc.net_id == 0 {
223 return Err(InterfaceSendError::NoRouteToDest);
225 }
226 if hdr.dst.network_id == intfc.net_id && hdr.dst.node_id == 2 {
234 return Err(InterfaceSendError::DestinationLocal);
235 }
236
237 let mut hdr = hdr.clone();
240 hdr.decrement_ttl()?;
241
242 if hdr.src.net_node_any() {
245 hdr.src.network_id = intfc.net_id;
249 hdr.src.node_id = 2;
250 }
251
252 if hdr.dst.port_id == 255 {
255 hdr.dst.network_id = intfc.net_id;
256 hdr.dst.node_id = 1;
257 }
258
259 let seq_no = self.seq_no;
260 self.seq_no = self.seq_no.wrapping_add(1);
261 let res = intfc.interface.skt_tx.try_send(OwnedFrame {
262 hdr: HeaderSeq {
263 src: hdr.src,
264 dst: hdr.dst,
265 seq_no,
266 key: hdr.key,
267 kind: hdr.kind,
268 ttl: hdr.ttl,
269 },
270 body: Err(err),
271 });
272 match res {
273 Ok(()) => Ok(()),
274 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
275 Err(TrySendError::Closed(_)) => {
276 if let Some(i) = self.inner.take() {
277 i.closer.close();
278 }
279 Err(InterfaceSendError::NoRouteToDest)
280 }
281 }
282 }
283}
284
285impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
286 pub async fn run(mut self) -> Result<(), ReceiverError> {
287 let res = self.run_inner().await;
288 self.stack.with_interface_manager(|im| {
290 _ = im.inner.take();
291 });
292 res
293 }
294
295 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
296 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
297 let mut raw_buf = [0u8; 4096];
298 let mut net_id = None;
299
300 loop {
301 let rd = self.skt.read(&mut raw_buf);
302 let close = self.closer.wait();
303
304 let ct = select! {
305 r = rd => {
306 match r {
307 Ok(0) | Err(_) => {
308 warn!("recv run closed");
309 return Err(ReceiverError::SocketClosed)
310 },
311 Ok(ct) => ct,
312 }
313 }
314 _c = close => {
315 return Err(ReceiverError::SocketClosed);
316 }
317 };
318
319 let buf = &raw_buf[..ct];
320 let mut window = buf;
321
322 'cobs: while !window.is_empty() {
323 window = match cobs_buf.feed_raw(window) {
324 FeedResult::Consumed => break 'cobs,
325 FeedResult::OverFull(new_wind) => new_wind,
326 FeedResult::DeserError(new_wind) => new_wind,
327 FeedResult::Success { data, remaining } => {
328 if let Some(mut frame) = de_frame(data) {
331 debug!("Got Frame!");
332 let take_net = net_id.is_none()
333 || net_id.is_some_and(|n| {
334 frame.hdr.dst.network_id != 0 && n != frame.hdr.dst.network_id
335 });
336 if take_net {
337 self.stack.with_interface_manager(|im| {
338 if let Some(i) = im.inner.as_mut() {
339 i.net_id = frame.hdr.dst.network_id;
341 }
342 });
344 net_id = Some(frame.hdr.dst.network_id);
345 }
346
347 if let Some(net) = net_id.as_ref() {
353 if frame.hdr.src.network_id == 0 {
354 assert_ne!(
355 frame.hdr.src.node_id, 0,
356 "we got a local packet remotely?"
357 );
358 assert_ne!(
359 frame.hdr.src.node_id, 2,
360 "someone is pretending to be us?"
361 );
362
363 frame.hdr.src.network_id = *net;
364 }
365 }
366
367 let hdr = frame.hdr.clone();
373 let hdr: Header = hdr.into();
374 let res = match frame.body {
375 Ok(body) => self.stack.send_raw(&hdr, &body),
376 Err(e) => self.stack.send_err(&hdr, e),
377 };
378 match res {
379 Ok(()) => {}
380 Err(e) => {
381 panic!("recv->send error: {e:?}");
383 }
384 }
385 } else {
386 warn!(
387 "Decode error! Ignoring frame on net_id {}",
388 net_id.unwrap_or(0)
389 );
390 }
391
392 remaining
393 }
394 };
395 }
396 }
397 }
398}
399
400pub fn register_interface<R: ScopedRawMutex>(
403 stack: &'static NetStack<R, StdTcpClientIm>,
404 socket: TcpStream,
405) -> Result<StdTcpRecvHdl<R>, ClientError> {
406 let (rx, tx) = socket.into_split();
407 let (ctx, crx) = channel(64);
408 let closer = Arc::new(WaitQueue::new());
409 stack.with_interface_manager(|im| {
410 if im.inner.is_some() {
411 return Err(ClientError::SocketAlreadyActive);
412 }
413
414 im.inner = Some(StdTcpClientImInner {
415 interface: StdTcpTxHdl { skt_tx: ctx },
416 net_id: 0,
417 closer: closer.clone(),
418 });
419 tokio::task::spawn(tx_worker(tx, crx, closer.clone()));
421 Ok(())
422 })?;
423 Ok(StdTcpRecvHdl {
424 stack,
425 skt: rx,
426 closer,
427 })
428}
429
430async fn tx_worker(mut tx: OwnedWriteHalf, mut rx: Receiver<OwnedFrame>, closer: Arc<WaitQueue>) {
431 info!("Started tx_worker");
432 loop {
433 let rxf = rx.recv();
434 let clf = closer.wait();
435
436 let frame = select! {
437 r = rxf => {
438 if let Some(frame) = r {
439 frame
440 } else {
441 warn!("tx_workerrx closed!");
442 closer.close();
443 break;
444 }
445 }
446 _c = clf => {
447 break;
448 }
449 };
450
451 let msg = ser_frame(frame);
452 info!("sending pkt len:{}", msg.len());
453 let res = tx.write_all(&msg).await;
454 if let Err(e) = res {
455 error!("Err: {e:?}");
456 break;
457 }
458 }
459 warn!("Closing interface");
461}