ergot_base/interface_manager/
std_tcp_router.rs1use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{Header, NetStack, interface_manager::std_utils::ser_frame};
23
24use log::{debug, error, info, trace, warn};
25use maitake_sync::WaitQueue;
26use mutex::ScopedRawMutex;
27use tokio::sync::mpsc::Sender;
28use tokio::{
29 io::{AsyncReadExt, AsyncWriteExt},
30 net::{
31 TcpStream,
32 tcp::{OwnedReadHalf, OwnedWriteHalf},
33 },
34 select,
35 sync::mpsc::{Receiver, channel, error::TrySendError},
36};
37
38use super::{
39 ConstInit, InterfaceManager, InterfaceSendError,
40 std_utils::{
41 OwnedFrame, ReceiverError,
42 acc::{CobsAccumulator, FeedResult},
43 de_frame,
44 },
45};
46
47pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
48 stack: &'static NetStack<R, StdTcpIm>,
49 net_id: u16,
56 skt: OwnedReadHalf,
57 closer: Arc<WaitQueue>,
58}
59
60pub struct StdTcpIm {
61 init: bool,
62 inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
63}
64
65#[derive(Default)]
66pub struct StdTcpImInner {
67 interfaces: Vec<StdTcpTxHdl>,
74 seq_no: u16,
75 any_closed: bool,
76}
77
78#[derive(Debug, PartialEq)]
79pub enum Error {
80 OutOfNetIds,
81}
82
83struct StdTcpTxHdl {
84 net_id: u16,
85 skt_tx: Sender<OwnedFrame>,
86 closer: Arc<WaitQueue>,
87}
88
89impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
94 pub async fn run(mut self) -> Result<(), ReceiverError> {
95 let res = self.run_inner().await;
96 self.closer.close();
97 self.stack.with_interface_manager(|im| {
99 let inner = im.get_or_init_inner();
100 inner.any_closed = true;
101 });
102 res
103 }
104
105 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
106 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
107 let mut raw_buf = [0u8; 4096];
108
109 loop {
110 let rd = self.skt.read(&mut raw_buf);
111 let close = self.closer.wait();
112
113 let ct = select! {
114 r = rd => {
115 match r {
116 Ok(0) | Err(_) => {
117 warn!("recv run {} closed", self.net_id);
118 return Err(ReceiverError::SocketClosed)
119 },
120 Ok(ct) => ct,
121 }
122 }
123 _c = close => {
124 return Err(ReceiverError::SocketClosed);
125 }
126 };
127
128 let buf = &raw_buf[..ct];
129 let mut window = buf;
130
131 'cobs: while !window.is_empty() {
132 window = match cobs_buf.feed_raw(window) {
133 FeedResult::Consumed => break 'cobs,
134 FeedResult::OverFull(new_wind) => new_wind,
135 FeedResult::DeserError(new_wind) => new_wind,
136 FeedResult::Success { data, remaining } => {
137 if let Some(mut frame) = de_frame(data) {
140 if frame.hdr.src.network_id == 0 {
144 assert_ne!(
145 frame.hdr.src.node_id, 0,
146 "we got a local packet remotely?"
147 );
148 assert_ne!(
149 frame.hdr.src.node_id, 1,
150 "someone is pretending to be us?"
151 );
152
153 frame.hdr.src.network_id = self.net_id;
154 }
155 let hdr = frame.hdr.clone();
161 let hdr: Header = hdr.into();
162
163 let res = match frame.body {
164 Ok(body) => self.stack.send_raw(&hdr, &body),
165 Err(e) => self.stack.send_err(&hdr, e),
166 };
167 match res {
168 Ok(()) => {}
169 Err(e) => {
170 warn!("recv->send error: {e:?}");
172 }
173 }
174 } else {
175 warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
176 }
177
178 remaining
179 }
180 };
181 }
182 }
183 }
184}
185
186impl StdTcpIm {
189 const fn new() -> Self {
190 Self {
191 init: false,
192 inner: UnsafeCell::new(MaybeUninit::uninit()),
193 }
194 }
195
196 pub fn get_nets(&mut self) -> Vec<u16> {
197 let inner = self.get_or_init_inner();
198 inner.interfaces.iter().map(|i| i.net_id).collect()
199 }
200
201 fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
202 let inner = self.inner.get_mut();
203 if self.init {
204 unsafe { inner.assume_init_mut() }
205 } else {
206 let imr = inner.write(StdTcpImInner::default());
207 self.init = true;
208 imr
209 }
210 }
211}
212
213impl InterfaceManager for StdTcpIm {
214 fn send<T: serde::Serialize>(
215 &mut self,
216 hdr: &Header,
217 data: &T,
218 ) -> Result<(), InterfaceSendError> {
219 assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
221
222 let inner = self.get_or_init_inner();
223 let Ok(idx) = inner
226 .interfaces
227 .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
228 else {
229 return Err(InterfaceSendError::NoRouteToDest);
230 };
231
232 let interface = &inner.interfaces[idx];
233 if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
235 return Err(InterfaceSendError::DestinationLocal);
236 }
237
238 let mut hdr = hdr.clone();
241 hdr.decrement_ttl()?;
242
243 if hdr.src.net_node_any() {
246 hdr.src.network_id = interface.net_id;
250 hdr.src.node_id = 1;
251 }
252
253 let res = interface.skt_tx.try_send(OwnedFrame {
254 hdr: hdr.to_headerseq_or_with_seq(|| {
255 let seq_no = inner.seq_no;
256 inner.seq_no = inner.seq_no.wrapping_add(1);
257 seq_no
258 }),
259 body: Ok(postcard::to_stdvec(data).unwrap()),
260 });
261 match res {
262 Ok(()) => Ok(()),
263 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
264 Err(TrySendError::Closed(_)) => {
265 let rem = inner.interfaces.remove(idx);
266 rem.closer.close();
267 Err(InterfaceSendError::NoRouteToDest)
268 }
269 }
270 }
271
272 fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
273 assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
275
276 let inner = self.get_or_init_inner();
277 let Ok(idx) = inner
281 .interfaces
282 .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
283 else {
284 return Err(InterfaceSendError::NoRouteToDest);
285 };
286
287 let interface = &inner.interfaces[idx];
288 if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
290 return Err(InterfaceSendError::DestinationLocal);
291 }
292
293 let mut hdr = hdr.clone();
296 hdr.decrement_ttl()?;
297
298 if hdr.src.net_node_any() {
301 hdr.src.network_id = interface.net_id;
305 hdr.src.node_id = 1;
306 }
307
308 let res = interface.skt_tx.try_send(OwnedFrame {
309 hdr: hdr.to_headerseq_or_with_seq(|| {
310 let seq_no = inner.seq_no;
311 inner.seq_no = inner.seq_no.wrapping_add(1);
312 seq_no
313 }),
314 body: Ok(data.to_vec()),
315 });
316 match res {
317 Ok(()) => Ok(()),
318 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
319 Err(TrySendError::Closed(_)) => {
320 inner.interfaces.remove(idx);
321 Err(InterfaceSendError::NoRouteToDest)
322 }
323 }
324 }
325
326 fn send_err(
327 &mut self,
328 hdr: &Header,
329 err: crate::ProtocolError,
330 ) -> Result<(), InterfaceSendError> {
331 assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
333
334 let inner = self.get_or_init_inner();
335 let Ok(idx) = inner
338 .interfaces
339 .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
340 else {
341 return Err(InterfaceSendError::NoRouteToDest);
342 };
343
344 let interface = &inner.interfaces[idx];
345 if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
347 return Err(InterfaceSendError::DestinationLocal);
348 }
349
350 let mut hdr = hdr.clone();
353 hdr.decrement_ttl()?;
354
355 if hdr.src.net_node_any() {
358 hdr.src.network_id = interface.net_id;
362 hdr.src.node_id = 1;
363 }
364
365 let res = interface.skt_tx.try_send(OwnedFrame {
366 hdr: hdr.to_headerseq_or_with_seq(|| {
367 let seq_no = inner.seq_no;
368 inner.seq_no = inner.seq_no.wrapping_add(1);
369 seq_no
370 }),
371 body: Err(err),
372 });
373 match res {
374 Ok(()) => Ok(()),
375 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
376 Err(TrySendError::Closed(_)) => {
377 let rem = inner.interfaces.remove(idx);
378 rem.closer.close();
379 Err(InterfaceSendError::NoRouteToDest)
380 }
381 }
382 }
383}
384
385impl Default for StdTcpIm {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl ConstInit for StdTcpIm {
392 #[allow(clippy::declare_interior_mutable_const)]
393 const INIT: Self = Self::new();
394}
395
396unsafe impl Sync for StdTcpIm {}
397
398impl StdTcpImInner {
401 pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
402 let closer = Arc::new(WaitQueue::new());
403 if self.interfaces.is_empty() {
404 let (ctx, crx) = channel(64);
406 let net_id = 1;
407 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
409 self.interfaces.push(StdTcpTxHdl {
410 net_id,
411 skt_tx: ctx,
412 closer: closer.clone(),
413 });
414 debug!("Alloc'd net_id 1");
415 return Some((net_id, closer));
416 } else if self.interfaces.len() >= 65534 {
417 warn!("Out of netids!");
418 return None;
419 }
420
421 if self.any_closed {
423 self.interfaces.retain(|int| {
424 let closed = int.closer.is_closed();
425 if closed {
426 info!("Collecting interface {}", int.net_id);
427 }
428 !closed
429 });
430 }
431
432 let mut net_id = 1;
433 for intfc in self.interfaces.iter() {
436 if intfc.net_id > net_id {
437 trace!("Found gap: {net_id}");
438 break;
439 }
440 debug_assert!(intfc.net_id == net_id);
441 net_id += 1;
442 }
443 debug_assert!(net_id > 0 && net_id != u16::MAX);
447 let (ctx, crx) = channel(64);
448 debug!("allocated net_id {net_id}");
449
450 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
451 self.interfaces.push(StdTcpTxHdl {
452 net_id,
453 skt_tx: ctx,
454 closer: closer.clone(),
455 });
456 self.interfaces.sort_unstable_by_key(|i| i.net_id);
457 Some((net_id, closer))
458 }
459}
460
461async fn tx_worker(
464 net_id: u16,
465 mut tx: OwnedWriteHalf,
466 mut rx: Receiver<OwnedFrame>,
467 closer: Arc<WaitQueue>,
468) {
469 info!("Started tx_worker for net_id {net_id}");
470 loop {
471 let rxf = rx.recv();
472 let clf = closer.wait();
473
474 let frame = select! {
475 r = rxf => {
476 if let Some(frame) = r {
477 frame
478 } else {
479 warn!("tx_worker {net_id} rx closed!");
480 closer.close();
481 break;
482 }
483 }
484 _c = clf => {
485 break;
486 }
487 };
488
489 let msg = ser_frame(frame);
490 debug!("sending pkt len:{} on net_id {net_id}", msg.len());
491 let res = tx.write_all(&msg).await;
492 if let Err(e) = res {
493 error!("Err: {e:?}");
494 break;
495 }
496 }
497 warn!("Closing interface {net_id}");
499}
500
501pub fn register_interface<R: ScopedRawMutex>(
502 stack: &'static NetStack<R, StdTcpIm>,
503 socket: TcpStream,
504) -> Result<StdTcpRecvHdl<R>, Error> {
505 let (rx, tx) = socket.into_split();
506 stack.with_interface_manager(|im| {
507 let inner = im.get_or_init_inner();
508 if let Some((addr, closer)) = inner.alloc_intfc(tx) {
509 Ok(StdTcpRecvHdl {
510 stack,
511 net_id: addr,
512 skt: rx,
513 closer,
514 })
515 } else {
516 Err(Error::OutOfNetIds)
517 }
518 })
519}