ergot_base/interface_manager/
std_tcp_router.rs

1/*
2    Let's see, we're going to need:
3
4    * Some kind of hashmap/vec of active interfaces, by network id?
5        * IF we use a vec, we should NOT use the index as the ID, it may be sparse
6    * The actual interface type probably gets defined by the interface manager
7    * The interface which follows the pinned rules, and removes itself on drop
8    * THIS version of the routing interface probably will not allow for other routers,
9        we probably have to assume we are the only one assigning network IDs until a
10        later point
11    * The associated "simple" version of a client probably needs a stub routing interface
12        that picks up the network ID from the destination address
13    * Honestly we might want to have an `Arc` version of the netstack, or we need some kind
14        of Once construction.
15    * The interface manager needs some kind of "handle" construction so that we can get mut
16        access to it, or we need an accessor via the netstack
17*/
18
19use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{
23    Header, Key, NetStack,
24    interface_manager::{
25        ConstInit, InterfaceManager, InterfaceSendError,
26        cobs_stream::{self, Interface},
27        std_utils::{
28            ReceiverError, StdQueue,
29            acc::{CobsAccumulator, FeedResult},
30        },
31        wire_frames::{CommonHeader, de_frame},
32    },
33};
34
35use bbq2::prod_cons::stream::StreamConsumer;
36use bbq2::traits::storage::BoxedSlice;
37use log::{debug, error, info, trace, warn};
38use maitake_sync::WaitQueue;
39use mutex::ScopedRawMutex;
40use tokio::{
41    io::{AsyncReadExt, AsyncWriteExt},
42    net::{
43        TcpStream,
44        tcp::{OwnedReadHalf, OwnedWriteHalf},
45    },
46    select,
47};
48
49pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
50    stack: &'static NetStack<R, StdTcpIm>,
51    // TODO: when we have more real networking and we could possibly
52    // have conflicting net_id assignments, we might need to have a
53    // shared ref to an Arc<AtomicU16> or something for net_id?
54    //
55    // for now, stdtcp assumes it is the only "seed" router, meaning that
56    // it is solely in charge of assigning netids
57    net_id: u16,
58    skt: OwnedReadHalf,
59    closer: Arc<WaitQueue>,
60}
61
62pub struct StdTcpIm {
63    init: bool,
64    inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
65}
66
67#[derive(Default)]
68pub struct StdTcpImInner {
69    // TODO: we probably want something like iddqd for a hashset sorted by
70    // net_id, as well as a list of "allocated" netids, mapped to the
71    // interface they are associated with
72    //
73    // TODO: for the no-std version of this, we will need to use the same
74    // intrusive list stuff that we use for sockets for holding interfaces.
75    interfaces: Vec<StdTcpTxHdl>,
76    seq_no: u16,
77    any_closed: bool,
78}
79
80#[derive(Debug, PartialEq)]
81pub enum Error {
82    OutOfNetIds,
83}
84
85struct StdTcpTxHdl {
86    net_id: u16,
87    skt_tx: Interface<StdQueue>,
88    closer: Arc<WaitQueue>,
89}
90
91// ---- impls ----
92
93// impl StdTcpRecvHdl
94
95impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
96    pub async fn run(mut self) -> Result<(), ReceiverError> {
97        let res = self.run_inner().await;
98        self.closer.close();
99        // todo: this could live somewhere else?
100        self.stack.with_interface_manager(|im| {
101            let inner = im.get_or_init_inner();
102            inner.any_closed = true;
103        });
104        res
105    }
106
107    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
108        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
109        let mut raw_buf = [0u8; 4096];
110
111        loop {
112            let rd = self.skt.read(&mut raw_buf);
113            let close = self.closer.wait();
114
115            let ct = select! {
116                r = rd => {
117                    match r {
118                        Ok(0) | Err(_) => {
119                            warn!("recv run {} closed", self.net_id);
120                            return Err(ReceiverError::SocketClosed)
121                        },
122                        Ok(ct) => ct,
123                    }
124                }
125                _c = close => {
126                    return Err(ReceiverError::SocketClosed);
127                }
128            };
129
130            let buf = &raw_buf[..ct];
131            let mut window = buf;
132
133            'cobs: while !window.is_empty() {
134                window = match cobs_buf.feed_raw(window) {
135                    FeedResult::Consumed => break 'cobs,
136                    FeedResult::OverFull(new_wind) => new_wind,
137                    FeedResult::DeserError(new_wind) => new_wind,
138                    FeedResult::Success { data, remaining } => {
139                        // Successfully de-cobs'd a packet, now we need to
140                        // do something with it.
141                        if let Some(mut frame) = de_frame(data) {
142                            // If the message comes in and has a src net_id of zero,
143                            // we should rewrite it so it isn't later understood as a
144                            // local packet.
145                            if frame.hdr.src.network_id == 0 {
146                                assert_ne!(
147                                    frame.hdr.src.node_id, 0,
148                                    "we got a local packet remotely?"
149                                );
150                                assert_ne!(
151                                    frame.hdr.src.node_id, 1,
152                                    "someone is pretending to be us?"
153                                );
154
155                                frame.hdr.src.network_id = self.net_id;
156                            }
157                            // TODO: if the destination IS self.net_id, we could rewrite the
158                            // dest net_id as zero to avoid a pass through the interface manager.
159                            //
160                            // If the dest is 0, should we rewrite the dest as self.net_id? This
161                            // is the opposite as above, but I dunno how that will work with responses
162                            let hdr = frame.hdr.clone();
163                            let hdr: Header = hdr.into();
164
165                            let res = match frame.body {
166                                Ok(body) => self.stack.send_raw(&hdr, body),
167                                Err(e) => self.stack.send_err(&hdr, e),
168                            };
169                            match res {
170                                Ok(()) => {}
171                                Err(e) => {
172                                    // TODO: match on error, potentially try to send NAK?
173                                    warn!("recv->send error: {e:?}");
174                                }
175                            }
176                        } else {
177                            warn!("Decode error! Ignoring frame on net_id {}", self.net_id);
178                        }
179
180                        remaining
181                    }
182                };
183            }
184        }
185    }
186}
187
188// impl StdTcpIm
189
190impl StdTcpIm {
191    const fn new() -> Self {
192        Self {
193            init: false,
194            inner: UnsafeCell::new(MaybeUninit::uninit()),
195        }
196    }
197
198    pub fn get_nets(&mut self) -> Vec<u16> {
199        let inner = self.get_or_init_inner();
200        inner.interfaces.iter().map(|i| i.net_id).collect()
201    }
202
203    fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
204        let inner = self.inner.get_mut();
205        if self.init {
206            unsafe { inner.assume_init_mut() }
207        } else {
208            let imr = inner.write(StdTcpImInner::default());
209            self.init = true;
210            imr
211        }
212    }
213}
214
215impl StdTcpIm {
216    fn common_send<'a, 'b>(
217        &'b mut self,
218        ihdr: &'a Header,
219    ) -> Result<(&'b mut StdTcpTxHdl, CommonHeader, Option<&'a Key>), InterfaceSendError> {
220        // todo: make this state impossible? enum of dst w/ or w/o key?
221        assert!(!(ihdr.dst.port_id == 0 && ihdr.key.is_none()));
222
223        let inner = self.get_or_init_inner();
224        // todo: dedupe w/ send
225        //
226        // todo: we only handle direct dests
227        let Ok(idx) = inner
228            .interfaces
229            .binary_search_by_key(&ihdr.dst.network_id, |int| int.net_id)
230        else {
231            return Err(InterfaceSendError::NoRouteToDest);
232        };
233
234        let interface = &mut inner.interfaces[idx];
235        // TODO: Assumption: "we" are always node_id==1
236        if ihdr.dst.network_id == interface.net_id && ihdr.dst.node_id == 1 {
237            return Err(InterfaceSendError::DestinationLocal);
238        }
239
240        // Now that we've filtered out "dest local" checks, see if there is
241        // any TTL left before we send to the next hop
242        let mut hdr = ihdr.clone();
243        hdr.decrement_ttl()?;
244
245        // If the source is local, rewrite the source using this interface's
246        // information so responses can find their way back here
247        if hdr.src.net_node_any() {
248            // todo: if we know the destination is EXACTLY this network,
249            // we could leave the network_id local to allow for shorter
250            // addresses
251            hdr.src.network_id = interface.net_id;
252            hdr.src.node_id = 1;
253        }
254
255        let seq_no = inner.seq_no;
256        inner.seq_no = inner.seq_no.wrapping_add(1);
257
258        let header = CommonHeader {
259            src: hdr.src.as_u32(),
260            dst: hdr.dst.as_u32(),
261            seq_no,
262            kind: hdr.kind.0,
263            ttl: hdr.ttl,
264        };
265        let key = if [0, 255].contains(&hdr.dst.port_id) {
266            Some(ihdr.key.as_ref().unwrap())
267        } else {
268            None
269        };
270
271        Ok((interface, header, key))
272    }
273}
274
275impl InterfaceManager for StdTcpIm {
276    fn send<T: serde::Serialize>(
277        &mut self,
278        hdr: &Header,
279        data: &T,
280    ) -> Result<(), InterfaceSendError> {
281        let (intfc, header, key) = self.common_send(hdr)?;
282        let res = intfc.skt_tx.send_ty(&header, key, data);
283
284        match res {
285            Ok(()) => Ok(()),
286            Err(()) => Err(InterfaceSendError::InterfaceFull),
287        }
288    }
289
290    fn send_raw(&mut self, hdr: &Header, data: &[u8]) -> Result<(), InterfaceSendError> {
291        let (intfc, header, key) = self.common_send(hdr)?;
292        let res = intfc.skt_tx.send_raw(&header, key, data);
293
294        match res {
295            Ok(()) => Ok(()),
296            Err(()) => Err(InterfaceSendError::InterfaceFull),
297        }
298    }
299
300    fn send_err(
301        &mut self,
302        hdr: &Header,
303        err: crate::ProtocolError,
304    ) -> Result<(), InterfaceSendError> {
305        let (intfc, header, _key) = self.common_send(hdr)?;
306        let res = intfc.skt_tx.send_err(&header, err);
307
308        match res {
309            Ok(()) => Ok(()),
310            Err(()) => Err(InterfaceSendError::InterfaceFull),
311        }
312    }
313}
314
315impl Default for StdTcpIm {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321impl ConstInit for StdTcpIm {
322    #[allow(clippy::declare_interior_mutable_const)]
323    const INIT: Self = Self::new();
324}
325
326unsafe impl Sync for StdTcpIm {}
327
328// impl StdTcpImInner
329
330impl StdTcpImInner {
331    pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
332        let closer = Arc::new(WaitQueue::new());
333        if self.interfaces.is_empty() {
334            // todo: configurable channel depth
335            let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
336            let ctx = q.stream_producer();
337            let crx = q.stream_consumer();
338
339            let ctx = cobs_stream::Interface {
340                mtu: 1024,
341                prod: ctx,
342            };
343
344            let net_id = 1;
345            // TODO: We are spawning in a non-async context!
346            tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
347            self.interfaces.push(StdTcpTxHdl {
348                net_id,
349                skt_tx: ctx,
350                closer: closer.clone(),
351            });
352            debug!("Alloc'd net_id 1");
353            return Some((net_id, closer));
354        } else if self.interfaces.len() >= 65534 {
355            warn!("Out of netids!");
356            return None;
357        }
358
359        // If we closed any interfaces, then collect
360        if self.any_closed {
361            self.interfaces.retain(|int| {
362                let closed = int.closer.is_closed();
363                if closed {
364                    info!("Collecting interface {}", int.net_id);
365                }
366                !closed
367            });
368        }
369
370        let mut net_id = 1;
371        // we're not empty, find the lowest free address by counting the
372        // indexes, and if we find a discontinuity, allocate the first one.
373        for intfc in self.interfaces.iter() {
374            if intfc.net_id > net_id {
375                trace!("Found gap: {net_id}");
376                break;
377            }
378            debug_assert!(intfc.net_id == net_id);
379            net_id += 1;
380        }
381        // EITHER: We've found a gap that we can use, OR we've iterated all
382        // interfaces, which means that we had contiguous allocations but we
383        // have not exhausted the range.
384        debug_assert!(net_id > 0 && net_id != u16::MAX);
385
386        let q = bbq2::nicknames::Lechon::new_with_storage(BoxedSlice::new(4096));
387        let ctx = q.stream_producer();
388        let crx = q.stream_consumer();
389
390        let ctx = cobs_stream::Interface {
391            mtu: 1024,
392            prod: ctx,
393        };
394
395        debug!("allocated net_id {net_id}");
396
397        tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
398        self.interfaces.push(StdTcpTxHdl {
399            net_id,
400            skt_tx: ctx,
401            closer: closer.clone(),
402        });
403        self.interfaces.sort_unstable_by_key(|i| i.net_id);
404        Some((net_id, closer))
405    }
406}
407
408// Helper functions
409
410async fn tx_worker(
411    net_id: u16,
412    mut tx: OwnedWriteHalf,
413    rx: StreamConsumer<StdQueue>,
414    closer: Arc<WaitQueue>,
415) {
416    info!("Started tx_worker for net_id {net_id}");
417    loop {
418        let rxf = rx.wait_read();
419        let clf = closer.wait();
420
421        let frame = select! {
422            r = rxf => r,
423            _c = clf => {
424                break;
425            }
426        };
427
428        let len = frame.len();
429        debug!("sending pkt len:{} on net_id {net_id}", len);
430        let res = tx.write_all(&frame).await;
431        frame.release(len);
432        if let Err(e) = res {
433            error!("Err: {e:?}");
434            break;
435        }
436    }
437    // TODO: GC waker?
438    warn!("Closing interface {net_id}");
439}
440
441pub fn register_interface<R: ScopedRawMutex>(
442    stack: &'static NetStack<R, StdTcpIm>,
443    socket: TcpStream,
444) -> Result<StdTcpRecvHdl<R>, Error> {
445    let (rx, tx) = socket.into_split();
446    stack.with_interface_manager(|im| {
447        let inner = im.get_or_init_inner();
448        if let Some((addr, closer)) = inner.alloc_intfc(tx) {
449            Ok(StdTcpRecvHdl {
450                stack,
451                net_id: addr,
452                skt: rx,
453                closer,
454            })
455        } else {
456            Err(Error::OutOfNetIds)
457        }
458    })
459}