ergot_base/interface_manager/
std_tcp_router.rs1use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{
23 Header, NetStack,
24 interface_manager::{
25 ConstInit, InterfaceManager, InterfaceSendError,
26 cobs_stream::{self, Interface},
27 std_utils::{
28 ReceiverError, StdQueue,
29 acc::{CobsAccumulator, FeedResult},
30 },
31 wire_frames::{CommonHeader, de_frame},
32 },
33};
34
35use bbq2::prod_cons::stream::StreamConsumer;
36use bbq2::traits::storage::BoxedSlice;
37use log::{debug, error, info, trace, warn};
38use maitake_sync::WaitQueue;
39use mutex::ScopedRawMutex;
40use tokio::{
41 io::{AsyncReadExt, AsyncWriteExt},
42 net::{
43 TcpStream,
44 tcp::{OwnedReadHalf, OwnedWriteHalf},
45 },
46 select,
47};
48
49pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
50 stack: &'static NetStack<R, StdTcpIm>,
51 net_id: u16,
58 skt: OwnedReadHalf,
59 closer: Arc<WaitQueue>,
60}
61
62pub struct StdTcpIm {
63 init: bool,
64 inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
65}
66
67#[derive(Default)]
68pub struct StdTcpImInner {
69 interfaces: Vec<StdTcpTxHdl>,
76 seq_no: u16,
77 any_closed: bool,
78}
79
80#[derive(Debug, PartialEq)]
81pub enum Error {
82 OutOfNetIds,
83}
84
85struct StdTcpTxHdl {
86 net_id: u16,
87 skt_tx: Interface<StdQueue>,
88 closer: Arc<WaitQueue>,
89}
90
91impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
96 pub async fn run(mut self) -> Result<(), ReceiverError> {
97 let res = self.run_inner().await;
98 self.closer.close();
99 self.stack.with_interface_manager(|im| {
101 let inner = im.get_or_init_inner();
102 inner.any_closed = true;
103 });
104 res
105 }
106
107 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
108 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
109 let mut raw_buf = [0u8; 4096];
110
111 loop {
112 let rd = self.skt.read(&mut raw_buf);
113 let close = self.closer.wait();
114
115 let ct = select! {
116 r = rd => {
117 match r {
118 Ok(0) | Err(_) => {
119 warn!("recv run {} closed", self.net_id);
120 return Err(ReceiverError::SocketClosed)
121 },
122 Ok(ct) => ct,
123 }
124 }
125 _c = close => {
126 return Err(ReceiverError::SocketClosed);
127 }
128 };
129
130 let buf = &raw_buf[..ct];
131 let mut window = buf;
132
133 'cobs: while !window.is_empty() {
134 window = match cobs_buf.feed_raw(window) {
135 FeedResult::Consumed => break 'cobs,
136 FeedResult::OverFull(new_wind) => new_wind,
137 FeedResult::DeserError(new_wind) => new_wind,
138 FeedResult::Success { data, remaining } => {
139 if let Some(mut frame) = de_frame(data) {
142 if frame.hdr.src.network_id == 0 {
146 assert_ne!(
147 frame.hdr.src.node_id, 0,
148 "we got a local packet remotely?"
149 );
150 assert_ne!(
151 frame.hdr.src.node_id, 1,
152 "someone is pretending to be us?"
153 );
154
155 frame.hdr.src.network_id = self.net_id;
156 }
157 let hdr = frame.hdr.clone();
163 let hdr: Header = hdr.into();
164
165 let res = match frame.body {
166 Ok(body) => self.stack.send_raw(&hdr, frame.hdr_raw, body),
167 Err(e) => self.stack.send_err(&hdr, e),
168 };
169 match res {
170 Ok(()) => {}
171 Err(e) => {
172 warn!("recv->send error: {e:?}");
174 }
175 }
176 } else {
177 warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
178 }
179
180 remaining
181 }
182 };
183 }
184 }
185 }
186}
187
188impl StdTcpIm {
191 const fn new() -> Self {
192 Self {
193 init: false,
194 inner: UnsafeCell::new(MaybeUninit::uninit()),
195 }
196 }
197
198 pub fn get_nets(&mut self) -> Vec<u16> {
199 let inner = self.get_or_init_inner();
200 inner.interfaces.iter().map(|i| i.net_id).collect()
201 }
202
203 fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
204 let inner = self.inner.get_mut();
205 if self.init {
206 unsafe { inner.assume_init_mut() }
207 } else {
208 let imr = inner.write(StdTcpImInner::default());
209 self.init = true;
210 imr
211 }
212 }
213}
214
215impl StdTcpIm {
216 fn common_send<'a, 'b>(
217 &'b mut self,
218 ihdr: &'a Header,
219 ) -> Result<(&'b mut StdTcpTxHdl, CommonHeader), InterfaceSendError> {
220 assert!(!(ihdr.dst.port_id == 0 && ihdr.any_all.is_none()));
222
223 let inner = self.get_or_init_inner();
224 let Ok(idx) = inner
228 .interfaces
229 .binary_search_by_key(&ihdr.dst.network_id, |int| int.net_id)
230 else {
231 return Err(InterfaceSendError::NoRouteToDest);
232 };
233
234 let interface = &mut inner.interfaces[idx];
235 if ihdr.dst.network_id == interface.net_id && ihdr.dst.node_id == 1 {
237 return Err(InterfaceSendError::DestinationLocal);
238 }
239
240 let mut hdr = ihdr.clone();
243 hdr.decrement_ttl()?;
244
245 if hdr.src.net_node_any() {
248 hdr.src.network_id = interface.net_id;
252 hdr.src.node_id = 1;
253 }
254
255 let seq_no = inner.seq_no;
256 inner.seq_no = inner.seq_no.wrapping_add(1);
257
258 let header = CommonHeader {
259 src: hdr.src.as_u32(),
260 dst: hdr.dst.as_u32(),
261 seq_no,
262 kind: hdr.kind.0,
263 ttl: hdr.ttl,
264 };
265 if [0, 255].contains(&hdr.dst.port_id) {
266 if ihdr.any_all.is_none() {
267 return Err(InterfaceSendError::AnyPortMissingKey);
268 }
269 }
270
271 Ok((interface, header))
272 }
273}
274
275impl InterfaceManager for StdTcpIm {
276 fn send<T: serde::Serialize>(
277 &mut self,
278 hdr: &Header,
279 data: &T,
280 ) -> Result<(), InterfaceSendError> {
281 let (intfc, header) = self.common_send(hdr)?;
282 let res = intfc.skt_tx.send_ty(&header, hdr.any_all.as_ref(), data);
283
284 match res {
285 Ok(()) => Ok(()),
286 Err(()) => Err(InterfaceSendError::InterfaceFull),
287 }
288 }
289
290 fn send_raw(
291 &mut self,
292 hdr: &Header,
293 hdr_raw: &[u8],
294 data: &[u8],
295 ) -> Result<(), InterfaceSendError> {
296 let (intfc, header) = self.common_send(hdr)?;
297 let res = intfc.skt_tx.send_raw(&header, hdr_raw, data);
298
299 match res {
300 Ok(()) => Ok(()),
301 Err(()) => Err(InterfaceSendError::InterfaceFull),
302 }
303 }
304
305 fn send_err(
306 &mut self,
307 hdr: &Header,
308 err: crate::ProtocolError,
309 ) -> Result<(), InterfaceSendError> {
310 let (intfc, header) = self.common_send(hdr)?;
311 let res = intfc.skt_tx.send_err(&header, err);
312
313 match res {
314 Ok(()) => Ok(()),
315 Err(()) => Err(InterfaceSendError::InterfaceFull),
316 }
317 }
318}
319
320impl Default for StdTcpIm {
321 fn default() -> Self {
322 Self::new()
323 }
324}
325
326impl ConstInit for StdTcpIm {
327 #[allow(clippy::declare_interior_mutable_const)]
328 const INIT: Self = Self::new();
329}
330
331unsafe impl Sync for StdTcpIm {}
332
333impl StdTcpImInner {
336 pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
337 let closer = Arc::new(WaitQueue::new());
338 if self.interfaces.is_empty() {
339 let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
341 let ctx = q.stream_producer();
342 let crx = q.stream_consumer();
343
344 let ctx = cobs_stream::Interface {
345 mtu: 1024,
346 prod: ctx,
347 };
348
349 let net_id = 1;
350 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
352 self.interfaces.push(StdTcpTxHdl {
353 net_id,
354 skt_tx: ctx,
355 closer: closer.clone(),
356 });
357 debug!("Alloc'd net_id 1");
358 return Some((net_id, closer));
359 } else if self.interfaces.len() >= 65534 {
360 warn!("Out of netids!");
361 return None;
362 }
363
364 if self.any_closed {
366 self.interfaces.retain(|int| {
367 let closed = int.closer.is_closed();
368 if closed {
369 info!("Collecting interface {}", int.net_id);
370 }
371 !closed
372 });
373 }
374
375 let mut net_id = 1;
376 for intfc in self.interfaces.iter() {
379 if intfc.net_id > net_id {
380 trace!("Found gap: {net_id}");
381 break;
382 }
383 debug_assert!(intfc.net_id == net_id);
384 net_id += 1;
385 }
386 debug_assert!(net_id > 0 && net_id != u16::MAX);
390
391 let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
392 let ctx = q.stream_producer();
393 let crx = q.stream_consumer();
394
395 let ctx = cobs_stream::Interface {
396 mtu: 1024,
397 prod: ctx,
398 };
399
400 debug!("allocated net_id {net_id}");
401
402 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
403 self.interfaces.push(StdTcpTxHdl {
404 net_id,
405 skt_tx: ctx,
406 closer: closer.clone(),
407 });
408 self.interfaces.sort_unstable_by_key(|i| i.net_id);
409 Some((net_id, closer))
410 }
411}
412
413async fn tx_worker(
416 net_id: u16,
417 mut tx: OwnedWriteHalf,
418 rx: StreamConsumer<StdQueue>,
419 closer: Arc<WaitQueue>,
420) {
421 info!("Started tx_worker for net_id {net_id}");
422 loop {
423 let rxf = rx.wait_read();
424 let clf = closer.wait();
425
426 let frame = select! {
427 r = rxf => r,
428 _c = clf => {
429 break;
430 }
431 };
432
433 let len = frame.len();
434 debug!("sending pkt len:{} on net_id {net_id}", len);
435 let res = tx.write_all(&frame).await;
436 frame.release(len);
437 if let Err(e) = res {
438 error!("Err: {e:?}");
439 break;
440 }
441 }
442 warn!("Closing interface {net_id}");
444}
445
446pub fn register_interface<R: ScopedRawMutex>(
447 stack: &'static NetStack<R, StdTcpIm>,
448 socket: TcpStream,
449) -> Result<StdTcpRecvHdl<R>, Error> {
450 let (rx, tx) = socket.into_split();
451 stack.with_interface_manager(|im| {
452 let inner = im.get_or_init_inner();
453 if let Some((addr, closer)) = inner.alloc_intfc(tx) {
454 Ok(StdTcpRecvHdl {
455 stack,
456 net_id: addr,
457 skt: rx,
458 closer,
459 })
460 } else {
461 Err(Error::OutOfNetIds)
462 }
463 })
464}