ergot_base/
net_stack.rs

1//! The Ergot NetStack
2//!
3//! The [`NetStack`] is the core of Ergot. It is intended to be placed
4//! in a `static` variable for the duration of your application.
5//!
6//! The Netstack is used directly for a couple of main responsibilities:
7//!
8//! 1. Sending a message, either from user code, or to deliver/forward messages
9//!    received from an interface
10//! 2. Attaching a socket, allowing the NetStack to route messages to it
11//! 3. Interacting with the [interface manager], in order to add/remove
12//!    interfaces, or obtain other information
13//!
14//! [interface manager]: crate::interface_manager
15//!
16//! In general, interacting with anything contained by the [`NetStack`] requires
17//! locking of the [`BlockingMutex`] which protects the inner contents. This
18//! is used both to allow sharing of the inner contents, but also to allow
19//! `Drop` impls to remove themselves from the stack in a blocking manner.
20
21use core::{any::TypeId, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29    FrameKind, Header, ProtocolError,
30    interface_manager::{self, InterfaceManager, InterfaceSendError},
31    socket::{SocketHeader, SocketSendError, SocketVTable},
32};
33
34/// The Ergot Netstack
35pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
36    inner: BlockingMutex<R, NetStackInner<M>>,
37}
38
39pub(crate) struct NetStackInner<M: InterfaceManager> {
40    sockets: List<SocketHeader>,
41    manager: M,
42    pcache_bits: u32,
43    pcache_start: u8,
44    seq_no: u16,
45}
46
47/// An error from calling a [`NetStack`] "send" method
48#[derive(Debug, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum NetStackSendError {
51    SocketSend(SocketSendError),
52    InterfaceSend(InterfaceSendError),
53    NoRoute,
54    AnyPortMissingKey,
55    WrongPortKind,
56    AnyPortNotUnique,
57    AllPortMissingKey,
58}
59
60// ---- impl NetStack ----
61
62impl<R, M> NetStack<R, M>
63where
64    R: ScopedRawMutex + ConstInit,
65    M: InterfaceManager + interface_manager::ConstInit,
66{
67    /// Create a new, uninitialized [`NetStack`].
68    ///
69    /// Requires that the [`ScopedRawMutex`] implements the [`mutex::ConstInit`]
70    /// trait, and the [`InterfaceManager`] implements the
71    /// [`interface_manager::ConstInit`] trait.
72    ///
73    /// ## Example
74    ///
75    /// ```rust
76    /// use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
77    /// use ergot_base::NetStack;
78    /// use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
79    ///
80    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
81    /// ```
82    pub const fn new() -> Self {
83        Self {
84            inner: BlockingMutex::new(NetStackInner::new()),
85        }
86    }
87}
88
89impl<R, M> NetStack<R, M>
90where
91    R: ScopedRawMutex,
92    M: InterfaceManager,
93{
94    /// Manually create a new, uninitialized [`NetStack`].
95    ///
96    /// This method is useful if your [`ScopedRawMutex`] or [`InterfaceManager`]
97    /// do not implement their corresponding `ConstInit` trait.
98    ///
99    /// In general, this is most often only needed for `loom` testing, and
100    /// [`NetStack::new()`] should be used when possible.
101    pub const fn const_new(r: R, m: M) -> Self {
102        Self {
103            inner: BlockingMutex::const_new(
104                r,
105                NetStackInner {
106                    sockets: List::new(),
107                    manager: m,
108                    seq_no: 0,
109                    pcache_start: 0,
110                    pcache_bits: 0,
111                },
112            ),
113        }
114    }
115
116    /// Access the contained [`InterfaceManager`].
117    ///
118    /// Access to the [`InterfaceManager`] is made via the provided closure.
119    /// The [`BlockingMutex`] is locked for the duration of this access,
120    /// inhibiting all other usage of this [`NetStack`].
121    ///
122    /// This can be used to add new interfaces, obtain metadata, or other
123    /// actions supported by the chosen [`InterfaceManager`].
124    ///
125    /// ## Example
126    ///
127    /// ```rust
128    /// # use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
129    /// # use ergot_base::NetStack;
130    /// # use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
131    /// #
132    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
133    ///
134    /// let res = STACK.with_interface_manager(|im| {
135    ///    // The mutex is locked for the full duration of this closure.
136    ///    # _ = im;
137    ///    // We can return whatever we want from this context, though not
138    ///    // anything borrowed from `im`.
139    ///    42
140    /// });
141    /// assert_eq!(res, 42);
142    /// ```
143    pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&'static self, f: F) -> U {
144        self.inner.with_lock(|inner| f(&mut inner.manager))
145    }
146
147    /// Send a raw (pre-serialized) message.
148    ///
149    /// This interface should almost never be used by end-users, and is instead
150    /// typically used by interfaces to feed received messages into the
151    /// [`NetStack`].
152    pub fn send_raw(
153        &'static self,
154        hdr: &Header,
155        hdr_raw: &[u8],
156        body: &[u8],
157    ) -> Result<(), NetStackSendError> {
158        self.inner
159            .with_lock(|inner| inner.send_raw(hdr, hdr_raw, body))
160    }
161
162    /// Send a typed message
163    pub fn send_ty<T: 'static + Serialize + Clone>(
164        &'static self,
165        hdr: &Header,
166        t: &T,
167    ) -> Result<(), NetStackSendError> {
168        self.inner.with_lock(|inner| inner.send_ty(hdr, t))
169    }
170
171    pub fn send_err(
172        &'static self,
173        hdr: &Header,
174        err: ProtocolError,
175    ) -> Result<(), NetStackSendError> {
176        self.inner.with_lock(|inner| inner.send_err(hdr, err))
177    }
178
179    pub(crate) unsafe fn try_attach_socket(
180        &'static self,
181        mut node: NonNull<SocketHeader>,
182    ) -> Option<u8> {
183        self.inner.with_lock(|inner| {
184            let new_port = inner.alloc_port()?;
185            unsafe {
186                node.as_mut().port = new_port;
187            }
188
189            inner.sockets.push_front(node);
190            Some(new_port)
191        })
192    }
193
194    pub(crate) unsafe fn attach_broadcast_socket(&'static self, mut node: NonNull<SocketHeader>) {
195        self.inner.with_lock(|inner| {
196            unsafe {
197                node.as_mut().port = 255;
198            }
199            inner.sockets.push_back(node);
200        });
201    }
202
203    pub(crate) unsafe fn attach_socket(&'static self, node: NonNull<SocketHeader>) -> u8 {
204        let res = unsafe { self.try_attach_socket(node) };
205        let Some(new_port) = res else {
206            panic!("exhausted all addrs");
207        };
208        new_port
209    }
210
211    pub(crate) unsafe fn detach_socket(&'static self, node: NonNull<SocketHeader>) {
212        self.inner.with_lock(|inner| unsafe {
213            let port = node.as_ref().port;
214            if port != 255 {
215                inner.free_port(port);
216            }
217            inner.sockets.remove(node)
218        });
219    }
220
221    pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&'static self, f: F) -> U {
222        self.inner.with_lock(|_inner| f())
223    }
224}
225
226impl<R, M> Default for NetStack<R, M>
227where
228    R: ScopedRawMutex + ConstInit,
229    M: InterfaceManager + interface_manager::ConstInit,
230{
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236// ---- impl NetStackInner ----
237
238impl<M> NetStackInner<M>
239where
240    M: InterfaceManager,
241    M: interface_manager::ConstInit,
242{
243    pub const fn new() -> Self {
244        Self {
245            sockets: List::new(),
246            manager: M::INIT,
247            seq_no: 0,
248            pcache_bits: 0,
249            pcache_start: 0,
250        }
251    }
252}
253
254impl<M> NetStackInner<M>
255where
256    M: InterfaceManager,
257{
258    /// Method that handles broadcast logic
259    ///
260    /// Takes closures for sending to a socket or sending to the manager to allow
261    /// for abstracting over send_raw/send_ty.
262    fn broadcast<SendSocket, SendMgr>(
263        sockets: &mut List<SocketHeader>,
264        hdr: &Header,
265        mut sskt: SendSocket,
266        smgr: SendMgr,
267    ) -> Result<(), NetStackSendError>
268    where
269        SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
270        SendMgr: FnOnce() -> bool,
271    {
272        trace!("Sending msg broadcast w/ header: {hdr:?}");
273        let res_lcl = {
274            let bcast_iter = Self::find_all_local(sockets, hdr)?;
275            let mut any_found = false;
276            for dst in bcast_iter {
277                let res = sskt(dst);
278                if res {
279                    debug!("delivered broadcast message locally");
280                }
281                any_found |= res;
282            }
283            any_found
284        };
285
286        let res_rmt = smgr();
287        if res_rmt {
288            debug!("delivered broadcast message remotely");
289        }
290
291        if res_lcl || res_rmt {
292            Ok(())
293        } else {
294            Err(NetStackSendError::NoRoute)
295        }
296    }
297
298    /// Method that handles unicast logic
299    ///
300    /// Takes closures for sending to a socket or sending to the manager to allow
301    /// for abstracting over send_raw/send_ty.
302    fn unicast<SendSocket, SendMgr>(
303        sockets: &mut List<SocketHeader>,
304        hdr: &Header,
305        sskt: SendSocket,
306        smgr: SendMgr,
307    ) -> Result<(), NetStackSendError>
308    where
309        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
310        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
311    {
312        trace!("Sending msg unicast w/ header: {hdr:?}");
313        // Can we assume the destination is local?
314        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
315
316        let res = if !local_bypass {
317            // Not local: offer to the interface manager to send
318            debug!("Offering msg externally unicast w/ header: {hdr:?}");
319            smgr()
320        } else {
321            // just skip to local sending
322            Err(InterfaceSendError::DestinationLocal)
323        };
324
325        match res {
326            Ok(()) => {
327                debug!("Externally routed msg unicast");
328                return Ok(());
329            }
330            Err(InterfaceSendError::DestinationLocal) => {
331                debug!("No external interest in msg unicast");
332            }
333            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
334        }
335
336        // It was a destination local error, try to honor that
337        let socket = if hdr.dst.port_id == 0 {
338            debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
339            Self::find_any_local(sockets, hdr)
340        } else {
341            debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
342            Self::find_one_local(sockets, hdr)
343        }?;
344
345        sskt(socket)
346    }
347
348    /// Method that handles unicast logic
349    ///
350    /// Takes closures for sending to a socket or sending to the manager to allow
351    /// for abstracting over send_raw/send_ty.
352    fn unicast_err<SendSocket, SendMgr>(
353        sockets: &mut List<SocketHeader>,
354        hdr: &Header,
355        sskt: SendSocket,
356        smgr: SendMgr,
357    ) -> Result<(), NetStackSendError>
358    where
359        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
360        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
361    {
362        trace!("Sending err unicast w/ header: {hdr:?}");
363        // Can we assume the destination is local?
364        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
365
366        let res = if !local_bypass {
367            // Not local: offer to the interface manager to send
368            debug!("Offering err externally unicast w/ header: {hdr:?}");
369            smgr()
370        } else {
371            // just skip to local sending
372            Err(InterfaceSendError::DestinationLocal)
373        };
374
375        match res {
376            Ok(()) => {
377                debug!("Externally routed err unicast");
378                return Ok(());
379            }
380            Err(InterfaceSendError::DestinationLocal) => {
381                debug!("No external interest in err unicast");
382            }
383            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
384        }
385
386        // It was a destination local error, try to honor that
387        let socket = Self::find_one_err_local(sockets, hdr)?;
388
389        sskt(socket)
390    }
391
392    /// Handle sending of a raw (serialized) message
393    fn send_raw(
394        &mut self,
395        hdr: &Header,
396        hdr_raw: &[u8],
397        body: &[u8],
398    ) -> Result<(), NetStackSendError> {
399        let Self {
400            sockets,
401            seq_no,
402            manager,
403            ..
404        } = self;
405        trace!("Sending msg raw w/ header: {hdr:?}");
406
407        if hdr.kind == FrameKind::PROTOCOL_ERROR {
408            todo!("Don't do that");
409        }
410
411        // Is this a broadcast message?
412        if hdr.dst.port_id == 255 {
413            Self::broadcast(
414                sockets,
415                hdr,
416                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no).is_ok(),
417                || manager.send_raw(hdr, hdr_raw, body).is_ok(),
418            )
419        } else {
420            Self::unicast(
421                sockets,
422                hdr,
423                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no),
424                || manager.send_raw(hdr, hdr_raw, body),
425            )
426        }
427    }
428
429    /// Handle sending of a typed message
430    fn send_ty<T: 'static + Serialize + Clone>(
431        &mut self,
432        hdr: &Header,
433        t: &T,
434    ) -> Result<(), NetStackSendError> {
435        let Self {
436            sockets,
437            seq_no,
438            manager,
439            ..
440        } = self;
441        trace!("Sending msg ty w/ header: {hdr:?}");
442
443        if hdr.kind == FrameKind::PROTOCOL_ERROR {
444            todo!("Don't do that");
445        }
446
447        // Is this a broadcast message?
448        if hdr.dst.port_id == 255 {
449            Self::broadcast(
450                sockets,
451                hdr,
452                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
453                || manager.send(hdr, t).is_ok(),
454            )
455        } else {
456            Self::unicast(
457                sockets,
458                hdr,
459                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
460                || manager.send(hdr, t),
461            )
462        }
463    }
464
465    /// Handle sending of a typed message
466    fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
467        let Self {
468            sockets,
469            seq_no,
470            manager,
471            ..
472        } = self;
473        trace!("Sending msg ty w/ header: {hdr:?}");
474
475        if hdr.dst.port_id == 255 {
476            todo!("Don't do that");
477        }
478
479        // Is this a broadcast message?
480        Self::unicast_err(
481            sockets,
482            hdr,
483            |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
484            || manager.send_err(hdr, err),
485        )
486    }
487
488    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
489    /// the given header.
490    fn find_one_local(
491        sockets: &mut List<SocketHeader>,
492        hdr: &Header,
493    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
494        // Find the specific matching port
495        let mut iter = sockets.iter_raw();
496        let socket = loop {
497            let Some(skt) = iter.next() else {
498                return Err(NetStackSendError::NoRoute);
499            };
500            let skt_ref = unsafe { skt.as_ref() };
501            if skt_ref.port != hdr.dst.port_id {
502                continue;
503            }
504            if skt_ref.attrs.kind != hdr.kind {
505                return Err(NetStackSendError::WrongPortKind);
506            }
507            break skt;
508        };
509        Ok(socket)
510    }
511
512    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
513    /// the given header.
514    fn find_one_err_local(
515        sockets: &mut List<SocketHeader>,
516        hdr: &Header,
517    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
518        // Find the specific matching port
519        let mut iter = sockets.iter_raw();
520        let socket = loop {
521            let Some(skt) = iter.next() else {
522                return Err(NetStackSendError::NoRoute);
523            };
524            let skt_ref = unsafe { skt.as_ref() };
525            if skt_ref.port != hdr.dst.port_id {
526                continue;
527            }
528            break skt;
529        };
530        Ok(socket)
531    }
532
533    /// Find a wildcard (e.g. port_id == 0) destination port matching the given header.
534    ///
535    /// If more than one port matches the wildcard, an error is returned.
536    /// Does not match sockets that does not have the `discoverable` [`Attributes`].
537    fn find_any_local(
538        sockets: &mut List<SocketHeader>,
539        hdr: &Header,
540    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
541        // Find ONE specific matching port
542        let Some(apdx) = hdr.any_all.as_ref() else {
543            return Err(NetStackSendError::AnyPortMissingKey);
544        };
545        let mut iter = sockets.iter_raw();
546        let mut socket: Option<NonNull<SocketHeader>> = None;
547
548        loop {
549            let Some(skt) = iter.next() else {
550                break;
551            };
552            let skt_ref = unsafe { skt.as_ref() };
553
554            // Check for things that would disqualify a socket from being an
555            // "ANY" destination
556            let mut illegal = false;
557            illegal |= skt_ref.attrs.kind != hdr.kind;
558            illegal |= !skt_ref.attrs.discoverable;
559            illegal |= skt_ref.key != apdx.key;
560            if let Some(nash) = apdx.nash {
561                illegal |= Some(nash) != skt_ref.nash;
562            }
563
564            if illegal {
565                // Wait, that's illegal
566                continue;
567            }
568
569            // It's a match! Is it a second match?
570            if socket.is_some() {
571                return Err(NetStackSendError::AnyPortNotUnique);
572            }
573            // Nope! Store this one, then we keep going to ensure that no
574            // other socket matches this description.
575            socket = Some(skt);
576        }
577
578        socket.ok_or(NetStackSendError::NoRoute)
579    }
580
581    /// Find ALL broadcast (e.g. port_id == 255) sockets matching the given header.
582    ///
583    /// Returns an error if the header does not contain a Key. May return zero
584    /// matches.
585    fn find_all_local(
586        sockets: &mut List<SocketHeader>,
587        hdr: &Header,
588    ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
589        let Some(any_all) = hdr.any_all.as_ref() else {
590            return Err(NetStackSendError::AllPortMissingKey);
591        };
592        Ok(sockets.iter_raw().filter(move |socket| {
593            let skt_ref = unsafe { socket.as_ref() };
594            let bport = skt_ref.port == 255;
595            let dkind = skt_ref.attrs.kind == hdr.kind;
596            let dkey = skt_ref.key == any_all.key;
597
598            // If the any/all message DOES contain a name hash, then ONLY match
599            // sockets with the same name hash.
600            let name = if let Some(nash) = any_all.nash {
601                Some(nash) == skt_ref.nash
602            } else {
603                true
604            };
605            bport && dkind && dkey && name
606        }))
607    }
608
609    /// Helper method for sending a type to a given socket
610    fn send_ty_to_socket<T: 'static + Serialize + Clone>(
611        this: NonNull<SocketHeader>,
612        t: &T,
613        hdr: &Header,
614        seq_no: &mut u16,
615    ) -> Result<(), NetStackSendError> {
616        let vtable: &'static SocketVTable = {
617            let skt_ref = unsafe { this.as_ref() };
618            skt_ref.vtable
619        };
620
621        if let Some(f) = vtable.recv_owned {
622            let this: NonNull<()> = this.cast();
623            let that: NonNull<T> = NonNull::from(t);
624            let that: NonNull<()> = that.cast();
625            let hdr = hdr.to_headerseq_or_with_seq(|| {
626                let seq = *seq_no;
627                *seq_no = seq_no.wrapping_add(1);
628                seq
629            });
630            (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
631        } else if let Some(_f) = vtable.recv_bor {
632            // TODO: support send borrowed
633            todo!()
634        } else {
635            // todo: keep going? If we found the "right" destination and
636            // sending fails, then there's not much we can do. Probably: there
637            // is no case where a socket has NEITHER send_owned NOR send_bor,
638            // can we make this state impossible instead?
639            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
640        }
641    }
642
643    /// Helper method for sending a type to a given socket
644    fn send_err_to_socket(
645        this: NonNull<SocketHeader>,
646        err: ProtocolError,
647        hdr: &Header,
648        seq_no: &mut u16,
649    ) -> Result<(), NetStackSendError> {
650        let vtable: &'static SocketVTable = {
651            let skt_ref = unsafe { this.as_ref() };
652            skt_ref.vtable
653        };
654
655        if let Some(f) = vtable.recv_err {
656            let this: NonNull<()> = this.cast();
657            let hdr = hdr.to_headerseq_or_with_seq(|| {
658                let seq = *seq_no;
659                *seq_no = seq_no.wrapping_add(1);
660                seq
661            });
662            (f)(this, hdr, err);
663            Ok(())
664        } else {
665            // todo: keep going? If we found the "right" destination and
666            // sending fails, then there's not much we can do. Probably: there
667            // is no case where a socket has NEITHER send_owned NOR send_bor,
668            // can we make this state impossible instead?
669            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
670        }
671    }
672
673    // /// Helper message for sending a raw message to a given socket
674    // fn send_err_raw_to_socket(
675    //     this: NonNull<SocketHeader>,
676    //     body: &[u8],
677    //     hdr: &Header,
678    //     seq_no: &mut u16,
679    // ) -> Result<(), NetStackSendError> {
680    //     let vtable: &'static SocketVTable = {
681    //         let skt_ref = unsafe { this.as_ref() };
682    //         skt_ref.vtable
683    //     };
684    //     let f = vtable.recv_raw;
685
686    //     let this: NonNull<()> = this.cast();
687    //     let hdr = hdr.to_headerseq_or_with_seq(|| {
688    //         let seq = *seq_no;
689    //         *seq_no = seq_no.wrapping_add(1);
690    //         seq
691    //     });
692
693    //     (f)(this, body, hdr).map_err(NetStackSendError::SocketSend)
694    // }
695
696    /// Helper message for sending a raw message to a given socket
697    fn send_raw_to_socket(
698        this: NonNull<SocketHeader>,
699        body: &[u8],
700        hdr: &Header,
701        hdr_raw: &[u8],
702        seq_no: &mut u16,
703    ) -> Result<(), NetStackSendError> {
704        let vtable: &'static SocketVTable = {
705            let skt_ref = unsafe { this.as_ref() };
706            skt_ref.vtable
707        };
708        let f = vtable.recv_raw;
709
710        let this: NonNull<()> = this.cast();
711        let hdr = hdr.to_headerseq_or_with_seq(|| {
712            let seq = *seq_no;
713            *seq_no = seq_no.wrapping_add(1);
714            seq
715        });
716
717        (f)(this, body, hdr, hdr_raw).map_err(NetStackSendError::SocketSend)
718    }
719}
720
721impl<M> NetStackInner<M>
722where
723    M: InterfaceManager,
724{
725    /// Cache-based allocator inspired by littlefs2 ID allocator
726    ///
727    /// We remember 32 ports at a time, from the current base, which is always
728    /// a multiple of 32. Allocating from this range does not require moving thru
729    /// the socket lists.
730    ///
731    /// If the current 32 ports are all taken, we will start over from a base port
732    /// of 0, and attempt to
733    fn alloc_port(&mut self) -> Option<u8> {
734        // ports 0 is always taken (could be clear on first alloc)
735        self.pcache_bits |= (self.pcache_start == 0) as u32;
736
737        if self.pcache_bits != u32::MAX {
738            // We can allocate from the current slot
739            let ldg = self.pcache_bits.trailing_ones();
740            debug_assert!(ldg < 32);
741            self.pcache_bits |= 1 << ldg;
742            return Some(self.pcache_start + (ldg as u8));
743        }
744
745        // Nope, cache is all taken. try to find a base with available items.
746        // We always start from the bottom to keep ports small, but if we know
747        // we just exhausted a range, don't waste time checking that
748        let old_start = self.pcache_start;
749        for base in 0..8 {
750            let start = base * 32;
751            if start == old_start {
752                continue;
753            }
754            // Clear/reset cache
755            self.pcache_start = start;
756            self.pcache_bits = 0;
757            // port 0 is not allowed
758            self.pcache_bits |= (self.pcache_start == 0) as u32;
759            // port 255 is not allowed
760            self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
761
762            // TODO: If we trust that sockets are always sorted, we could early-return
763            // when we reach a `pupper > self.pcache_start`. We could also maybe be smart
764            // and iterate forwards for 0..4 and backwards for 4..8 (and switch the early
765            // return check to < instead). NOTE: We currently do NOT guarantee sockets are
766            // sorted!
767            self.sockets.iter().for_each(|s| {
768                if s.port == 255 {
769                    return;
770                }
771
772                // The upper 3 bits of the port
773                let pupper = s.port & !(32 - 1);
774                // The lower 5 bits of the port
775                let plower = s.port & (32 - 1);
776
777                if pupper == self.pcache_start {
778                    self.pcache_bits |= 1 << plower;
779                }
780            });
781
782            if self.pcache_bits != u32::MAX {
783                // We can allocate from the current slot
784                let ldg = self.pcache_bits.trailing_ones();
785                debug_assert!(ldg < 32);
786                self.pcache_bits |= 1 << ldg;
787                return Some(self.pcache_start + (ldg as u8));
788            }
789        }
790
791        // Nope, nothing found
792        None
793    }
794
795    fn free_port(&mut self, port: u8) {
796        debug_assert!(port != 255);
797        // The upper 3 bits of the port
798        let pupper = port & !(32 - 1);
799        // The lower 5 bits of the port
800        let plower = port & (32 - 1);
801
802        // TODO: If the freed port is in the 0..32 range, or just less than
803        // the current start range, maybe do an opportunistic re-look?
804        if pupper == self.pcache_start {
805            self.pcache_bits &= !(1 << plower);
806        }
807    }
808}
809
810impl NetStackSendError {
811    pub fn to_error(&self) -> ProtocolError {
812        match self {
813            NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
814            NetStackSendError::InterfaceSend(interface_send_error) => {
815                interface_send_error.to_error()
816            }
817            NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
818            NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
819            NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
820            NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
821            NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
822        }
823    }
824}
825
826#[cfg(test)]
827mod test {
828    use core::pin::pin;
829    use mutex::raw_impls::cs::CriticalSectionRawMutex;
830    use std::thread::JoinHandle;
831    use tokio::sync::oneshot;
832
833    use crate::{
834        FrameKind, Key, NetStack,
835        interface_manager::null::NullInterfaceManager,
836        socket::{Attributes, owned::single::Socket},
837    };
838
839    #[test]
840    fn port_alloc() {
841        static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
842
843        let mut v = vec![];
844
845        fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
846            let (txdone, rxdone) = oneshot::channel();
847            let (txwait, rxwait) = oneshot::channel();
848            let hdl = std::thread::spawn(move || {
849                let skt = Socket::<u64, _, _>::new(
850                    &STACK,
851                    Key(*b"TEST1234"),
852                    Attributes {
853                        kind: FrameKind::ENDPOINT_REQ,
854                        discoverable: true,
855                    },
856                    None,
857                );
858                let skt = pin!(skt);
859                let hdl = skt.attach();
860                assert_eq!(hdl.port(), id);
861                txwait.send(()).unwrap();
862                let _: () = rxdone.blocking_recv().unwrap();
863            });
864            let _ = rxwait.blocking_recv();
865            (id, hdl, txdone)
866        }
867
868        // make sockets 1..32
869        for i in 1..32 {
870            v.push(spawn_skt(i));
871        }
872
873        // make sockets 32..40
874        for i in 32..40 {
875            v.push(spawn_skt(i));
876        }
877
878        // drop socket 35
879        let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
880        let (_i, hdl, tx) = v.remove(pos);
881        tx.send(()).unwrap();
882        hdl.join().unwrap();
883
884        // make a new socket, it should be 35
885        v.push(spawn_skt(35));
886
887        // drop socket 4
888        let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
889        let (_i, hdl, tx) = v.remove(pos);
890        tx.send(()).unwrap();
891        hdl.join().unwrap();
892
893        // make a new socket, it should be 40
894        v.push(spawn_skt(40));
895
896        // make sockets 41..64
897        for i in 41..64 {
898            v.push(spawn_skt(i));
899        }
900
901        // make a new socket, it should be 4
902        v.push(spawn_skt(4));
903
904        // make sockets 64..255
905        for i in 64..255 {
906            v.push(spawn_skt(i));
907        }
908
909        // drop socket 212
910        let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
911        let (_i, hdl, tx) = v.remove(pos);
912        tx.send(()).unwrap();
913        hdl.join().unwrap();
914
915        // make a new socket, it should be 212
916        v.push(spawn_skt(212));
917
918        // Sockets exhausted (we never see 255)
919        let hdl = std::thread::spawn(move || {
920            let skt = Socket::<u64, _, _>::new(
921                &STACK,
922                Key(*b"TEST1234"),
923                Attributes {
924                    kind: FrameKind::ENDPOINT_REQ,
925                    discoverable: true,
926                },
927                None,
928            );
929            let skt = pin!(skt);
930            let hdl = skt.attach();
931            println!("{}", hdl.port());
932        });
933        assert!(hdl.join().is_err());
934
935        for (_i, hdl, tx) in v.drain(..) {
936            tx.send(()).unwrap();
937            hdl.join().unwrap();
938        }
939    }
940}