ergot_base/
net_stack.rs

1//! The Ergot NetStack
2//!
3//! The [`NetStack`] is the core of Ergot. It is intended to be placed
4//! in a `static` variable for the duration of your application.
5//!
6//! The Netstack is used directly for a couple of main responsibilities:
7//!
8//! 1. Sending a message, either from user code, or to deliver/forward messages
9//!    received from an interface
10//! 2. Attaching a socket, allowing the NetStack to route messages to it
11//! 3. Interacting with the [interface manager], in order to add/remove
12//!    interfaces, or obtain other information
13//!
14//! [interface manager]: crate::interface_manager
15//!
16//! In general, interacting with anything contained by the [`NetStack`] requires
17//! locking of the [`BlockingMutex`] which protects the inner contents. This
18//! is used both to allow sharing of the inner contents, but also to allow
19//! `Drop` impls to remove themselves from the stack in a blocking manner.
20
21use core::{any::TypeId, ops::Deref, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29    FrameKind, Header, ProtocolError,
30    interface_manager::{self, InterfaceManager, InterfaceSendError},
31    socket::{SocketHeader, SocketSendError, SocketVTable},
32};
33
34/// The Ergot Netstack
35pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
36    inner: BlockingMutex<R, NetStackInner<M>>,
37}
38
39pub trait NetStackHandle
40where
41    Self: Sized,
42{
43    type Target: Deref<Target = NetStack<Self::Mutex, Self::Interface>> + Clone;
44    type Mutex: ScopedRawMutex;
45    type Interface: InterfaceManager;
46    fn stack(&self) -> Self::Target;
47}
48
49pub(crate) struct NetStackInner<M: InterfaceManager> {
50    sockets: List<SocketHeader>,
51    manager: M,
52    pcache_bits: u32,
53    pcache_start: u8,
54    seq_no: u16,
55}
56
57/// An error from calling a [`NetStack`] "send" method
58#[derive(Debug, PartialEq, Eq)]
59#[non_exhaustive]
60pub enum NetStackSendError {
61    SocketSend(SocketSendError),
62    InterfaceSend(InterfaceSendError),
63    NoRoute,
64    AnyPortMissingKey,
65    WrongPortKind,
66    AnyPortNotUnique,
67    AllPortMissingKey,
68}
69
70// ---- impl NetStack ----
71
72// TODO: Impl for Arc
73impl<R, M> NetStackHandle for &'_ NetStack<R, M>
74where
75    R: ScopedRawMutex,
76    M: InterfaceManager,
77{
78    type Mutex = R;
79    type Interface = M;
80    type Target = Self;
81
82    fn stack(&self) -> Self::Target {
83        self
84    }
85}
86
87impl<R, M> NetStack<R, M>
88where
89    R: ScopedRawMutex + ConstInit,
90    M: InterfaceManager + interface_manager::ConstInit,
91{
92    /// Create a new, uninitialized [`NetStack`].
93    ///
94    /// Requires that the [`ScopedRawMutex`] implements the [`mutex::ConstInit`]
95    /// trait, and the [`InterfaceManager`] implements the
96    /// [`interface_manager::ConstInit`] trait.
97    ///
98    /// ## Example
99    ///
100    /// ```rust
101    /// use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
102    /// use ergot_base::NetStack;
103    /// use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
104    ///
105    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
106    /// ```
107    pub const fn new() -> Self {
108        Self {
109            inner: BlockingMutex::new(NetStackInner::new()),
110        }
111    }
112}
113
114impl<R, M> NetStack<R, M>
115where
116    R: ScopedRawMutex,
117    M: InterfaceManager,
118{
119    /// Manually create a new, uninitialized [`NetStack`].
120    ///
121    /// This method is useful if your [`ScopedRawMutex`] or [`InterfaceManager`]
122    /// do not implement their corresponding `ConstInit` trait.
123    ///
124    /// In general, this is most often only needed for `loom` testing, and
125    /// [`NetStack::new()`] should be used when possible.
126    pub const fn const_new(r: R, m: M) -> Self {
127        Self {
128            inner: BlockingMutex::const_new(
129                r,
130                NetStackInner {
131                    sockets: List::new(),
132                    manager: m,
133                    seq_no: 0,
134                    pcache_start: 0,
135                    pcache_bits: 0,
136                },
137            ),
138        }
139    }
140
141    /// Access the contained [`InterfaceManager`].
142    ///
143    /// Access to the [`InterfaceManager`] is made via the provided closure.
144    /// The [`BlockingMutex`] is locked for the duration of this access,
145    /// inhibiting all other usage of this [`NetStack`].
146    ///
147    /// This can be used to add new interfaces, obtain metadata, or other
148    /// actions supported by the chosen [`InterfaceManager`].
149    ///
150    /// ## Example
151    ///
152    /// ```rust
153    /// # use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
154    /// # use ergot_base::NetStack;
155    /// # use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
156    /// #
157    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
158    ///
159    /// let res = STACK.with_interface_manager(|im| {
160    ///    // The mutex is locked for the full duration of this closure.
161    ///    # _ = im;
162    ///    // We can return whatever we want from this context, though not
163    ///    // anything borrowed from `im`.
164    ///    42
165    /// });
166    /// assert_eq!(res, 42);
167    /// ```
168    pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&self, f: F) -> U {
169        self.inner.with_lock(|inner| f(&mut inner.manager))
170    }
171
172    /// Send a raw (pre-serialized) message.
173    ///
174    /// This interface should almost never be used by end-users, and is instead
175    /// typically used by interfaces to feed received messages into the
176    /// [`NetStack`].
177    pub fn send_raw(
178        &self,
179        hdr: &Header,
180        hdr_raw: &[u8],
181        body: &[u8],
182    ) -> Result<(), NetStackSendError> {
183        self.inner
184            .with_lock(|inner| inner.send_raw(hdr, hdr_raw, body))
185    }
186
187    /// Send a typed message
188    pub fn send_ty<T: 'static + Serialize + Clone>(
189        &self,
190        hdr: &Header,
191        t: &T,
192    ) -> Result<(), NetStackSendError> {
193        self.inner.with_lock(|inner| inner.send_ty(hdr, t))
194    }
195
196    pub fn send_err(&self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
197        self.inner.with_lock(|inner| inner.send_err(hdr, err))
198    }
199
200    pub(crate) unsafe fn try_attach_socket(&self, mut node: NonNull<SocketHeader>) -> Option<u8> {
201        self.inner.with_lock(|inner| {
202            let new_port = inner.alloc_port()?;
203            unsafe {
204                node.as_mut().port = new_port;
205            }
206
207            inner.sockets.push_front(node);
208            Some(new_port)
209        })
210    }
211
212    pub(crate) unsafe fn attach_broadcast_socket(&self, mut node: NonNull<SocketHeader>) {
213        self.inner.with_lock(|inner| {
214            unsafe {
215                node.as_mut().port = 255;
216            }
217            inner.sockets.push_back(node);
218        });
219    }
220
221    pub(crate) unsafe fn attach_socket(&self, node: NonNull<SocketHeader>) -> u8 {
222        let res = unsafe { self.try_attach_socket(node) };
223        let Some(new_port) = res else {
224            panic!("exhausted all addrs");
225        };
226        new_port
227    }
228
229    pub(crate) unsafe fn detach_socket(&self, node: NonNull<SocketHeader>) {
230        self.inner.with_lock(|inner| unsafe {
231            let port = node.as_ref().port;
232            if port != 255 {
233                inner.free_port(port);
234            }
235            inner.sockets.remove(node)
236        });
237    }
238
239    pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&self, f: F) -> U {
240        self.inner.with_lock(|_inner| f())
241    }
242}
243
244impl<R, M> Default for NetStack<R, M>
245where
246    R: ScopedRawMutex + ConstInit,
247    M: InterfaceManager + interface_manager::ConstInit,
248{
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254// ---- impl NetStackInner ----
255
256impl<M> NetStackInner<M>
257where
258    M: InterfaceManager,
259    M: interface_manager::ConstInit,
260{
261    pub const fn new() -> Self {
262        Self {
263            sockets: List::new(),
264            manager: M::INIT,
265            seq_no: 0,
266            pcache_bits: 0,
267            pcache_start: 0,
268        }
269    }
270}
271
272impl<M> NetStackInner<M>
273where
274    M: InterfaceManager,
275{
276    /// Method that handles broadcast logic
277    ///
278    /// Takes closures for sending to a socket or sending to the manager to allow
279    /// for abstracting over send_raw/send_ty.
280    fn broadcast<SendSocket, SendMgr>(
281        sockets: &mut List<SocketHeader>,
282        hdr: &Header,
283        mut sskt: SendSocket,
284        smgr: SendMgr,
285    ) -> Result<(), NetStackSendError>
286    where
287        SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
288        SendMgr: FnOnce() -> bool,
289    {
290        trace!("Sending msg broadcast w/ header: {hdr:?}");
291        let res_lcl = {
292            let bcast_iter = Self::find_all_local(sockets, hdr)?;
293            let mut any_found = false;
294            for dst in bcast_iter {
295                let res = sskt(dst);
296                if res {
297                    debug!("delivered broadcast message locally");
298                }
299                any_found |= res;
300            }
301            any_found
302        };
303
304        let res_rmt = smgr();
305        if res_rmt {
306            debug!("delivered broadcast message remotely");
307        }
308
309        if res_lcl || res_rmt {
310            Ok(())
311        } else {
312            Err(NetStackSendError::NoRoute)
313        }
314    }
315
316    /// Method that handles unicast logic
317    ///
318    /// Takes closures for sending to a socket or sending to the manager to allow
319    /// for abstracting over send_raw/send_ty.
320    fn unicast<SendSocket, SendMgr>(
321        sockets: &mut List<SocketHeader>,
322        hdr: &Header,
323        sskt: SendSocket,
324        smgr: SendMgr,
325    ) -> Result<(), NetStackSendError>
326    where
327        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
328        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
329    {
330        trace!("Sending msg unicast w/ header: {hdr:?}");
331        // Can we assume the destination is local?
332        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
333
334        let res = if !local_bypass {
335            // Not local: offer to the interface manager to send
336            debug!("Offering msg externally unicast w/ header: {hdr:?}");
337            smgr()
338        } else {
339            // just skip to local sending
340            Err(InterfaceSendError::DestinationLocal)
341        };
342
343        match res {
344            Ok(()) => {
345                debug!("Externally routed msg unicast");
346                return Ok(());
347            }
348            Err(InterfaceSendError::DestinationLocal) => {
349                debug!("No external interest in msg unicast");
350            }
351            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
352        }
353
354        // It was a destination local error, try to honor that
355        let socket = if hdr.dst.port_id == 0 {
356            debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
357            Self::find_any_local(sockets, hdr)
358        } else {
359            debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
360            Self::find_one_local(sockets, hdr)
361        }?;
362
363        sskt(socket)
364    }
365
366    /// Method that handles unicast logic
367    ///
368    /// Takes closures for sending to a socket or sending to the manager to allow
369    /// for abstracting over send_raw/send_ty.
370    fn unicast_err<SendSocket, SendMgr>(
371        sockets: &mut List<SocketHeader>,
372        hdr: &Header,
373        sskt: SendSocket,
374        smgr: SendMgr,
375    ) -> Result<(), NetStackSendError>
376    where
377        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
378        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
379    {
380        trace!("Sending err unicast w/ header: {hdr:?}");
381        // Can we assume the destination is local?
382        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
383
384        let res = if !local_bypass {
385            // Not local: offer to the interface manager to send
386            debug!("Offering err externally unicast w/ header: {hdr:?}");
387            smgr()
388        } else {
389            // just skip to local sending
390            Err(InterfaceSendError::DestinationLocal)
391        };
392
393        match res {
394            Ok(()) => {
395                debug!("Externally routed err unicast");
396                return Ok(());
397            }
398            Err(InterfaceSendError::DestinationLocal) => {
399                debug!("No external interest in err unicast");
400            }
401            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
402        }
403
404        // It was a destination local error, try to honor that
405        let socket = Self::find_one_err_local(sockets, hdr)?;
406
407        sskt(socket)
408    }
409
410    /// Handle sending of a raw (serialized) message
411    fn send_raw(
412        &mut self,
413        hdr: &Header,
414        hdr_raw: &[u8],
415        body: &[u8],
416    ) -> Result<(), NetStackSendError> {
417        let Self {
418            sockets,
419            seq_no,
420            manager,
421            ..
422        } = self;
423        trace!("Sending msg raw w/ header: {hdr:?}");
424
425        if hdr.kind == FrameKind::PROTOCOL_ERROR {
426            todo!("Don't do that");
427        }
428
429        // Is this a broadcast message?
430        if hdr.dst.port_id == 255 {
431            Self::broadcast(
432                sockets,
433                hdr,
434                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no).is_ok(),
435                || manager.send_raw(hdr, hdr_raw, body).is_ok(),
436            )
437        } else {
438            Self::unicast(
439                sockets,
440                hdr,
441                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no),
442                || manager.send_raw(hdr, hdr_raw, body),
443            )
444        }
445    }
446
447    /// Handle sending of a typed message
448    fn send_ty<T: 'static + Serialize + Clone>(
449        &mut self,
450        hdr: &Header,
451        t: &T,
452    ) -> Result<(), NetStackSendError> {
453        let Self {
454            sockets,
455            seq_no,
456            manager,
457            ..
458        } = self;
459        trace!("Sending msg ty w/ header: {hdr:?}");
460
461        if hdr.kind == FrameKind::PROTOCOL_ERROR {
462            todo!("Don't do that");
463        }
464
465        // Is this a broadcast message?
466        if hdr.dst.port_id == 255 {
467            Self::broadcast(
468                sockets,
469                hdr,
470                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
471                || manager.send(hdr, t).is_ok(),
472            )
473        } else {
474            Self::unicast(
475                sockets,
476                hdr,
477                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
478                || manager.send(hdr, t),
479            )
480        }
481    }
482
483    /// Handle sending of a typed message
484    fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
485        let Self {
486            sockets,
487            seq_no,
488            manager,
489            ..
490        } = self;
491        trace!("Sending msg ty w/ header: {hdr:?}");
492
493        if hdr.dst.port_id == 255 {
494            todo!("Don't do that");
495        }
496
497        // Is this a broadcast message?
498        Self::unicast_err(
499            sockets,
500            hdr,
501            |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
502            || manager.send_err(hdr, err),
503        )
504    }
505
506    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
507    /// the given header.
508    fn find_one_local(
509        sockets: &mut List<SocketHeader>,
510        hdr: &Header,
511    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
512        // Find the specific matching port
513        let mut iter = sockets.iter_raw();
514        let socket = loop {
515            let Some(skt) = iter.next() else {
516                return Err(NetStackSendError::NoRoute);
517            };
518            let skt_ref = unsafe { skt.as_ref() };
519            if skt_ref.port != hdr.dst.port_id {
520                continue;
521            }
522            if skt_ref.attrs.kind != hdr.kind {
523                return Err(NetStackSendError::WrongPortKind);
524            }
525            break skt;
526        };
527        Ok(socket)
528    }
529
530    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
531    /// the given header.
532    fn find_one_err_local(
533        sockets: &mut List<SocketHeader>,
534        hdr: &Header,
535    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
536        // Find the specific matching port
537        let mut iter = sockets.iter_raw();
538        let socket = loop {
539            let Some(skt) = iter.next() else {
540                return Err(NetStackSendError::NoRoute);
541            };
542            let skt_ref = unsafe { skt.as_ref() };
543            if skt_ref.port != hdr.dst.port_id {
544                continue;
545            }
546            break skt;
547        };
548        Ok(socket)
549    }
550
551    /// Find a wildcard (e.g. port_id == 0) destination port matching the given header.
552    ///
553    /// If more than one port matches the wildcard, an error is returned.
554    /// Does not match sockets that does not have the `discoverable` [`Attributes`].
555    fn find_any_local(
556        sockets: &mut List<SocketHeader>,
557        hdr: &Header,
558    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
559        // Find ONE specific matching port
560        let Some(apdx) = hdr.any_all.as_ref() else {
561            return Err(NetStackSendError::AnyPortMissingKey);
562        };
563        let mut iter = sockets.iter_raw();
564        let mut socket: Option<NonNull<SocketHeader>> = None;
565
566        loop {
567            let Some(skt) = iter.next() else {
568                break;
569            };
570            let skt_ref = unsafe { skt.as_ref() };
571
572            // Check for things that would disqualify a socket from being an
573            // "ANY" destination
574            let mut illegal = false;
575            illegal |= skt_ref.attrs.kind != hdr.kind;
576            illegal |= !skt_ref.attrs.discoverable;
577            illegal |= skt_ref.key != apdx.key;
578            if let Some(nash) = apdx.nash {
579                illegal |= Some(nash) != skt_ref.nash;
580            }
581
582            if illegal {
583                // Wait, that's illegal
584                continue;
585            }
586
587            // It's a match! Is it a second match?
588            if socket.is_some() {
589                return Err(NetStackSendError::AnyPortNotUnique);
590            }
591            // Nope! Store this one, then we keep going to ensure that no
592            // other socket matches this description.
593            socket = Some(skt);
594        }
595
596        socket.ok_or(NetStackSendError::NoRoute)
597    }
598
599    /// Find ALL broadcast (e.g. port_id == 255) sockets matching the given header.
600    ///
601    /// Returns an error if the header does not contain a Key. May return zero
602    /// matches.
603    fn find_all_local(
604        sockets: &mut List<SocketHeader>,
605        hdr: &Header,
606    ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
607        let Some(any_all) = hdr.any_all.as_ref() else {
608            return Err(NetStackSendError::AllPortMissingKey);
609        };
610        Ok(sockets.iter_raw().filter(move |socket| {
611            let skt_ref = unsafe { socket.as_ref() };
612            let bport = skt_ref.port == 255;
613            let dkind = skt_ref.attrs.kind == hdr.kind;
614            let dkey = skt_ref.key == any_all.key;
615
616            // If the any/all message DOES contain a name hash, then ONLY match
617            // sockets with the same name hash.
618            let name = if let Some(nash) = any_all.nash {
619                Some(nash) == skt_ref.nash
620            } else {
621                true
622            };
623            bport && dkind && dkey && name
624        }))
625    }
626
627    /// Helper method for sending a type to a given socket
628    fn send_ty_to_socket<T: 'static + Serialize + Clone>(
629        this: NonNull<SocketHeader>,
630        t: &T,
631        hdr: &Header,
632        seq_no: &mut u16,
633    ) -> Result<(), NetStackSendError> {
634        let vtable: &'static SocketVTable = {
635            let skt_ref = unsafe { this.as_ref() };
636            skt_ref.vtable
637        };
638
639        if let Some(f) = vtable.recv_owned {
640            let this: NonNull<()> = this.cast();
641            let that: NonNull<T> = NonNull::from(t);
642            let that: NonNull<()> = that.cast();
643            let hdr = hdr.to_headerseq_or_with_seq(|| {
644                let seq = *seq_no;
645                *seq_no = seq_no.wrapping_add(1);
646                seq
647            });
648            (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
649        } else if let Some(_f) = vtable.recv_bor {
650            // TODO: support send borrowed
651            todo!()
652        } else {
653            // todo: keep going? If we found the "right" destination and
654            // sending fails, then there's not much we can do. Probably: there
655            // is no case where a socket has NEITHER send_owned NOR send_bor,
656            // can we make this state impossible instead?
657            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
658        }
659    }
660
661    /// Helper method for sending a type to a given socket
662    fn send_err_to_socket(
663        this: NonNull<SocketHeader>,
664        err: ProtocolError,
665        hdr: &Header,
666        seq_no: &mut u16,
667    ) -> Result<(), NetStackSendError> {
668        let vtable: &'static SocketVTable = {
669            let skt_ref = unsafe { this.as_ref() };
670            skt_ref.vtable
671        };
672
673        if let Some(f) = vtable.recv_err {
674            let this: NonNull<()> = this.cast();
675            let hdr = hdr.to_headerseq_or_with_seq(|| {
676                let seq = *seq_no;
677                *seq_no = seq_no.wrapping_add(1);
678                seq
679            });
680            (f)(this, hdr, err);
681            Ok(())
682        } else {
683            // todo: keep going? If we found the "right" destination and
684            // sending fails, then there's not much we can do. Probably: there
685            // is no case where a socket has NEITHER send_owned NOR send_bor,
686            // can we make this state impossible instead?
687            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
688        }
689    }
690
691    // /// Helper message for sending a raw message to a given socket
692    // fn send_err_raw_to_socket(
693    //     this: NonNull<SocketHeader>,
694    //     body: &[u8],
695    //     hdr: &Header,
696    //     seq_no: &mut u16,
697    // ) -> Result<(), NetStackSendError> {
698    //     let vtable: &'static SocketVTable = {
699    //         let skt_ref = unsafe { this.as_ref() };
700    //         skt_ref.vtable
701    //     };
702    //     let f = vtable.recv_raw;
703
704    //     let this: NonNull<()> = this.cast();
705    //     let hdr = hdr.to_headerseq_or_with_seq(|| {
706    //         let seq = *seq_no;
707    //         *seq_no = seq_no.wrapping_add(1);
708    //         seq
709    //     });
710
711    //     (f)(this, body, hdr).map_err(NetStackSendError::SocketSend)
712    // }
713
714    /// Helper message for sending a raw message to a given socket
715    fn send_raw_to_socket(
716        this: NonNull<SocketHeader>,
717        body: &[u8],
718        hdr: &Header,
719        hdr_raw: &[u8],
720        seq_no: &mut u16,
721    ) -> Result<(), NetStackSendError> {
722        let vtable: &'static SocketVTable = {
723            let skt_ref = unsafe { this.as_ref() };
724            skt_ref.vtable
725        };
726        let f = vtable.recv_raw;
727
728        let this: NonNull<()> = this.cast();
729        let hdr = hdr.to_headerseq_or_with_seq(|| {
730            let seq = *seq_no;
731            *seq_no = seq_no.wrapping_add(1);
732            seq
733        });
734
735        (f)(this, body, hdr, hdr_raw).map_err(NetStackSendError::SocketSend)
736    }
737}
738
739impl<M> NetStackInner<M>
740where
741    M: InterfaceManager,
742{
743    /// Cache-based allocator inspired by littlefs2 ID allocator
744    ///
745    /// We remember 32 ports at a time, from the current base, which is always
746    /// a multiple of 32. Allocating from this range does not require moving thru
747    /// the socket lists.
748    ///
749    /// If the current 32 ports are all taken, we will start over from a base port
750    /// of 0, and attempt to
751    fn alloc_port(&mut self) -> Option<u8> {
752        // ports 0 is always taken (could be clear on first alloc)
753        self.pcache_bits |= (self.pcache_start == 0) as u32;
754
755        if self.pcache_bits != u32::MAX {
756            // We can allocate from the current slot
757            let ldg = self.pcache_bits.trailing_ones();
758            debug_assert!(ldg < 32);
759            self.pcache_bits |= 1 << ldg;
760            return Some(self.pcache_start + (ldg as u8));
761        }
762
763        // Nope, cache is all taken. try to find a base with available items.
764        // We always start from the bottom to keep ports small, but if we know
765        // we just exhausted a range, don't waste time checking that
766        let old_start = self.pcache_start;
767        for base in 0..8 {
768            let start = base * 32;
769            if start == old_start {
770                continue;
771            }
772            // Clear/reset cache
773            self.pcache_start = start;
774            self.pcache_bits = 0;
775            // port 0 is not allowed
776            self.pcache_bits |= (self.pcache_start == 0) as u32;
777            // port 255 is not allowed
778            self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
779
780            // TODO: If we trust that sockets are always sorted, we could early-return
781            // when we reach a `pupper > self.pcache_start`. We could also maybe be smart
782            // and iterate forwards for 0..4 and backwards for 4..8 (and switch the early
783            // return check to < instead). NOTE: We currently do NOT guarantee sockets are
784            // sorted!
785            self.sockets.iter().for_each(|s| {
786                if s.port == 255 {
787                    return;
788                }
789
790                // The upper 3 bits of the port
791                let pupper = s.port & !(32 - 1);
792                // The lower 5 bits of the port
793                let plower = s.port & (32 - 1);
794
795                if pupper == self.pcache_start {
796                    self.pcache_bits |= 1 << plower;
797                }
798            });
799
800            if self.pcache_bits != u32::MAX {
801                // We can allocate from the current slot
802                let ldg = self.pcache_bits.trailing_ones();
803                debug_assert!(ldg < 32);
804                self.pcache_bits |= 1 << ldg;
805                return Some(self.pcache_start + (ldg as u8));
806            }
807        }
808
809        // Nope, nothing found
810        None
811    }
812
813    fn free_port(&mut self, port: u8) {
814        debug_assert!(port != 255);
815        // The upper 3 bits of the port
816        let pupper = port & !(32 - 1);
817        // The lower 5 bits of the port
818        let plower = port & (32 - 1);
819
820        // TODO: If the freed port is in the 0..32 range, or just less than
821        // the current start range, maybe do an opportunistic re-look?
822        if pupper == self.pcache_start {
823            self.pcache_bits &= !(1 << plower);
824        }
825    }
826}
827
828impl NetStackSendError {
829    pub fn to_error(&self) -> ProtocolError {
830        match self {
831            NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
832            NetStackSendError::InterfaceSend(interface_send_error) => {
833                interface_send_error.to_error()
834            }
835            NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
836            NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
837            NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
838            NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
839            NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
840        }
841    }
842}
843
844#[cfg(test)]
845mod test {
846    use core::pin::pin;
847    use mutex::raw_impls::cs::CriticalSectionRawMutex;
848    use std::thread::JoinHandle;
849    use tokio::sync::oneshot;
850
851    use crate::{
852        FrameKind, Key, NetStack,
853        interface_manager::null::NullInterfaceManager,
854        socket::{Attributes, owned::single::Socket},
855    };
856
857    #[test]
858    fn port_alloc() {
859        static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
860
861        let mut v = vec![];
862
863        fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
864            let (txdone, rxdone) = oneshot::channel();
865            let (txwait, rxwait) = oneshot::channel();
866            let hdl = std::thread::spawn(move || {
867                let skt = Socket::<u64, &_>::new(
868                    &STACK,
869                    Key(*b"TEST1234"),
870                    Attributes {
871                        kind: FrameKind::ENDPOINT_REQ,
872                        discoverable: true,
873                    },
874                    None,
875                );
876                let skt = pin!(skt);
877                let hdl = skt.attach();
878                assert_eq!(hdl.port(), id);
879                txwait.send(()).unwrap();
880                let _: () = rxdone.blocking_recv().unwrap();
881            });
882            let _ = rxwait.blocking_recv();
883            (id, hdl, txdone)
884        }
885
886        // make sockets 1..32
887        for i in 1..32 {
888            v.push(spawn_skt(i));
889        }
890
891        // make sockets 32..40
892        for i in 32..40 {
893            v.push(spawn_skt(i));
894        }
895
896        // drop socket 35
897        let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
898        let (_i, hdl, tx) = v.remove(pos);
899        tx.send(()).unwrap();
900        hdl.join().unwrap();
901
902        // make a new socket, it should be 35
903        v.push(spawn_skt(35));
904
905        // drop socket 4
906        let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
907        let (_i, hdl, tx) = v.remove(pos);
908        tx.send(()).unwrap();
909        hdl.join().unwrap();
910
911        // make a new socket, it should be 40
912        v.push(spawn_skt(40));
913
914        // make sockets 41..64
915        for i in 41..64 {
916            v.push(spawn_skt(i));
917        }
918
919        // make a new socket, it should be 4
920        v.push(spawn_skt(4));
921
922        // make sockets 64..255
923        for i in 64..255 {
924            v.push(spawn_skt(i));
925        }
926
927        // drop socket 212
928        let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
929        let (_i, hdl, tx) = v.remove(pos);
930        tx.send(()).unwrap();
931        hdl.join().unwrap();
932
933        // make a new socket, it should be 212
934        v.push(spawn_skt(212));
935
936        // Sockets exhausted (we never see 255)
937        let hdl = std::thread::spawn(move || {
938            let skt = Socket::<u64, &_>::new(
939                &STACK,
940                Key(*b"TEST1234"),
941                Attributes {
942                    kind: FrameKind::ENDPOINT_REQ,
943                    discoverable: true,
944                },
945                None,
946            );
947            let skt = pin!(skt);
948            let hdl = skt.attach();
949            println!("{}", hdl.port());
950        });
951        assert!(hdl.join().is_err());
952
953        for (_i, hdl, tx) in v.drain(..) {
954            tx.send(()).unwrap();
955            hdl.join().unwrap();
956        }
957    }
958}