ergot_base/interface_manager/
std_tcp_router.rs

1/*
2    Let's see, we're going to need:
3
4    * Some kind of hashmap/vec of active interfaces, by network id?
5        * IF we use a vec, we should NOT use the index as the ID, it may be sparse
6    * The actual interface type probably gets defined by the interface manager
7    * The interface which follows the pinned rules, and removes itself on drop
8    * THIS version of the routing interface probably will not allow for other routers,
9        we probably have to assume we are the only one assigning network IDs until a
10        later point
11    * The associated "simple" version of a client probably needs a stub routing interface
12        that picks up the network ID from the destination address
13    * Honestly we might want to have an `Arc` version of the netstack, or we need some kind
14        of Once construction.
15    * The interface manager needs some kind of "handle" construction so that we can get mut
16        access to it, or we need an accessor via the netstack
17*/
18
19use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{Header, NetStack, interface_manager::std_utils::ser_frame};
23
24use log::{debug, error, info, trace, warn};
25use maitake_sync::WaitQueue;
26use mutex::ScopedRawMutex;
27use tokio::sync::mpsc::Sender;
28use tokio::{
29    io::{AsyncReadExt, AsyncWriteExt},
30    net::{
31        TcpStream,
32        tcp::{OwnedReadHalf, OwnedWriteHalf},
33    },
34    select,
35    sync::mpsc::{Receiver, channel, error::TrySendError},
36};
37
38use super::{
39    ConstInit, InterfaceManager, InterfaceSendError,
40    std_utils::{
41        OwnedFrame, ReceiverError,
42        acc::{CobsAccumulator, FeedResult},
43        de_frame,
44    },
45};
46
47pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
48    stack: &'static NetStack<R, StdTcpIm>,
49    // TODO: when we have more real networking and we could possibly
50    // have conflicting net_id assignments, we might need to have a
51    // shared ref to an Arc<AtomicU16> or something for net_id?
52    //
53    // for now, stdtcp assumes it is the only "seed" router, meaning that
54    // it is solely in charge of assigning netids
55    net_id: u16,
56    skt: OwnedReadHalf,
57    closer: Arc<WaitQueue>,
58}
59
60pub struct StdTcpIm {
61    init: bool,
62    inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
63}
64
65#[derive(Default)]
66pub struct StdTcpImInner {
67    // TODO: we probably want something like iddqd for a hashset sorted by
68    // net_id, as well as a list of "allocated" netids, mapped to the
69    // interface they are associated with
70    //
71    // TODO: for the no-std version of this, we will need to use the same
72    // intrusive list stuff that we use for sockets for holding interfaces.
73    interfaces: Vec<StdTcpTxHdl>,
74    seq_no: u16,
75    any_closed: bool,
76}
77
78#[derive(Debug, PartialEq)]
79pub enum Error {
80    OutOfNetIds,
81}
82
83struct StdTcpTxHdl {
84    net_id: u16,
85    skt_tx: Sender<OwnedFrame>,
86    closer: Arc<WaitQueue>,
87}
88
89// ---- impls ----
90
91// impl StdTcpRecvHdl
92
93impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
94    pub async fn run(mut self) -> Result<(), ReceiverError> {
95        let res = self.run_inner().await;
96        self.closer.close();
97        // todo: this could live somewhere else?
98        self.stack.with_interface_manager(|im| {
99            let inner = im.get_or_init_inner();
100            inner.any_closed = true;
101        });
102        res
103    }
104
105    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
106        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
107        let mut raw_buf = [0u8; 4096];
108
109        loop {
110            let rd = self.skt.read(&mut raw_buf);
111            let close = self.closer.wait();
112
113            let ct = select! {
114                r = rd => {
115                    match r {
116                        Ok(0) | Err(_) => {
117                            warn!("recv run {} closed", self.net_id);
118                            return Err(ReceiverError::SocketClosed)
119                        },
120                        Ok(ct) => ct,
121                    }
122                }
123                _c = close => {
124                    return Err(ReceiverError::SocketClosed);
125                }
126            };
127
128            let buf = &raw_buf[..ct];
129            let mut window = buf;
130
131            'cobs: while !window.is_empty() {
132                window = match cobs_buf.feed_raw(window) {
133                    FeedResult::Consumed => break 'cobs,
134                    FeedResult::OverFull(new_wind) => new_wind,
135                    FeedResult::DeserError(new_wind) => new_wind,
136                    FeedResult::Success { data, remaining } => {
137                        // Successfully de-cobs'd a packet, now we need to
138                        // do something with it.
139                        if let Some(mut frame) = de_frame(data) {
140                            // If the message comes in and has a src net_id of zero,
141                            // we should rewrite it so it isn't later understood as a
142                            // local packet.
143                            if frame.hdr.src.network_id == 0 {
144                                assert_ne!(
145                                    frame.hdr.src.node_id, 0,
146                                    "we got a local packet remotely?"
147                                );
148                                assert_ne!(
149                                    frame.hdr.src.node_id, 1,
150                                    "someone is pretending to be us?"
151                                );
152
153                                frame.hdr.src.network_id = self.net_id;
154                            }
155                            // TODO: if the destination IS self.net_id, we could rewrite the
156                            // dest net_id as zero to avoid a pass through the interface manager.
157                            //
158                            // If the dest is 0, should we rewrite the dest as self.net_id? This
159                            // is the opposite as above, but I dunno how that will work with responses
160                            let hdr = frame.hdr.clone();
161                            let hdr: Header = hdr.into();
162
163                            let res = match frame.body {
164                                Ok(body) => self.stack.send_raw(&hdr, &body),
165                                Err(e) => self.stack.send_err(&hdr, e),
166                            };
167                            match res {
168                                Ok(()) => {}
169                                Err(e) => {
170                                    // TODO: match on error, potentially try to send NAK?
171                                    warn!("recv->send error: {e:?}");
172                                }
173                            }
174                        } else {
175                            warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
176                        }
177
178                        remaining
179                    }
180                };
181            }
182        }
183    }
184}
185
186// impl StdTcpIm
187
188impl StdTcpIm {
189    const fn new() -> Self {
190        Self {
191            init: false,
192            inner: UnsafeCell::new(MaybeUninit::uninit()),
193        }
194    }
195
196    pub fn get_nets(&mut self) -> Vec<u16> {
197        let inner = self.get_or_init_inner();
198        inner.interfaces.iter().map(|i| i.net_id).collect()
199    }
200
201    fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
202        let inner = self.inner.get_mut();
203        if self.init {
204            unsafe { inner.assume_init_mut() }
205        } else {
206            let imr = inner.write(StdTcpImInner::default());
207            self.init = true;
208            imr
209        }
210    }
211}
212
213impl InterfaceManager for StdTcpIm {
214    fn send<T: serde::Serialize>(
215        &mut self,
216        hdr: &Header,
217        data: &T,
218    ) -> Result<(), InterfaceSendError> {
219        // todo: make this state impossible? enum of dst w/ or w/o key?
220        assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
221
222        let inner = self.get_or_init_inner();
223        // todo: we only handle direct dests, we will probably also want to search
224        // some kind of net_id:interface routing table
225        let Ok(idx) = inner
226            .interfaces
227            .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
228        else {
229            return Err(InterfaceSendError::NoRouteToDest);
230        };
231
232        let interface = &inner.interfaces[idx];
233        // TODO: Assumption: "we" are always node_id==1
234        if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
235            return Err(InterfaceSendError::DestinationLocal);
236        }
237
238        // Now that we've filtered out "dest local" checks, see if there is
239        // any TTL left before we send to the next hop
240        let mut hdr = hdr.clone();
241        hdr.decrement_ttl()?;
242
243        // If the source is local, rewrite the source using this interface's
244        // information so responses can find their way back here
245        if hdr.src.net_node_any() {
246            // todo: if we know the destination is EXACTLY this network,
247            // we could leave the network_id local to allow for shorter
248            // addresses
249            hdr.src.network_id = interface.net_id;
250            hdr.src.node_id = 1;
251        }
252
253        let res = interface.skt_tx.try_send(OwnedFrame {
254            hdr: hdr.to_headerseq_or_with_seq(|| {
255                let seq_no = inner.seq_no;
256                inner.seq_no = inner.seq_no.wrapping_add(1);
257                seq_no
258            }),
259            body: Ok(postcard::to_stdvec(data).unwrap()),
260        });
261        match res {
262            Ok(()) => Ok(()),
263            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
264            Err(TrySendError::Closed(_)) => {
265                let rem = inner.interfaces.remove(idx);
266                rem.closer.close();
267                Err(InterfaceSendError::NoRouteToDest)
268            }
269        }
270    }
271
272    fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
273        // todo: make this state impossible? enum of dst w/ or w/o key?
274        assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
275
276        let inner = self.get_or_init_inner();
277        // todo: dedupe w/ send
278        //
279        // todo: we only handle direct dests
280        let Ok(idx) = inner
281            .interfaces
282            .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
283        else {
284            return Err(InterfaceSendError::NoRouteToDest);
285        };
286
287        let interface = &inner.interfaces[idx];
288        // TODO: Assumption: "we" are always node_id==1
289        if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
290            return Err(InterfaceSendError::DestinationLocal);
291        }
292
293        // Now that we've filtered out "dest local" checks, see if there is
294        // any TTL left before we send to the next hop
295        let mut hdr = hdr.clone();
296        hdr.decrement_ttl()?;
297
298        // If the source is local, rewrite the source using this interface's
299        // information so responses can find their way back here
300        if hdr.src.net_node_any() {
301            // todo: if we know the destination is EXACTLY this network,
302            // we could leave the network_id local to allow for shorter
303            // addresses
304            hdr.src.network_id = interface.net_id;
305            hdr.src.node_id = 1;
306        }
307
308        let res = interface.skt_tx.try_send(OwnedFrame {
309            hdr: hdr.to_headerseq_or_with_seq(|| {
310                let seq_no = inner.seq_no;
311                inner.seq_no = inner.seq_no.wrapping_add(1);
312                seq_no
313            }),
314            body: Ok(data.to_vec()),
315        });
316        match res {
317            Ok(()) => Ok(()),
318            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
319            Err(TrySendError::Closed(_)) => {
320                inner.interfaces.remove(idx);
321                Err(InterfaceSendError::NoRouteToDest)
322            }
323        }
324    }
325
326    fn send_err(
327        &mut self,
328        hdr: &Header,
329        err: crate::ProtocolError,
330    ) -> Result<(), InterfaceSendError> {
331        // todo: make this state impossible? enum of dst w/ or w/o key?
332        assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
333
334        let inner = self.get_or_init_inner();
335        // todo: we only handle direct dests, we will probably also want to search
336        // some kind of net_id:interface routing table
337        let Ok(idx) = inner
338            .interfaces
339            .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
340        else {
341            return Err(InterfaceSendError::NoRouteToDest);
342        };
343
344        let interface = &inner.interfaces[idx];
345        // TODO: Assumption: "we" are always node_id==1
346        if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
347            return Err(InterfaceSendError::DestinationLocal);
348        }
349
350        // Now that we've filtered out "dest local" checks, see if there is
351        // any TTL left before we send to the next hop
352        let mut hdr = hdr.clone();
353        hdr.decrement_ttl()?;
354
355        // If the source is local, rewrite the source using this interface's
356        // information so responses can find their way back here
357        if hdr.src.net_node_any() {
358            // todo: if we know the destination is EXACTLY this network,
359            // we could leave the network_id local to allow for shorter
360            // addresses
361            hdr.src.network_id = interface.net_id;
362            hdr.src.node_id = 1;
363        }
364
365        let res = interface.skt_tx.try_send(OwnedFrame {
366            hdr: hdr.to_headerseq_or_with_seq(|| {
367                let seq_no = inner.seq_no;
368                inner.seq_no = inner.seq_no.wrapping_add(1);
369                seq_no
370            }),
371            body: Err(err),
372        });
373        match res {
374            Ok(()) => Ok(()),
375            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
376            Err(TrySendError::Closed(_)) => {
377                let rem = inner.interfaces.remove(idx);
378                rem.closer.close();
379                Err(InterfaceSendError::NoRouteToDest)
380            }
381        }
382    }
383}
384
385impl Default for StdTcpIm {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391impl ConstInit for StdTcpIm {
392    #[allow(clippy::declare_interior_mutable_const)]
393    const INIT: Self = Self::new();
394}
395
396unsafe impl Sync for StdTcpIm {}
397
398// impl StdTcpImInner
399
400impl StdTcpImInner {
401    pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
402        let closer = Arc::new(WaitQueue::new());
403        if self.interfaces.is_empty() {
404            // todo: configurable channel depth
405            let (ctx, crx) = channel(64);
406            let net_id = 1;
407            // TODO: We are spawning in a non-async context!
408            tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
409            self.interfaces.push(StdTcpTxHdl {
410                net_id,
411                skt_tx: ctx,
412                closer: closer.clone(),
413            });
414            debug!("Alloc'd net_id 1");
415            return Some((net_id, closer));
416        } else if self.interfaces.len() >= 65534 {
417            warn!("Out of netids!");
418            return None;
419        }
420
421        // If we closed any interfaces, then collect
422        if self.any_closed {
423            self.interfaces.retain(|int| {
424                let closed = int.closer.is_closed();
425                if closed {
426                    info!("Collecting interface {}", int.net_id);
427                }
428                !closed
429            });
430        }
431
432        let mut net_id = 1;
433        // we're not empty, find the lowest free address by counting the
434        // indexes, and if we find a discontinuity, allocate the first one.
435        for intfc in self.interfaces.iter() {
436            if intfc.net_id > net_id {
437                trace!("Found gap: {net_id}");
438                break;
439            }
440            debug_assert!(intfc.net_id == net_id);
441            net_id += 1;
442        }
443        // EITHER: We've found a gap that we can use, OR we've iterated all
444        // interfaces, which means that we had contiguous allocations but we
445        // have not exhausted the range.
446        debug_assert!(net_id > 0 && net_id != u16::MAX);
447        let (ctx, crx) = channel(64);
448        debug!("allocated net_id {net_id}");
449
450        tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
451        self.interfaces.push(StdTcpTxHdl {
452            net_id,
453            skt_tx: ctx,
454            closer: closer.clone(),
455        });
456        self.interfaces.sort_unstable_by_key(|i| i.net_id);
457        Some((net_id, closer))
458    }
459}
460
461// Helper functions
462
463async fn tx_worker(
464    net_id: u16,
465    mut tx: OwnedWriteHalf,
466    mut rx: Receiver<OwnedFrame>,
467    closer: Arc<WaitQueue>,
468) {
469    info!("Started tx_worker for net_id {net_id}");
470    loop {
471        let rxf = rx.recv();
472        let clf = closer.wait();
473
474        let frame = select! {
475            r = rxf => {
476                if let Some(frame) = r {
477                    frame
478                } else {
479                    warn!("tx_worker {net_id} rx closed!");
480                    closer.close();
481                    break;
482                }
483            }
484            _c = clf => {
485                break;
486            }
487        };
488
489        let msg = ser_frame(frame);
490        debug!("sending pkt len:{} on net_id {net_id}", msg.len());
491        let res = tx.write_all(&msg).await;
492        if let Err(e) = res {
493            error!("Err: {e:?}");
494            break;
495        }
496    }
497    // TODO: GC waker?
498    warn!("Closing interface {net_id}");
499}
500
501pub fn register_interface<R: ScopedRawMutex>(
502    stack: &'static NetStack<R, StdTcpIm>,
503    socket: TcpStream,
504) -> Result<StdTcpRecvHdl<R>, Error> {
505    let (rx, tx) = socket.into_split();
506    stack.with_interface_manager(|im| {
507        let inner = im.get_or_init_inner();
508        if let Some((addr, closer)) = inner.alloc_intfc(tx) {
509            Ok(StdTcpRecvHdl {
510                stack,
511                net_id: addr,
512                skt: rx,
513                closer,
514            })
515        } else {
516            Err(Error::OutOfNetIds)
517        }
518    })
519}