ergot_base/interface_manager/
std_tcp_router.rs1use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{
23 Header, Key, NetStack,
24 interface_manager::{
25 ConstInit, InterfaceManager, InterfaceSendError,
26 cobs_stream::{self, Interface},
27 std_utils::{
28 ReceiverError, StdQueue,
29 acc::{CobsAccumulator, FeedResult},
30 },
31 wire_frames::{CommonHeader, de_frame},
32 },
33};
34
35use bbq2::prod_cons::stream::StreamConsumer;
36use bbq2::traits::storage::BoxedSlice;
37use log::{debug, error, info, trace, warn};
38use maitake_sync::WaitQueue;
39use mutex::ScopedRawMutex;
40use tokio::{
41 io::{AsyncReadExt, AsyncWriteExt},
42 net::{
43 TcpStream,
44 tcp::{OwnedReadHalf, OwnedWriteHalf},
45 },
46 select,
47};
48
49pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
50 stack: &'static NetStack<R, StdTcpIm>,
51 net_id: u16,
58 skt: OwnedReadHalf,
59 closer: Arc<WaitQueue>,
60}
61
62pub struct StdTcpIm {
63 init: bool,
64 inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
65}
66
67#[derive(Default)]
68pub struct StdTcpImInner {
69 interfaces: Vec<StdTcpTxHdl>,
76 seq_no: u16,
77 any_closed: bool,
78}
79
80#[derive(Debug, PartialEq)]
81pub enum Error {
82 OutOfNetIds,
83}
84
85struct StdTcpTxHdl {
86 net_id: u16,
87 skt_tx: Interface<StdQueue>,
88 closer: Arc<WaitQueue>,
89}
90
91impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
96 pub async fn run(mut self) -> Result<(), ReceiverError> {
97 let res = self.run_inner().await;
98 self.closer.close();
99 self.stack.with_interface_manager(|im| {
101 let inner = im.get_or_init_inner();
102 inner.any_closed = true;
103 });
104 res
105 }
106
107 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
108 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
109 let mut raw_buf = [0u8; 4096];
110
111 loop {
112 let rd = self.skt.read(&mut raw_buf);
113 let close = self.closer.wait();
114
115 let ct = select! {
116 r = rd => {
117 match r {
118 Ok(0) | Err(_) => {
119 warn!("recv run {} closed", self.net_id);
120 return Err(ReceiverError::SocketClosed)
121 },
122 Ok(ct) => ct,
123 }
124 }
125 _c = close => {
126 return Err(ReceiverError::SocketClosed);
127 }
128 };
129
130 let buf = &raw_buf[..ct];
131 let mut window = buf;
132
133 'cobs: while !window.is_empty() {
134 window = match cobs_buf.feed_raw(window) {
135 FeedResult::Consumed => break 'cobs,
136 FeedResult::OverFull(new_wind) => new_wind,
137 FeedResult::DeserError(new_wind) => new_wind,
138 FeedResult::Success { data, remaining } => {
139 if let Some(mut frame) = de_frame(data) {
142 if frame.hdr.src.network_id == 0 {
146 assert_ne!(
147 frame.hdr.src.node_id, 0,
148 "we got a local packet remotely?"
149 );
150 assert_ne!(
151 frame.hdr.src.node_id, 1,
152 "someone is pretending to be us?"
153 );
154
155 frame.hdr.src.network_id = self.net_id;
156 }
157 let hdr = frame.hdr.clone();
163 let hdr: Header = hdr.into();
164
165 let res = match frame.body {
166 Ok(body) => self.stack.send_raw(&hdr, body),
167 Err(e) => self.stack.send_err(&hdr, e),
168 };
169 match res {
170 Ok(()) => {}
171 Err(e) => {
172 warn!("recv->send error: {e:?}");
174 }
175 }
176 } else {
177 warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
178 }
179
180 remaining
181 }
182 };
183 }
184 }
185 }
186}
187
188impl StdTcpIm {
191 const fn new() -> Self {
192 Self {
193 init: false,
194 inner: UnsafeCell::new(MaybeUninit::uninit()),
195 }
196 }
197
198 pub fn get_nets(&mut self) -> Vec<u16> {
199 let inner = self.get_or_init_inner();
200 inner.interfaces.iter().map(|i| i.net_id).collect()
201 }
202
203 fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
204 let inner = self.inner.get_mut();
205 if self.init {
206 unsafe { inner.assume_init_mut() }
207 } else {
208 let imr = inner.write(StdTcpImInner::default());
209 self.init = true;
210 imr
211 }
212 }
213}
214
215impl StdTcpIm {
216 fn common_send<'a, 'b>(
217 &'b mut self,
218 ihdr: &'a Header,
219 ) -> Result<(&'b mut StdTcpTxHdl, CommonHeader, Option<&'a Key>), InterfaceSendError> {
220 assert!(!(ihdr.dst.port_id == 0 && ihdr.key.is_none()));
222
223 let inner = self.get_or_init_inner();
224 let Ok(idx) = inner
228 .interfaces
229 .binary_search_by_key(&ihdr.dst.network_id, |int| int.net_id)
230 else {
231 return Err(InterfaceSendError::NoRouteToDest);
232 };
233
234 let interface = &mut inner.interfaces[idx];
235 if ihdr.dst.network_id == interface.net_id && ihdr.dst.node_id == 1 {
237 return Err(InterfaceSendError::DestinationLocal);
238 }
239
240 let mut hdr = ihdr.clone();
243 hdr.decrement_ttl()?;
244
245 if hdr.src.net_node_any() {
248 hdr.src.network_id = interface.net_id;
252 hdr.src.node_id = 1;
253 }
254
255 let seq_no = inner.seq_no;
256 inner.seq_no = inner.seq_no.wrapping_add(1);
257
258 let header = CommonHeader {
259 src: hdr.src.as_u32(),
260 dst: hdr.dst.as_u32(),
261 seq_no,
262 kind: hdr.kind.0,
263 ttl: hdr.ttl,
264 };
265 let key = if [0, 255].contains(&hdr.dst.port_id) {
266 Some(ihdr.key.as_ref().unwrap())
267 } else {
268 None
269 };
270
271 Ok((interface, header, key))
272 }
273}
274
275impl InterfaceManager for StdTcpIm {
276 fn send<T: serde::Serialize>(
277 &mut self,
278 hdr: &Header,
279 data: &T,
280 ) -> Result<(), InterfaceSendError> {
281 let (intfc, header, key) = self.common_send(hdr)?;
282 let res = intfc.skt_tx.send_ty(&header, key, data);
283
284 match res {
285 Ok(()) => Ok(()),
286 Err(()) => Err(InterfaceSendError::InterfaceFull),
287 }
288 }
289
290 fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
291 let (intfc, header, key) = self.common_send(hdr)?;
292 let res = intfc.skt_tx.send_raw(&header, key, data);
293
294 match res {
295 Ok(()) => Ok(()),
296 Err(()) => Err(InterfaceSendError::InterfaceFull),
297 }
298 }
299
300 fn send_err(
301 &mut self,
302 hdr: &Header,
303 err: crate::ProtocolError,
304 ) -> Result<(), InterfaceSendError> {
305 let (intfc, header, _key) = self.common_send(hdr)?;
306 let res = intfc.skt_tx.send_err(&header, err);
307
308 match res {
309 Ok(()) => Ok(()),
310 Err(()) => Err(InterfaceSendError::InterfaceFull),
311 }
312 }
313}
314
315impl Default for StdTcpIm {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321impl ConstInit for StdTcpIm {
322 #[allow(clippy::declare_interior_mutable_const)]
323 const INIT: Self = Self::new();
324}
325
326unsafe impl Sync for StdTcpIm {}
327
328impl StdTcpImInner {
331 pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
332 let closer = Arc::new(WaitQueue::new());
333 if self.interfaces.is_empty() {
334 let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
336 let ctx = q.stream_producer();
337 let crx = q.stream_consumer();
338
339 let ctx = cobs_stream::Interface {
340 mtu: 1024,
341 prod: ctx,
342 };
343
344 let net_id = 1;
345 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
347 self.interfaces.push(StdTcpTxHdl {
348 net_id,
349 skt_tx: ctx,
350 closer: closer.clone(),
351 });
352 debug!("Alloc'd net_id 1");
353 return Some((net_id, closer));
354 } else if self.interfaces.len() >= 65534 {
355 warn!("Out of netids!");
356 return None;
357 }
358
359 if self.any_closed {
361 self.interfaces.retain(|int| {
362 let closed = int.closer.is_closed();
363 if closed {
364 info!("Collecting interface {}", int.net_id);
365 }
366 !closed
367 });
368 }
369
370 let mut net_id = 1;
371 for intfc in self.interfaces.iter() {
374 if intfc.net_id > net_id {
375 trace!("Found gap: {net_id}");
376 break;
377 }
378 debug_assert!(intfc.net_id == net_id);
379 net_id += 1;
380 }
381 debug_assert!(net_id > 0 && net_id != u16::MAX);
385
386 let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
387 let ctx = q.stream_producer();
388 let crx = q.stream_consumer();
389
390 let ctx = cobs_stream::Interface {
391 mtu: 1024,
392 prod: ctx,
393 };
394
395 debug!("allocated net_id {net_id}");
396
397 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
398 self.interfaces.push(StdTcpTxHdl {
399 net_id,
400 skt_tx: ctx,
401 closer: closer.clone(),
402 });
403 self.interfaces.sort_unstable_by_key(|i| i.net_id);
404 Some((net_id, closer))
405 }
406}
407
408async fn tx_worker(
411 net_id: u16,
412 mut tx: OwnedWriteHalf,
413 rx: StreamConsumer<StdQueue>,
414 closer: Arc<WaitQueue>,
415) {
416 info!("Started tx_worker for net_id {net_id}");
417 loop {
418 let rxf = rx.wait_read();
419 let clf = closer.wait();
420
421 let frame = select! {
422 r = rxf => r,
423 _c = clf => {
424 break;
425 }
426 };
427
428 let len = frame.len();
429 debug!("sending pkt len:{} on net_id {net_id}", len);
430 let res = tx.write_all(&frame).await;
431 frame.release(len);
432 if let Err(e) = res {
433 error!("Err: {e:?}");
434 break;
435 }
436 }
437 warn!("Closing interface {net_id}");
439}
440
441pub fn register_interface<R: ScopedRawMutex>(
442 stack: &'static NetStack<R, StdTcpIm>,
443 socket: TcpStream,
444) -> Result<StdTcpRecvHdl<R>, Error> {
445 let (rx, tx) = socket.into_split();
446 stack.with_interface_manager(|im| {
447 let inner = im.get_or_init_inner();
448 if let Some((addr, closer)) = inner.alloc_intfc(tx) {
449 Ok(StdTcpRecvHdl {
450 stack,
451 net_id: addr,
452 skt: rx,
453 closer,
454 })
455 } else {
456 Err(Error::OutOfNetIds)
457 }
458 })
459}