ergot_base/
net_stack.rs

1//! The Ergot NetStack
2//!
3//! The [`NetStack`] is the core of Ergot. It is intended to be placed
4//! in a `static` variable for the duration of your application.
5//!
6//! The Netstack is used directly for a couple of main responsibilities:
7//!
8//! 1. Sending a message, either from user code, or to deliver/forward messages
9//!    received from an interface
10//! 2. Attaching a socket, allowing the NetStack to route messages to it
11//! 3. Interacting with the [interface manager], in order to add/remove
12//!    interfaces, or obtain other information
13//!
14//! [interface manager]: crate::interface_manager
15//!
16//! In general, interacting with anything contained by the [`NetStack`] requires
17//! locking of the [`BlockingMutex`] which protects the inner contents. This
18//! is used both to allow sharing of the inner contents, but also to allow
19//! `Drop` impls to remove themselves from the stack in a blocking manner.
20
21use core::{any::TypeId, mem::ManuallyDrop, ptr::NonNull};
22
23use cordyceps::List;
24use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
25use serde::Serialize;
26
27use crate::{
28    Header,
29    interface_manager::{self, InterfaceManager, InterfaceSendError},
30    socket::{SocketHeader, SocketSendError, SocketVTable},
31};
32
33/// The Ergot Netstack
34pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
35    inner: BlockingMutex<R, NetStackInner<M>>,
36}
37
38pub(crate) struct NetStackInner<M: InterfaceManager> {
39    sockets: List<SocketHeader>,
40    manager: M,
41    pcache_bits: u32,
42    pcache_start: u8,
43    seq_no: u16,
44}
45
46/// An error from calling a [`NetStack`] "send" method
47#[derive(Debug, PartialEq, Eq)]
48#[non_exhaustive]
49pub enum NetStackSendError {
50    SocketSend(SocketSendError),
51    InterfaceSend(InterfaceSendError),
52    NoRoute,
53    AnyPortMissingKey,
54    WrongPortKind,
55}
56
57// ---- impl NetStack ----
58
59impl<R, M> NetStack<R, M>
60where
61    R: ScopedRawMutex + ConstInit,
62    M: InterfaceManager + interface_manager::ConstInit,
63{
64    /// Create a new, uninitialized [`NetStack`].
65    ///
66    /// Requires that the [`ScopedRawMutex`] implements the [`mutex::ConstInit`]
67    /// trait, and the [`InterfaceManager`] implements the
68    /// [`interface_manager::ConstInit`] trait.
69    ///
70    /// ## Example
71    ///
72    /// ```rust
73    /// use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
74    /// use ergot_base::NetStack;
75    /// use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
76    ///
77    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
78    /// ```
79    pub const fn new() -> Self {
80        Self {
81            inner: BlockingMutex::new(NetStackInner::new()),
82        }
83    }
84}
85
86impl<R, M> NetStack<R, M>
87where
88    R: ScopedRawMutex,
89    M: InterfaceManager,
90{
91    /// Manually create a new, uninitialized [`NetStack`].
92    ///
93    /// This method is useful if your [`ScopedRawMutex`] or [`InterfaceManager`]
94    /// do not implement their corresponding `ConstInit` trait.
95    ///
96    /// In general, this is most often only needed for `loom` testing, and
97    /// [`NetStack::new()`] should be used when possible.
98    pub const fn const_new(r: R, m: M) -> Self {
99        Self {
100            inner: BlockingMutex::const_new(
101                r,
102                NetStackInner {
103                    sockets: List::new(),
104                    manager: m,
105                    seq_no: 0,
106                    pcache_start: 0,
107                    pcache_bits: 0,
108                },
109            ),
110        }
111    }
112
113    /// Access the contained [`InterfaceManager`].
114    ///
115    /// Access to the [`InterfaceManager`] is made via the provided closure.
116    /// The [`BlockingMutex`] is locked for the duration of this access,
117    /// inhibiting all other usage of this [`NetStack`].
118    ///
119    /// This can be used to add new interfaces, obtain metadata, or other
120    /// actions supported by the chosen [`InterfaceManager`].
121    ///
122    /// ## Example
123    ///
124    /// ```rust
125    /// # use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
126    /// # use ergot_base::NetStack;
127    /// # use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
128    /// #
129    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
130    ///
131    /// let res = STACK.with_interface_manager(|im| {
132    ///    // The mutex is locked for the full duration of this closure.
133    ///    # _ = im;
134    ///    // We can return whatever we want from this context, though not
135    ///    // anything borrowed from `im`.
136    ///    42
137    /// });
138    /// assert_eq!(res, 42);
139    /// ```
140    pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&'static self, f: F) -> U {
141        self.inner.with_lock(|inner| f(&mut inner.manager))
142    }
143
144    /// Send a raw (pre-serialized) message.
145    ///
146    /// This interface should almost never be used by end-users, and is instead
147    /// typically used by interfaces to feed received messages into the
148    /// [`NetStack`].
149    pub fn send_raw(&'static self, hdr: Header, body: &[u8]) -> Result<(), NetStackSendError> {
150        if hdr.dst.port_id == 0 && hdr.key.is_none() {
151            return Err(NetStackSendError::AnyPortMissingKey);
152        }
153        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
154
155        self.inner
156            .with_lock(|inner| inner.send_raw(local_bypass, hdr, body))
157    }
158
159    /// Send a typed message
160    pub fn send_ty<T: 'static + Serialize>(
161        &'static self,
162        hdr: Header,
163        t: T,
164    ) -> Result<(), NetStackSendError> {
165        // Can we assume the destination is local?
166        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
167
168        self.inner
169            .with_lock(|inner| inner.send_ty(local_bypass, hdr, t))
170    }
171
172    pub(crate) unsafe fn try_attach_socket(
173        &'static self,
174        mut node: NonNull<SocketHeader>,
175    ) -> Option<u8> {
176        self.inner.with_lock(|inner| {
177            let new_port = inner.alloc_port()?;
178            unsafe {
179                node.as_mut().port = new_port;
180            }
181
182            inner.sockets.push_front(node);
183            Some(new_port)
184        })
185    }
186
187    pub(crate) unsafe fn attach_socket(&'static self, node: NonNull<SocketHeader>) -> u8 {
188        let res = unsafe { self.try_attach_socket(node) };
189        let Some(new_port) = res else {
190            panic!("exhausted all addrs");
191        };
192        new_port
193    }
194
195    pub(crate) unsafe fn detach_socket(&'static self, node: NonNull<SocketHeader>) {
196        self.inner.with_lock(|inner| unsafe {
197            let port = node.as_ref().port;
198            inner.free_port(port);
199            inner.sockets.remove(node)
200        });
201    }
202
203    pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&'static self, f: F) -> U {
204        self.inner.with_lock(|_inner| f())
205    }
206}
207
208impl<R, M> Default for NetStack<R, M>
209where
210    R: ScopedRawMutex + ConstInit,
211    M: InterfaceManager + interface_manager::ConstInit,
212{
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218// ---- impl NetStackInner ----
219
220impl<M> NetStackInner<M>
221where
222    M: InterfaceManager,
223    M: interface_manager::ConstInit,
224{
225    pub const fn new() -> Self {
226        Self {
227            sockets: List::new(),
228            manager: M::INIT,
229            seq_no: 0,
230            pcache_bits: 0,
231            pcache_start: 0,
232        }
233    }
234}
235
236impl<M> NetStackInner<M>
237where
238    M: InterfaceManager,
239{
240    fn send_raw(
241        &mut self,
242        local_bypass: bool,
243        hdr: Header,
244        body: &[u8],
245    ) -> Result<(), NetStackSendError> {
246        let res = if !local_bypass {
247            self.manager.send_raw(hdr.clone(), body)
248        } else {
249            Err(InterfaceSendError::DestinationLocal)
250        };
251
252        match res {
253            Ok(()) => return Ok(()),
254            Err(InterfaceSendError::DestinationLocal) => {}
255            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
256        }
257        // It was a destination local error, try to honor that
258        for socket in self.sockets.iter_raw() {
259            let skt_ref = unsafe { socket.as_ref() };
260            if hdr.kind != skt_ref.kind {
261                if hdr.dst.port_id != 0 && hdr.dst.port_id == skt_ref.port {
262                    // If kind mismatch and not wildcard: report error
263                    return Err(NetStackSendError::WrongPortKind);
264                } else {
265                    continue;
266                }
267            }
268            // TODO: only allow port_id == 0 if there is only one matching port
269            // with this key.
270            if (skt_ref.port == hdr.dst.port_id)
271                || (hdr.dst.port_id == 0 && hdr.key.is_some_and(|k| k == skt_ref.key))
272            {
273                let res = {
274                    let f = skt_ref.vtable.send_raw;
275
276                    // SAFETY: skt_ref is now dead to us!
277
278                    let this: NonNull<SocketHeader> = socket;
279                    let this: NonNull<()> = this.cast();
280                    let hdr = hdr.to_headerseq_or_with_seq(|| {
281                        let seq = self.seq_no;
282                        self.seq_no = self.seq_no.wrapping_add(1);
283                        seq
284                    });
285
286                    (f)(this, body, hdr).map_err(NetStackSendError::SocketSend)
287                };
288                return res;
289            }
290        }
291        Err(NetStackSendError::NoRoute)
292    }
293
294    fn send_ty<T: 'static + Serialize>(
295        &mut self,
296        local_bypass: bool,
297        hdr: Header,
298        t: T,
299    ) -> Result<(), NetStackSendError> {
300        let res = if !local_bypass {
301            // Not local: offer to the interface manager to send
302            self.manager.send(hdr.clone(), &t)
303        } else {
304            // just skip to local sending
305            Err(InterfaceSendError::DestinationLocal)
306        };
307
308        match res {
309            Ok(()) => return Ok(()),
310            Err(InterfaceSendError::DestinationLocal) => {}
311            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
312        }
313
314        // It was a destination local error, try to honor that
315        //
316        // Sending to a local interface means a potential move. Create a
317        // manuallydrop, if a send succeeds, then we have "moved from" here
318        // into the destination. If no send succeeds (e.g. no socket match
319        // or sending to the socket failed) then we will need to drop the
320        // value ourselves.
321        let mut t = ManuallyDrop::new(t);
322
323        // Check each socket to see if we want to send it there...
324        for socket in self.sockets.iter_raw() {
325            let skt_ref = unsafe { socket.as_ref() };
326
327            if hdr.kind != skt_ref.kind {
328                if hdr.dst.port_id != 0 && hdr.dst.port_id == skt_ref.port {
329                    // If kind mismatch and not wildcard: report error
330                    return Err(NetStackSendError::WrongPortKind);
331                } else {
332                    continue;
333                }
334            }
335
336            // TODO: only allow port_id == 0 if there is only one matching port
337            // with this key.
338            if (skt_ref.port == hdr.dst.port_id || hdr.dst.port_id == 0)
339                && hdr.key.unwrap() == skt_ref.key
340            {
341                let vtable: &'static SocketVTable = skt_ref.vtable;
342                // SAFETY: skt_ref is now dead to us!
343
344                let res = if let Some(f) = vtable.send_owned {
345                    let this: NonNull<SocketHeader> = socket;
346                    let this: NonNull<()> = this.cast();
347                    let that: NonNull<ManuallyDrop<T>> = NonNull::from(&mut t);
348                    let that: NonNull<()> = that.cast();
349                    let hdr = hdr.to_headerseq_or_with_seq(|| {
350                        let seq = self.seq_no;
351                        self.seq_no = self.seq_no.wrapping_add(1);
352                        seq
353                    });
354                    (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
355                } else if let Some(_f) = vtable.send_bor {
356                    // TODO: if we support send borrowed, then we need to
357                    // drop the manuallydrop here, success or failure.
358                    todo!()
359                } else {
360                    // todo: keep going? If we found the "right" destination and
361                    // sending fails, then there's not much we can do. Probably: there
362                    // is no case where a socket has NEITHER send_owned NOR send_bor,
363                    // can we make this state impossible instead?
364                    Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
365                };
366
367                // If sending failed, we did NOT move the T, which means it's on us
368                // to drop it.
369                if res.is_err() {
370                    unsafe {
371                        ManuallyDrop::drop(&mut t);
372                    }
373                }
374                return res;
375            }
376        }
377
378        // We reached the end of sockets. We need to drop this item.
379        unsafe {
380            ManuallyDrop::drop(&mut t);
381        }
382        Err(NetStackSendError::NoRoute)
383    }
384}
385
386impl<M> NetStackInner<M>
387where
388    M: InterfaceManager,
389{
390    /// Cache-based allocator inspired by littlefs2 ID allocator
391    ///
392    /// We remember 32 ports at a time, from the current base, which is always
393    /// a multiple of 32. Allocating from this range does not require moving thru
394    /// the socket lists.
395    ///
396    /// If the current 32 ports are all taken, we will start over from a base port
397    /// of 0, and attempt to
398    fn alloc_port(&mut self) -> Option<u8> {
399        // ports 0 is always taken (could be clear on first alloc)
400        self.pcache_bits |= (self.pcache_start == 0) as u32;
401
402        if self.pcache_bits != u32::MAX {
403            // We can allocate from the current slot
404            let ldg = self.pcache_bits.trailing_ones();
405            debug_assert!(ldg < 32);
406            self.pcache_bits |= 1 << ldg;
407            return Some(self.pcache_start + (ldg as u8));
408        }
409
410        // Nope, cache is all taken. try to find a base with available items.
411        // We always start from the bottom to keep ports small, but if we know
412        // we just exhausted a range, don't waste time checking that
413        let old_start = self.pcache_start;
414        for base in 0..8 {
415            let start = base * 32;
416            if start == old_start {
417                continue;
418            }
419            // Clear/reset cache
420            self.pcache_start = start;
421            self.pcache_bits = 0;
422            // port 0 is not allowed
423            self.pcache_bits |= (self.pcache_start == 0) as u32;
424            // port 255 is not allowed
425            self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
426
427            // TODO: If we trust that sockets are always sorted, we could early-return
428            // when we reach a `pupper > self.pcache_start`. We could also maybe be smart
429            // and iterate forwards for 0..4 and backwards for 4..8 (and switch the early
430            // return check to < instead). NOTE: We currently do NOT guarantee sockets are
431            // sorted!
432            self.sockets.iter().for_each(|s| {
433                // The upper 3 bits of the port
434                let pupper = s.port & !(32 - 1);
435                // The lower 5 bits of the port
436                let plower = s.port & (32 - 1);
437
438                if pupper == self.pcache_start {
439                    self.pcache_bits |= 1 << plower;
440                }
441            });
442
443            if self.pcache_bits != u32::MAX {
444                // We can allocate from the current slot
445                let ldg = self.pcache_bits.trailing_ones();
446                debug_assert!(ldg < 32);
447                self.pcache_bits |= 1 << ldg;
448                return Some(self.pcache_start + (ldg as u8));
449            }
450        }
451
452        // Nope, nothing found
453        None
454    }
455
456    fn free_port(&mut self, port: u8) {
457        // The upper 3 bits of the port
458        let pupper = port & !(32 - 1);
459        // The lower 5 bits of the port
460        let plower = port & (32 - 1);
461
462        // TODO: If the freed port is in the 0..32 range, or just less than
463        // the current start range, maybe do an opportunistic re-look?
464        if pupper == self.pcache_start {
465            self.pcache_bits &= !(1 << plower);
466        }
467    }
468}
469
470#[cfg(test)]
471mod test {
472    use core::pin::pin;
473    use mutex::raw_impls::cs::CriticalSectionRawMutex;
474    use std::thread::JoinHandle;
475    use tokio::sync::oneshot;
476
477    use crate::{
478        FrameKind, Key, NetStack, interface_manager::null::NullInterfaceManager,
479        socket::owned::OwnedSocket,
480    };
481
482    #[test]
483    fn port_alloc() {
484        static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
485
486        let mut v = vec![];
487
488        fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
489            let (txdone, rxdone) = oneshot::channel();
490            let (txwait, rxwait) = oneshot::channel();
491            let hdl = std::thread::spawn(move || {
492                let skt = OwnedSocket::<u64, _, _>::new(
493                    &STACK,
494                    Key(*b"TEST1234"),
495                    FrameKind::ENDPOINT_REQ,
496                );
497                let skt = pin!(skt);
498                let hdl = skt.attach();
499                assert_eq!(hdl.port(), id);
500                txwait.send(()).unwrap();
501                let _: () = rxdone.blocking_recv().unwrap();
502            });
503            let _ = rxwait.blocking_recv();
504            (id, hdl, txdone)
505        }
506
507        // make sockets 1..32
508        for i in 1..32 {
509            v.push(spawn_skt(i));
510        }
511
512        // make sockets 32..40
513        for i in 32..40 {
514            v.push(spawn_skt(i));
515        }
516
517        // drop socket 35
518        let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
519        let (_i, hdl, tx) = v.remove(pos);
520        tx.send(()).unwrap();
521        hdl.join().unwrap();
522
523        // make a new socket, it should be 35
524        v.push(spawn_skt(35));
525
526        // drop socket 4
527        let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
528        let (_i, hdl, tx) = v.remove(pos);
529        tx.send(()).unwrap();
530        hdl.join().unwrap();
531
532        // make a new socket, it should be 40
533        v.push(spawn_skt(40));
534
535        // make sockets 41..64
536        for i in 41..64 {
537            v.push(spawn_skt(i));
538        }
539
540        // make a new socket, it should be 4
541        v.push(spawn_skt(4));
542
543        // make sockets 64..255
544        for i in 64..255 {
545            v.push(spawn_skt(i));
546        }
547
548        // drop socket 212
549        let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
550        let (_i, hdl, tx) = v.remove(pos);
551        tx.send(()).unwrap();
552        hdl.join().unwrap();
553
554        // make a new socket, it should be 212
555        v.push(spawn_skt(212));
556
557        // Sockets exhausted (we never see 255)
558        let hdl = std::thread::spawn(move || {
559            let skt =
560                OwnedSocket::<u64, _, _>::new(&STACK, Key(*b"TEST1234"), FrameKind::ENDPOINT_REQ);
561            let skt = pin!(skt);
562            let hdl = skt.attach();
563            println!("{}", hdl.port());
564        });
565        assert!(hdl.join().is_err());
566
567        for (_i, hdl, tx) in v.drain(..) {
568            tx.send(()).unwrap();
569            hdl.join().unwrap();
570        }
571    }
572}