ergot_base/interface_manager/
std_tcp_router.rs

1/*
2    Let's see, we're going to need:
3
4    * Some kind of hashmap/vec of active interfaces, by network id?
5        * IF we use a vec, we should NOT use the index as the ID, it may be sparse
6    * The actual interface type probably gets defined by the interface manager
7    * The interface which follows the pinned rules, and removes itself on drop
8    * THIS version of the routing interface probably will not allow for other routers,
9        we probably have to assume we are the only one assigning network IDs until a
10        later point
11    * The associated "simple" version of a client probably needs a stub routing interface
12        that picks up the network ID from the destination address
13    * Honestly we might want to have an `Arc` version of the netstack, or we need some kind
14        of Once construction.
15    * The interface manager needs some kind of "handle" construction so that we can get mut
16        access to it, or we need an accessor via the netstack
17*/
18
19use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{
23    Header, NetStack,
24    interface_manager::{
25        ConstInit, InterfaceManager, InterfaceSendError,
26        cobs_stream::{self, Interface},
27        std_utils::{
28            ReceiverError, StdQueue,
29            acc::{CobsAccumulator, FeedResult},
30        },
31        wire_frames::{CommonHeader, de_frame},
32    },
33};
34
35use bbq2::prod_cons::stream::StreamConsumer;
36use bbq2::traits::storage::BoxedSlice;
37use log::{debug, error, info, trace, warn};
38use maitake_sync::WaitQueue;
39use mutex::ScopedRawMutex;
40use tokio::{
41    io::{AsyncReadExt, AsyncWriteExt},
42    net::{
43        TcpStream,
44        tcp::{OwnedReadHalf, OwnedWriteHalf},
45    },
46    select,
47};
48
49pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
50    stack: &'static NetStack<R, StdTcpIm>,
51    // TODO: when we have more real networking and we could possibly
52    // have conflicting net_id assignments, we might need to have a
53    // shared ref to an Arc<AtomicU16> or something for net_id?
54    //
55    // for now, stdtcp assumes it is the only "seed" router, meaning that
56    // it is solely in charge of assigning netids
57    net_id: u16,
58    skt: OwnedReadHalf,
59    closer: Arc<WaitQueue>,
60}
61
62pub struct StdTcpIm {
63    init: bool,
64    inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
65}
66
67#[derive(Default)]
68pub struct StdTcpImInner {
69    // TODO: we probably want something like iddqd for a hashset sorted by
70    // net_id, as well as a list of "allocated" netids, mapped to the
71    // interface they are associated with
72    //
73    // TODO: for the no-std version of this, we will need to use the same
74    // intrusive list stuff that we use for sockets for holding interfaces.
75    interfaces: Vec<StdTcpTxHdl>,
76    seq_no: u16,
77    any_closed: bool,
78}
79
80#[derive(Debug, PartialEq)]
81pub enum Error {
82    OutOfNetIds,
83}
84
85struct StdTcpTxHdl {
86    net_id: u16,
87    skt_tx: Interface<StdQueue>,
88    closer: Arc<WaitQueue>,
89}
90
91// ---- impls ----
92
93// impl StdTcpRecvHdl
94
95impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
96    pub async fn run(mut self) -> Result<(), ReceiverError> {
97        let res = self.run_inner().await;
98        self.closer.close();
99        // todo: this could live somewhere else?
100        self.stack.with_interface_manager(|im| {
101            let inner = im.get_or_init_inner();
102            inner.any_closed = true;
103        });
104        res
105    }
106
107    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
108        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
109        let mut raw_buf = [0u8; 4096];
110
111        loop {
112            let rd = self.skt.read(&mut raw_buf);
113            let close = self.closer.wait();
114
115            let ct = select! {
116                r = rd => {
117                    match r {
118                        Ok(0) | Err(_) => {
119                            warn!("recv run {} closed", self.net_id);
120                            return Err(ReceiverError::SocketClosed)
121                        },
122                        Ok(ct) => ct,
123                    }
124                }
125                _c = close => {
126                    return Err(ReceiverError::SocketClosed);
127                }
128            };
129
130            let buf = &raw_buf[..ct];
131            let mut window = buf;
132
133            'cobs: while !window.is_empty() {
134                window = match cobs_buf.feed_raw(window) {
135                    FeedResult::Consumed => break 'cobs,
136                    FeedResult::OverFull(new_wind) => new_wind,
137                    FeedResult::DeserError(new_wind) => new_wind,
138                    FeedResult::Success { data, remaining } => {
139                        // Successfully de-cobs'd a packet, now we need to
140                        // do something with it.
141                        if let Some(mut frame) = de_frame(data) {
142                            // If the message comes in and has a src net_id of zero,
143                            // we should rewrite it so it isn't later understood as a
144                            // local packet.
145                            if frame.hdr.src.network_id == 0 {
146                                assert_ne!(
147                                    frame.hdr.src.node_id, 0,
148                                    "we got a local packet remotely?"
149                                );
150                                assert_ne!(
151                                    frame.hdr.src.node_id, 1,
152                                    "someone is pretending to be us?"
153                                );
154
155                                frame.hdr.src.network_id = self.net_id;
156                            }
157                            // TODO: if the destination IS self.net_id, we could rewrite the
158                            // dest net_id as zero to avoid a pass through the interface manager.
159                            //
160                            // If the dest is 0, should we rewrite the dest as self.net_id? This
161                            // is the opposite as above, but I dunno how that will work with responses
162                            let hdr = frame.hdr.clone();
163                            let hdr: Header = hdr.into();
164
165                            let res = match frame.body {
166                                Ok(body) => self.stack.send_raw(&hdr, frame.hdr_raw, body),
167                                Err(e) => self.stack.send_err(&hdr, e),
168                            };
169                            match res {
170                                Ok(()) => {}
171                                Err(e) => {
172                                    // TODO: match on error, potentially try to send NAK?
173                                    warn!("recv->send error: {e:?}");
174                                }
175                            }
176                        } else {
177                            warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
178                        }
179
180                        remaining
181                    }
182                };
183            }
184        }
185    }
186}
187
188// impl StdTcpIm
189
190impl StdTcpIm {
191    const fn new() -> Self {
192        Self {
193            init: false,
194            inner: UnsafeCell::new(MaybeUninit::uninit()),
195        }
196    }
197
198    pub fn get_nets(&mut self) -> Vec<u16> {
199        let inner = self.get_or_init_inner();
200        inner.interfaces.iter().map(|i| i.net_id).collect()
201    }
202
203    fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
204        let inner = self.inner.get_mut();
205        if self.init {
206            unsafe { inner.assume_init_mut() }
207        } else {
208            let imr = inner.write(StdTcpImInner::default());
209            self.init = true;
210            imr
211        }
212    }
213}
214
215impl StdTcpIm {
216    fn common_send<'a, 'b>(
217        &'b mut self,
218        ihdr: &'a Header,
219    ) -> Result<(&'b mut StdTcpTxHdl, CommonHeader), InterfaceSendError> {
220        // todo: make this state impossible? enum of dst w/ or w/o key?
221        assert!(!(ihdr.dst.port_id == 0 && ihdr.any_all.is_none()));
222
223        let inner = self.get_or_init_inner();
224        // todo: dedupe w/ send
225        //
226        // todo: we only handle direct dests
227        let Ok(idx) = inner
228            .interfaces
229            .binary_search_by_key(&ihdr.dst.network_id, |int| int.net_id)
230        else {
231            return Err(InterfaceSendError::NoRouteToDest);
232        };
233
234        let interface = &mut inner.interfaces[idx];
235        // TODO: Assumption: "we" are always node_id==1
236        if ihdr.dst.network_id == interface.net_id && ihdr.dst.node_id == 1 {
237            return Err(InterfaceSendError::DestinationLocal);
238        }
239
240        // Now that we've filtered out "dest local" checks, see if there is
241        // any TTL left before we send to the next hop
242        let mut hdr = ihdr.clone();
243        hdr.decrement_ttl()?;
244
245        // If the source is local, rewrite the source using this interface's
246        // information so responses can find their way back here
247        if hdr.src.net_node_any() {
248            // todo: if we know the destination is EXACTLY this network,
249            // we could leave the network_id local to allow for shorter
250            // addresses
251            hdr.src.network_id = interface.net_id;
252            hdr.src.node_id = 1;
253        }
254
255        let seq_no = inner.seq_no;
256        inner.seq_no = inner.seq_no.wrapping_add(1);
257
258        let header = CommonHeader {
259            src: hdr.src.as_u32(),
260            dst: hdr.dst.as_u32(),
261            seq_no,
262            kind: hdr.kind.0,
263            ttl: hdr.ttl,
264        };
265        if [0, 255].contains(&hdr.dst.port_id) {
266            if ihdr.any_all.is_none() {
267                return Err(InterfaceSendError::AnyPortMissingKey);
268            }
269        }
270
271        Ok((interface, header))
272    }
273}
274
275impl InterfaceManager for StdTcpIm {
276    fn send<T: serde::Serialize>(
277        &mut self,
278        hdr: &Header,
279        data: &T,
280    ) -> Result<(), InterfaceSendError> {
281        let (intfc, header) = self.common_send(hdr)?;
282        let res = intfc.skt_tx.send_ty(&header, hdr.any_all.as_ref(), data);
283
284        match res {
285            Ok(()) => Ok(()),
286            Err(()) => Err(InterfaceSendError::InterfaceFull),
287        }
288    }
289
290    fn send_raw(
291        &mut self,
292        hdr: &Header,
293        hdr_raw: &[u8],
294        data: &[u8],
295    ) -> Result<(), InterfaceSendError> {
296        let (intfc, header) = self.common_send(hdr)?;
297        let res = intfc.skt_tx.send_raw(&header, hdr_raw, data);
298
299        match res {
300            Ok(()) => Ok(()),
301            Err(()) => Err(InterfaceSendError::InterfaceFull),
302        }
303    }
304
305    fn send_err(
306        &mut self,
307        hdr: &Header,
308        err: crate::ProtocolError,
309    ) -> Result<(), InterfaceSendError> {
310        let (intfc, header) = self.common_send(hdr)?;
311        let res = intfc.skt_tx.send_err(&header, err);
312
313        match res {
314            Ok(()) => Ok(()),
315            Err(()) => Err(InterfaceSendError::InterfaceFull),
316        }
317    }
318}
319
320impl Default for StdTcpIm {
321    fn default() -> Self {
322        Self::new()
323    }
324}
325
326impl ConstInit for StdTcpIm {
327    #[allow(clippy::declare_interior_mutable_const)]
328    const INIT: Self = Self::new();
329}
330
331unsafe impl Sync for StdTcpIm {}
332
333// impl StdTcpImInner
334
335impl StdTcpImInner {
336    pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
337        let closer = Arc::new(WaitQueue::new());
338        if self.interfaces.is_empty() {
339            // todo: configurable channel depth
340            let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
341            let ctx = q.stream_producer();
342            let crx = q.stream_consumer();
343
344            let ctx = cobs_stream::Interface {
345                mtu: 1024,
346                prod: ctx,
347            };
348
349            let net_id = 1;
350            // TODO: We are spawning in a non-async context!
351            tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
352            self.interfaces.push(StdTcpTxHdl {
353                net_id,
354                skt_tx: ctx,
355                closer: closer.clone(),
356            });
357            debug!("Alloc'd net_id 1");
358            return Some((net_id, closer));
359        } else if self.interfaces.len() >= 65534 {
360            warn!("Out of netids!");
361            return None;
362        }
363
364        // If we closed any interfaces, then collect
365        if self.any_closed {
366            self.interfaces.retain(|int| {
367                let closed = int.closer.is_closed();
368                if closed {
369                    info!("Collecting interface {}", int.net_id);
370                }
371                !closed
372            });
373        }
374
375        let mut net_id = 1;
376        // we're not empty, find the lowest free address by counting the
377        // indexes, and if we find a discontinuity, allocate the first one.
378        for intfc in self.interfaces.iter() {
379            if intfc.net_id > net_id {
380                trace!("Found gap: {net_id}");
381                break;
382            }
383            debug_assert!(intfc.net_id == net_id);
384            net_id += 1;
385        }
386        // EITHER: We've found a gap that we can use, OR we've iterated all
387        // interfaces, which means that we had contiguous allocations but we
388        // have not exhausted the range.
389        debug_assert!(net_id > 0 && net_id != u16::MAX);
390
391        let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
392        let ctx = q.stream_producer();
393        let crx = q.stream_consumer();
394
395        let ctx = cobs_stream::Interface {
396            mtu: 1024,
397            prod: ctx,
398        };
399
400        debug!("allocated net_id {net_id}");
401
402        tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
403        self.interfaces.push(StdTcpTxHdl {
404            net_id,
405            skt_tx: ctx,
406            closer: closer.clone(),
407        });
408        self.interfaces.sort_unstable_by_key(|i| i.net_id);
409        Some((net_id, closer))
410    }
411}
412
413// Helper functions
414
415async fn tx_worker(
416    net_id: u16,
417    mut tx: OwnedWriteHalf,
418    rx: StreamConsumer<StdQueue>,
419    closer: Arc<WaitQueue>,
420) {
421    info!("Started tx_worker for net_id {net_id}");
422    loop {
423        let rxf = rx.wait_read();
424        let clf = closer.wait();
425
426        let frame = select! {
427            r = rxf => r,
428            _c = clf => {
429                break;
430            }
431        };
432
433        let len = frame.len();
434        debug!("sending pkt len:{} on net_id {net_id}", len);
435        let res = tx.write_all(&frame).await;
436        frame.release(len);
437        if let Err(e) = res {
438            error!("Err: {e:?}");
439            break;
440        }
441    }
442    // TODO: GC waker?
443    warn!("Closing interface {net_id}");
444}
445
446pub fn register_interface<R: ScopedRawMutex>(
447    stack: &'static NetStack<R, StdTcpIm>,
448    socket: TcpStream,
449) -> Result<StdTcpRecvHdl<R>, Error> {
450    let (rx, tx) = socket.into_split();
451    stack.with_interface_manager(|im| {
452        let inner = im.get_or_init_inner();
453        if let Some((addr, closer)) = inner.alloc_intfc(tx) {
454            Ok(StdTcpRecvHdl {
455                stack,
456                net_id: addr,
457                skt: rx,
458                closer,
459            })
460        } else {
461            Err(Error::OutOfNetIds)
462        }
463    })
464}