ergot_base/interface_manager/
std_tcp_router.rs

1/*
2    Let's see, we're going to need:
3
4    * Some kind of hashmap/vec of active interfaces, by network id?
5        * IF we use a vec, we should NOT use the index as the ID, it may be sparse
6    * The actual interface type probably gets defined by the interface manager
7    * The interface which follows the pinned rules, and removes itself on drop
8    * THIS version of the routing interface probably will not allow for other routers,
9        we probably have to assume we are the only one assigning network IDs until a
10        later point
11    * The associated "simple" version of a client probably needs a stub routing interface
12        that picks up the network ID from the destination address
13    * Honestly we might want to have an `Arc` version of the netstack, or we need some kind
14        of Once construction.
15    * The interface manager needs some kind of "handle" construction so that we can get mut
16        access to it, or we need an accessor via the netstack
17*/
18
19use std::sync::Arc;
20use std::{cell::UnsafeCell, mem::MaybeUninit};
21
22use crate::{Header, NetStack, interface_manager::std_utils::ser_frame};
23
24use maitake_sync::WaitQueue;
25use mutex::ScopedRawMutex;
26use tokio::sync::mpsc::Sender;
27use tokio::{
28    io::{AsyncReadExt, AsyncWriteExt},
29    net::{
30        TcpStream,
31        tcp::{OwnedReadHalf, OwnedWriteHalf},
32    },
33    select,
34    sync::mpsc::{Receiver, channel, error::TrySendError},
35};
36
37use super::{
38    ConstInit, InterfaceManager, InterfaceSendError,
39    std_utils::{
40        OwnedFrame, ReceiverError,
41        acc::{CobsAccumulator, FeedResult},
42        de_frame,
43    },
44};
45
46pub struct StdTcpRecvHdl<R: ScopedRawMutex + 'static> {
47    stack: &'static NetStack<R, StdTcpIm>,
48    // TODO: when we have more real networking and we could possibly
49    // have conflicting net_id assignments, we might need to have a
50    // shared ref to an Arc<AtomicU16> or something for net_id?
51    //
52    // for now, stdtcp assumes it is the only "seed" router, meaning that
53    // it is solely in charge of assigning netids
54    net_id: u16,
55    skt: OwnedReadHalf,
56    closer: Arc<WaitQueue>,
57}
58
59pub struct StdTcpIm {
60    init: bool,
61    inner: UnsafeCell<MaybeUninit<StdTcpImInner>>,
62}
63
64#[derive(Default)]
65pub struct StdTcpImInner {
66    // TODO: we probably want something like iddqd for a hashset sorted by
67    // net_id, as well as a list of "allocated" netids, mapped to the
68    // interface they are associated with
69    //
70    // TODO: for the no-std version of this, we will need to use the same
71    // intrusive list stuff that we use for sockets for holding interfaces.
72    interfaces: Vec<StdTcpTxHdl>,
73    seq_no: u16,
74    any_closed: bool,
75}
76
77#[derive(Debug, PartialEq)]
78pub enum Error {
79    OutOfNetIds,
80}
81
82struct StdTcpTxHdl {
83    net_id: u16,
84    skt_tx: Sender<OwnedFrame>,
85    closer: Arc<WaitQueue>,
86}
87
88// ---- impls ----
89
90// impl StdTcpRecvHdl
91
92impl<R: ScopedRawMutex + 'static> StdTcpRecvHdl<R> {
93    pub async fn run(mut self) -> Result<(), ReceiverError> {
94        let res = self.run_inner().await;
95        self.closer.close();
96        // todo: this could live somewhere else?
97        self.stack.with_interface_manager(|im| {
98            let inner = im.get_or_init_inner();
99            inner.any_closed = true;
100        });
101        res
102    }
103
104    pub async fn run_inner(&mut self) -> Result<(), ReceiverError> {
105        let mut cobs_buf = CobsAccumulator::new(1024 * 1024);
106        let mut raw_buf = [0u8; 4096];
107
108        loop {
109            let rd = self.skt.read(&mut raw_buf);
110            let close = self.closer.wait();
111
112            let ct = select! {
113                r = rd => {
114                    match r {
115                        Ok(0) | Err(_) => {
116                            println!("recv run {} closed", self.net_id);
117                            return Err(ReceiverError::SocketClosed)
118                        },
119                        Ok(ct) => ct,
120                    }
121                }
122                _c = close => {
123                    return Err(ReceiverError::SocketClosed);
124                }
125            };
126
127            let buf = &raw_buf[..ct];
128            let mut window = buf;
129
130            'cobs: while !window.is_empty() {
131                window = match cobs_buf.feed_raw(window) {
132                    FeedResult::Consumed => break 'cobs,
133                    FeedResult::OverFull(new_wind) => new_wind,
134                    FeedResult::DeserError(new_wind) => new_wind,
135                    FeedResult::Success { data, remaining } => {
136                        // Successfully de-cobs'd a packet, now we need to
137                        // do something with it.
138                        if let Some(mut frame) = de_frame(data) {
139                            // If the message comes in and has a src net_id of zero,
140                            // we should rewrite it so it isn't later understood as a
141                            // local packet.
142                            if frame.hdr.src.network_id == 0 {
143                                assert_ne!(
144                                    frame.hdr.src.node_id, 0,
145                                    "we got a local packet remotely?"
146                                );
147                                assert_ne!(
148                                    frame.hdr.src.node_id, 1,
149                                    "someone is pretending to be us?"
150                                );
151
152                                frame.hdr.src.network_id = self.net_id;
153                            }
154                            // TODO: if the destination IS self.net_id, we could rewrite the
155                            // dest net_id as zero to avoid a pass through the interface manager.
156                            //
157                            // If the dest is 0, should we rewrite the dest as self.net_id? This
158                            // is the opposite as above, but I dunno how that will work with responses
159                            let res = self.stack.send_raw(frame.hdr.into(), &frame.body);
160                            match res {
161                                Ok(()) => {}
162                                Err(e) => {
163                                    // TODO: match on error, potentially try to send NAK?
164                                    panic!("recv->send error: {e:?}");
165                                }
166                            }
167                        } else {
168                            println!("Decode error! Ignoring frame on net_id {}", self.net_id);
169                        }
170
171                        remaining
172                    }
173                };
174            }
175        }
176    }
177}
178
179// impl StdTcpIm
180
181impl StdTcpIm {
182    const fn new() -> Self {
183        Self {
184            init: false,
185            inner: UnsafeCell::new(MaybeUninit::uninit()),
186        }
187    }
188
189    pub fn get_nets(&mut self) -> Vec<u16> {
190        let inner = self.get_or_init_inner();
191        inner.interfaces.iter().map(|i| i.net_id).collect()
192    }
193
194    fn get_or_init_inner(&mut self) -> &mut StdTcpImInner {
195        let inner = self.inner.get_mut();
196        if self.init {
197            unsafe { inner.assume_init_mut() }
198        } else {
199            let imr = inner.write(StdTcpImInner::default());
200            self.init = true;
201            imr
202        }
203    }
204}
205
206impl InterfaceManager for StdTcpIm {
207    fn send<T: serde::Serialize>(
208        &mut self,
209        mut hdr: Header,
210        data: &T,
211    ) -> Result<(), InterfaceSendError> {
212        // todo: make this state impossible? enum of dst w/ or w/o key?
213        assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
214
215        let inner = self.get_or_init_inner();
216        // todo: we only handle direct dests, we will probably also want to search
217        // some kind of net_id:interface routing table
218        let Ok(idx) = inner
219            .interfaces
220            .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
221        else {
222            return Err(InterfaceSendError::NoRouteToDest);
223        };
224
225        let interface = &inner.interfaces[idx];
226        // TODO: Assumption: "we" are always node_id==1
227        if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
228            return Err(InterfaceSendError::DestinationLocal);
229        }
230        // If the source is local, rewrite the source using this interface's
231        // information so responses can find their way back here
232        if hdr.src.net_node_any() {
233            // todo: if we know the destination is EXACTLY this network,
234            // we could leave the network_id local to allow for shorter
235            // addresses
236            hdr.src.network_id = interface.net_id;
237            hdr.src.node_id = 1;
238        }
239
240        let res = interface.skt_tx.try_send(OwnedFrame {
241            hdr: hdr.to_headerseq_or_with_seq(|| {
242                let seq_no = inner.seq_no;
243                inner.seq_no = inner.seq_no.wrapping_add(1);
244                seq_no
245            }),
246            body: postcard::to_stdvec(data).unwrap(),
247        });
248        match res {
249            Ok(()) => Ok(()),
250            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
251            Err(TrySendError::Closed(_)) => {
252                let rem = inner.interfaces.remove(idx);
253                rem.closer.close();
254                Err(InterfaceSendError::NoRouteToDest)
255            }
256        }
257    }
258
259    fn send_raw(&mut self, mut hdr: Header, data: &[u8]) -> Result<(), InterfaceSendError> {
260        // todo: make this state impossible? enum of dst w/ or w/o key?
261        assert!(!(hdr.dst.port_id == 0 && hdr.key.is_none()));
262
263        let inner = self.get_or_init_inner();
264        // todo: dedupe w/ send
265        //
266        // todo: we only handle direct dests
267        let Ok(idx) = inner
268            .interfaces
269            .binary_search_by_key(&hdr.dst.network_id, |int| int.net_id)
270        else {
271            return Err(InterfaceSendError::NoRouteToDest);
272        };
273
274        let interface = &inner.interfaces[idx];
275        // TODO: Assumption: "we" are always node_id==1
276        if hdr.dst.network_id == interface.net_id && hdr.dst.node_id == 1 {
277            return Err(InterfaceSendError::DestinationLocal);
278        }
279        // If the source is local, rewrite the source using this interface's
280        // information so responses can find their way back here
281        if hdr.src.net_node_any() {
282            // todo: if we know the destination is EXACTLY this network,
283            // we could leave the network_id local to allow for shorter
284            // addresses
285            hdr.src.network_id = interface.net_id;
286            hdr.src.node_id = 1;
287        }
288
289        let res = interface.skt_tx.try_send(OwnedFrame {
290            hdr: hdr.to_headerseq_or_with_seq(|| {
291                let seq_no = inner.seq_no;
292                inner.seq_no = inner.seq_no.wrapping_add(1);
293                seq_no
294            }),
295            body: data.to_vec(),
296        });
297        match res {
298            Ok(()) => Ok(()),
299            Err(TrySendError::Full(_)) => Err(InterfaceSendError::InterfaceFull),
300            Err(TrySendError::Closed(_)) => {
301                inner.interfaces.remove(idx);
302                Err(InterfaceSendError::NoRouteToDest)
303            }
304        }
305    }
306}
307
308impl Default for StdTcpIm {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314impl ConstInit for StdTcpIm {
315    #[allow(clippy::declare_interior_mutable_const)]
316    const INIT: Self = Self::new();
317}
318
319unsafe impl Sync for StdTcpIm {}
320
321// impl StdTcpImInner
322
323impl StdTcpImInner {
324    pub fn alloc_intfc(&mut self, tx: OwnedWriteHalf) -> Option<(u16, Arc<WaitQueue>)> {
325        let closer = Arc::new(WaitQueue::new());
326        if self.interfaces.is_empty() {
327            // todo: configurable channel depth
328            let (ctx, crx) = channel(64);
329            let net_id = 1;
330            // TODO: We are spawning in a non-async context!
331            tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
332            self.interfaces.push(StdTcpTxHdl {
333                net_id,
334                skt_tx: ctx,
335                closer: closer.clone(),
336            });
337            println!("Alloc'd net_id 1");
338            return Some((net_id, closer));
339        } else if self.interfaces.len() >= 65534 {
340            println!("Out of netids!");
341            return None;
342        }
343
344        // If we closed any interfaces, then collect
345        if self.any_closed {
346            self.interfaces.retain(|int| {
347                let closed = int.closer.is_closed();
348                if closed {
349                    println!("Collecting interface {}", int.net_id);
350                }
351                !closed
352            });
353        }
354
355        let mut net_id = 1;
356        // we're not empty, find the lowest free address by counting the
357        // indexes, and if we find a discontinuity, allocate the first one.
358        for intfc in self.interfaces.iter() {
359            if intfc.net_id > net_id {
360                println!("Found gap: {net_id}");
361                break;
362            }
363            debug_assert!(intfc.net_id == net_id);
364            net_id += 1;
365        }
366        // EITHER: We've found a gap that we can use, OR we've iterated all
367        // interfaces, which means that we had contiguous allocations but we
368        // have not exhausted the range.
369        debug_assert!(net_id > 0 && net_id != u16::MAX);
370        let (ctx, crx) = channel(64);
371        println!("allocated net_id {net_id}");
372
373        tokio::task::spawn(tx_worker(net_id, tx, crx, closer.clone()));
374        self.interfaces.push(StdTcpTxHdl {
375            net_id,
376            skt_tx: ctx,
377            closer: closer.clone(),
378        });
379        self.interfaces.sort_unstable_by_key(|i| i.net_id);
380        Some((net_id, closer))
381    }
382}
383
384// Helper functions
385
386async fn tx_worker(
387    net_id: u16,
388    mut tx: OwnedWriteHalf,
389    mut rx: Receiver<OwnedFrame>,
390    closer: Arc<WaitQueue>,
391) {
392    println!("Started tx_worker for net_id {net_id}");
393    loop {
394        let rxf = rx.recv();
395        let clf = closer.wait();
396
397        let frame = select! {
398            r = rxf => {
399                if let Some(frame) = r {
400                    frame
401                } else {
402                    println!("tx_worker {net_id} rx closed!");
403                    closer.close();
404                    break;
405                }
406            }
407            _c = clf => {
408                break;
409            }
410        };
411
412        let msg = ser_frame(frame);
413        println!("sending pkt len:{} on net_id {net_id}", msg.len());
414        let res = tx.write_all(&msg).await;
415        if let Err(e) = res {
416            println!("Err: {e:?}");
417            break;
418        }
419    }
420    // TODO: GC waker?
421    println!("Closing interface {net_id}");
422}
423
424pub fn register_interface<R: ScopedRawMutex>(
425    stack: &'static NetStack<R, StdTcpIm>,
426    socket: TcpStream,
427) -> Result<StdTcpRecvHdl<R>, Error> {
428    let (rx, tx) = socket.into_split();
429    stack.with_interface_manager(|im| {
430        let inner = im.get_or_init_inner();
431        if let Some((addr, closer)) = inner.alloc_intfc(tx) {
432            Ok(StdTcpRecvHdl {
433                stack,
434                net_id: addr,
435                skt: rx,
436                closer,
437            })
438        } else {
439            Err(Error::OutOfNetIds)
440        }
441    })
442}