ergot_base/interface_manager/
std_tcp_router.rs1use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{Header, NetStack, interface_manager::std_utils::ser_frame};
23
24use maitake_sync::WaitQueue;
25use mutex::ScopedRawMutex;
26use tokio::sync::mpsc::Sender;
27use tokio::{
28 io::{AsyncReadExt, AsyncWriteExt},
29 net::{
30 TcpStream,
31 tcp::{OwnedReadHalf, OwnedWriteHalf},
32 },
33 select,
34 sync::mpsc::{Receiver, channel, error::TrySendError},
35};
36
37use super::{
38 ConstInit, InterfaceManager, InterfaceSendError,
39 std_utils::{
40 OwnedFrame, ReceiverError,
41 acc::{CobsAccumulator, FeedResult},
42 de_frame,
43 },
44};
45
46pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
47 stack: &'static NetStack<R, StdTcpIm>,
48 net_id: u16,
55 skt: OwnedReadHalf,
56 closer: Arc<WaitQueue>,
57}
58
59pub struct StdTcpIm {
60 init: bool,
61 inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
62}
63
64#[derive(Default)]
65pub struct StdTcpImInner {
66 interfaces: Vec<StdTcpTxHdl>,
73 seq_no: u16,
74 any_closed: bool,
75}
76
77#[derive(Debug, PartialEq)]
78pub enum Error {
79 OutOfNetIds,
80}
81
82struct StdTcpTxHdl {
83 net_id: u16,
84 skt_tx: Sender<OwnedFrame>,
85 closer: Arc<WaitQueue>,
86}
87
88impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
93 pub async fn run(mut self) -> Result<(), ReceiverError> {
94 let res = self.run_inner().await;
95 self.closer.close();
96 self.stack.with_interface_manager(|im| {
98 let inner = im.get_or_init_inner();
99 inner.any_closed = true;
100 });
101 res
102 }
103
104 pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
105 let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
106 let mut raw_buf = [0u8; 4096];
107
108 loop {
109 let rd = self.skt.read(&mut raw_buf);
110 let close = self.closer.wait();
111
112 let ct = select! {
113 r = rd => {
114 match r {
115 Ok(0) | Err(_) => {
116 println!("recv run {} closed", self.net_id);
117 return Err(ReceiverError::SocketClosed)
118 },
119 Ok(ct) => ct,
120 }
121 }
122 _c = close => {
123 return Err(ReceiverError::SocketClosed);
124 }
125 };
126
127 let buf = &raw_buf[..ct];
128 let mut window = buf;
129
130 'cobs: while !window.is_empty() {
131 window = match cobs_buf.feed_raw(window) {
132 FeedResult::Consumed => break 'cobs,
133 FeedResult::OverFull(new_wind) => new_wind,
134 FeedResult::DeserError(new_wind) => new_wind,
135 FeedResult::Success { data, remaining } => {
136 if let Some(mut frame) = de_frame(data) {
139 if frame.hdr.src.network_id == 0 {
143 assert_ne!(
144 frame.hdr.src.node_id, 0,
145 "we got a local packet remotely?"
146 );
147 assert_ne!(
148 frame.hdr.src.node_id, 1,
149 "someone is pretending to be us?"
150 );
151
152 frame.hdr.src.network_id = self.net_id;
153 }
154 let res = self.stack.send_raw(frame.hdr.into(), &frame.body);
160 match res {
161 Ok(()) => {}
162 Err(e) => {
163 panic!("recv->send error: {e:?}");
165 }
166 }
167 } else {
168 println!("Decode error! Ignoring frame on net_id {}", self.net_id);
169 }
170
171 remaining
172 }
173 };
174 }
175 }
176 }
177}
178
179impl StdTcpIm {
182 const fn new() -> Self {
183 Self {
184 init: false,
185 inner: UnsafeCell::new(MaybeUninit::uninit()),
186 }
187 }
188
189 pub fn get_nets(&mut self) -> Vec<u16> {
190 let inner = self.get_or_init_inner();
191 inner.interfaces.iter().map(|i| i.net_id).collect()
192 }
193
194 fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
195 let inner = self.inner.get_mut();
196 if self.init {
197 unsafe { inner.assume_init_mut() }
198 } else {
199 let imr = inner.write(StdTcpImInner::default());
200 self.init = true;
201 imr
202 }
203 }
204}
205
206impl InterfaceManager for StdTcpIm {
207 fn send<T: serde::Serialize>(
208 &mut self,
209 mut hdr: Header,
210 data: &T,
211 ) -> Result<(), InterfaceSendError> {
212 assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
214
215 let inner = self.get_or_init_inner();
216 let Ok(idx) = inner
219 .interfaces
220 .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
221 else {
222 return Err(InterfaceSendError::NoRouteToDest);
223 };
224
225 let interface = &inner.interfaces[idx];
226 if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
228 return Err(InterfaceSendError::DestinationLocal);
229 }
230 if hdr.src.net_node_any() {
233 hdr.src.network_id = interface.net_id;
237 hdr.src.node_id = 1;
238 }
239
240 let res = interface.skt_tx.try_send(OwnedFrame {
241 hdr: hdr.to_headerseq_or_with_seq(|| {
242 let seq_no = inner.seq_no;
243 inner.seq_no = inner.seq_no.wrapping_add(1);
244 seq_no
245 }),
246 body: postcard::to_stdvec(data).unwrap(),
247 });
248 match res {
249 Ok(()) => Ok(()),
250 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
251 Err(TrySendError::Closed(_)) => {
252 let rem = inner.interfaces.remove(idx);
253 rem.closer.close();
254 Err(InterfaceSendError::NoRouteToDest)
255 }
256 }
257 }
258
259 fn send_raw(&mut self, mut hdr: Header, data: &[u8]) -> Result<(), InterfaceSendError> {
260 assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
262
263 let inner = self.get_or_init_inner();
264 let Ok(idx) = inner
268 .interfaces
269 .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
270 else {
271 return Err(InterfaceSendError::NoRouteToDest);
272 };
273
274 let interface = &inner.interfaces[idx];
275 if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
277 return Err(InterfaceSendError::DestinationLocal);
278 }
279 if hdr.src.net_node_any() {
282 hdr.src.network_id = interface.net_id;
286 hdr.src.node_id = 1;
287 }
288
289 let res = interface.skt_tx.try_send(OwnedFrame {
290 hdr: hdr.to_headerseq_or_with_seq(|| {
291 let seq_no = inner.seq_no;
292 inner.seq_no = inner.seq_no.wrapping_add(1);
293 seq_no
294 }),
295 body: data.to_vec(),
296 });
297 match res {
298 Ok(()) => Ok(()),
299 Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
300 Err(TrySendError::Closed(_)) => {
301 inner.interfaces.remove(idx);
302 Err(InterfaceSendError::NoRouteToDest)
303 }
304 }
305 }
306}
307
308impl Default for StdTcpIm {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314impl ConstInit for StdTcpIm {
315 #[allow(clippy::declare_interior_mutable_const)]
316 const INIT: Self = Self::new();
317}
318
319unsafe impl Sync for StdTcpIm {}
320
321impl StdTcpImInner {
324 pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
325 let closer = Arc::new(WaitQueue::new());
326 if self.interfaces.is_empty() {
327 let (ctx, crx) = channel(64);
329 let net_id = 1;
330 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
332 self.interfaces.push(StdTcpTxHdl {
333 net_id,
334 skt_tx: ctx,
335 closer: closer.clone(),
336 });
337 println!("Alloc'd net_id 1");
338 return Some((net_id, closer));
339 } else if self.interfaces.len() >= 65534 {
340 println!("Out of netids!");
341 return None;
342 }
343
344 if self.any_closed {
346 self.interfaces.retain(|int| {
347 let closed = int.closer.is_closed();
348 if closed {
349 println!("Collecting interface {}", int.net_id);
350 }
351 !closed
352 });
353 }
354
355 let mut net_id = 1;
356 for intfc in self.interfaces.iter() {
359 if intfc.net_id > net_id {
360 println!("Found gap: {net_id}");
361 break;
362 }
363 debug_assert!(intfc.net_id == net_id);
364 net_id += 1;
365 }
366 debug_assert!(net_id > 0 && net_id != u16::MAX);
370 let (ctx, crx) = channel(64);
371 println!("allocated net_id {net_id}");
372
373 tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
374 self.interfaces.push(StdTcpTxHdl {
375 net_id,
376 skt_tx: ctx,
377 closer: closer.clone(),
378 });
379 self.interfaces.sort_unstable_by_key(|i| i.net_id);
380 Some((net_id, closer))
381 }
382}
383
384async fn tx_worker(
387 net_id: u16,
388 mut tx: OwnedWriteHalf,
389 mut rx: Receiver<OwnedFrame>,
390 closer: Arc<WaitQueue>,
391) {
392 println!("Started tx_worker for net_id {net_id}");
393 loop {
394 let rxf = rx.recv();
395 let clf = closer.wait();
396
397 let frame = select! {
398 r = rxf => {
399 if let Some(frame) = r {
400 frame
401 } else {
402 println!("tx_worker {net_id} rx closed!");
403 closer.close();
404 break;
405 }
406 }
407 _c = clf => {
408 break;
409 }
410 };
411
412 let msg = ser_frame(frame);
413 println!("sending pkt len:{} on net_id {net_id}", msg.len());
414 let res = tx.write_all(&msg).await;
415 if let Err(e) = res {
416 println!("Err: {e:?}");
417 break;
418 }
419 }
420 println!("Closing interface {net_id}");
422}
423
424pub fn register_interface<R: ScopedRawMutex>(
425 stack: &'static NetStack<R, StdTcpIm>,
426 socket: TcpStream,
427) -> Result<StdTcpRecvHdl<R>, Error> {
428 let (rx, tx) = socket.into_split();
429 stack.with_interface_manager(|im| {
430 let inner = im.get_or_init_inner();
431 if let Some((addr, closer)) = inner.alloc_intfc(tx) {
432 Ok(StdTcpRecvHdl {
433 stack,
434 net_id: addr,
435 skt: rx,
436 closer,
437 })
438 } else {
439 Err(Error::OutOfNetIds)
440 }
441 })
442}