ergot_base/
net_stack.rs

1//! The Ergot NetStack
2//!
3//! The [`NetStack`] is the core of Ergot. It is intended to be placed
4//! in a `static` variable for the duration of your application.
5//!
6//! The Netstack is used directly for a couple of main responsibilities:
7//!
8//! 1. Sending a message, either from user code, or to deliver/forward messages
9//!    received from an interface
10//! 2. Attaching a socket, allowing the NetStack to route messages to it
11//! 3. Interacting with the [interface manager], in order to add/remove
12//!    interfaces, or obtain other information
13//!
14//! [interface manager]: crate::interface_manager
15//!
16//! In general, interacting with anything contained by the [`NetStack`] requires
17//! locking of the [`BlockingMutex`] which protects the inner contents. This
18//! is used both to allow sharing of the inner contents, but also to allow
19//! `Drop` impls to remove themselves from the stack in a blocking manner.
20
21use core::{any::TypeId, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29    FrameKind, Header, ProtocolError,
30    interface_manager::{self, InterfaceManager, InterfaceSendError},
31    socket::{SocketHeader, SocketSendError, SocketVTable},
32};
33
34/// The Ergot Netstack
35pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
36    inner: BlockingMutex<R, NetStackInner<M>>,
37}
38
39pub(crate) struct NetStackInner<M: InterfaceManager> {
40    sockets: List<SocketHeader>,
41    manager: M,
42    pcache_bits: u32,
43    pcache_start: u8,
44    seq_no: u16,
45}
46
47/// An error from calling a [`NetStack`] "send" method
48#[derive(Debug, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum NetStackSendError {
51    SocketSend(SocketSendError),
52    InterfaceSend(InterfaceSendError),
53    NoRoute,
54    AnyPortMissingKey,
55    WrongPortKind,
56    AnyPortNotUnique,
57    AllPortMissingKey,
58}
59
60// ---- impl NetStack ----
61
62impl<R, M> NetStack<R, M>
63where
64    R: ScopedRawMutex + ConstInit,
65    M: InterfaceManager + interface_manager::ConstInit,
66{
67    /// Create a new, uninitialized [`NetStack`].
68    ///
69    /// Requires that the [`ScopedRawMutex`] implements the [`mutex::ConstInit`]
70    /// trait, and the [`InterfaceManager`] implements the
71    /// [`interface_manager::ConstInit`] trait.
72    ///
73    /// ## Example
74    ///
75    /// ```rust
76    /// use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
77    /// use ergot_base::NetStack;
78    /// use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
79    ///
80    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
81    /// ```
82    pub const fn new() -> Self {
83        Self {
84            inner: BlockingMutex::new(NetStackInner::new()),
85        }
86    }
87}
88
89impl<R, M> NetStack<R, M>
90where
91    R: ScopedRawMutex,
92    M: InterfaceManager,
93{
94    /// Manually create a new, uninitialized [`NetStack`].
95    ///
96    /// This method is useful if your [`ScopedRawMutex`] or [`InterfaceManager`]
97    /// do not implement their corresponding `ConstInit` trait.
98    ///
99    /// In general, this is most often only needed for `loom` testing, and
100    /// [`NetStack::new()`] should be used when possible.
101    pub const fn const_new(r: R, m: M) -> Self {
102        Self {
103            inner: BlockingMutex::const_new(
104                r,
105                NetStackInner {
106                    sockets: List::new(),
107                    manager: m,
108                    seq_no: 0,
109                    pcache_start: 0,
110                    pcache_bits: 0,
111                },
112            ),
113        }
114    }
115
116    /// Access the contained [`InterfaceManager`].
117    ///
118    /// Access to the [`InterfaceManager`] is made via the provided closure.
119    /// The [`BlockingMutex`] is locked for the duration of this access,
120    /// inhibiting all other usage of this [`NetStack`].
121    ///
122    /// This can be used to add new interfaces, obtain metadata, or other
123    /// actions supported by the chosen [`InterfaceManager`].
124    ///
125    /// ## Example
126    ///
127    /// ```rust
128    /// # use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
129    /// # use ergot_base::NetStack;
130    /// # use ergot_base::interface_manager::null::NullInterfaceManager as NullIM;
131    /// #
132    /// static STACK: NetStack<CSRMutex, NullIM> = NetStack::new();
133    ///
134    /// let res = STACK.with_interface_manager(|im| {
135    ///    // The mutex is locked for the full duration of this closure.
136    ///    # _ = im;
137    ///    // We can return whatever we want from this context, though not
138    ///    // anything borrowed from `im`.
139    ///    42
140    /// });
141    /// assert_eq!(res, 42);
142    /// ```
143    pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&'static self, f: F) -> U {
144        self.inner.with_lock(|inner| f(&mut inner.manager))
145    }
146
147    /// Send a raw (pre-serialized) message.
148    ///
149    /// This interface should almost never be used by end-users, and is instead
150    /// typically used by interfaces to feed received messages into the
151    /// [`NetStack`].
152    pub fn send_raw(&'static self, hdr: &Header, body: &[u8]) -> Result<(), NetStackSendError> {
153        self.inner.with_lock(|inner| inner.send_raw(hdr, body))
154    }
155
156    /// Send a typed message
157    pub fn send_ty<T: 'static + Serialize + Clone>(
158        &'static self,
159        hdr: &Header,
160        t: &T,
161    ) -> Result<(), NetStackSendError> {
162        self.inner.with_lock(|inner| inner.send_ty(hdr, t))
163    }
164
165    pub fn send_err(
166        &'static self,
167        hdr: &Header,
168        err: ProtocolError,
169    ) -> Result<(), NetStackSendError> {
170        self.inner.with_lock(|inner| inner.send_err(hdr, err))
171    }
172
173    pub(crate) unsafe fn try_attach_socket(
174        &'static self,
175        mut node: NonNull<SocketHeader>,
176    ) -> Option<u8> {
177        self.inner.with_lock(|inner| {
178            let new_port = inner.alloc_port()?;
179            unsafe {
180                node.as_mut().port = new_port;
181            }
182
183            inner.sockets.push_front(node);
184            Some(new_port)
185        })
186    }
187
188    pub(crate) unsafe fn attach_broadcast_socket(&'static self, mut node: NonNull<SocketHeader>) {
189        self.inner.with_lock(|inner| {
190            unsafe {
191                node.as_mut().port = 255;
192            }
193            inner.sockets.push_back(node);
194        });
195    }
196
197    pub(crate) unsafe fn attach_socket(&'static self, node: NonNull<SocketHeader>) -> u8 {
198        let res = unsafe { self.try_attach_socket(node) };
199        let Some(new_port) = res else {
200            panic!("exhausted all addrs");
201        };
202        new_port
203    }
204
205    pub(crate) unsafe fn detach_socket(&'static self, node: NonNull<SocketHeader>) {
206        self.inner.with_lock(|inner| unsafe {
207            let port = node.as_ref().port;
208            if port != 255 {
209                inner.free_port(port);
210            }
211            inner.sockets.remove(node)
212        });
213    }
214
215    pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&'static self, f: F) -> U {
216        self.inner.with_lock(|_inner| f())
217    }
218}
219
220impl<R, M> Default for NetStack<R, M>
221where
222    R: ScopedRawMutex + ConstInit,
223    M: InterfaceManager + interface_manager::ConstInit,
224{
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230// ---- impl NetStackInner ----
231
232impl<M> NetStackInner<M>
233where
234    M: InterfaceManager,
235    M: interface_manager::ConstInit,
236{
237    pub const fn new() -> Self {
238        Self {
239            sockets: List::new(),
240            manager: M::INIT,
241            seq_no: 0,
242            pcache_bits: 0,
243            pcache_start: 0,
244        }
245    }
246}
247
248impl<M> NetStackInner<M>
249where
250    M: InterfaceManager,
251{
252    /// Method that handles broadcast logic
253    ///
254    /// Takes closures for sending to a socket or sending to the manager to allow
255    /// for abstracting over send_raw/send_ty.
256    fn broadcast<SendSocket, SendMgr>(
257        sockets: &mut List<SocketHeader>,
258        hdr: &Header,
259        mut sskt: SendSocket,
260        smgr: SendMgr,
261    ) -> Result<(), NetStackSendError>
262    where
263        SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
264        SendMgr: FnOnce() -> bool,
265    {
266        trace!("Sending msg broadcast w/ header: {hdr:?}");
267        let res_lcl = {
268            let bcast_iter = Self::find_all_local(sockets, hdr)?;
269            let mut any_found = false;
270            for dst in bcast_iter {
271                let res = sskt(dst);
272                if res {
273                    debug!("delivered broadcast message locally");
274                }
275                any_found |= res;
276            }
277            any_found
278        };
279
280        let res_rmt = smgr();
281        if res_rmt {
282            debug!("delivered broadcast message remotely");
283        }
284
285        if res_lcl || res_rmt {
286            Ok(())
287        } else {
288            Err(NetStackSendError::NoRoute)
289        }
290    }
291
292    /// Method that handles unicast logic
293    ///
294    /// Takes closures for sending to a socket or sending to the manager to allow
295    /// for abstracting over send_raw/send_ty.
296    fn unicast<SendSocket, SendMgr>(
297        sockets: &mut List<SocketHeader>,
298        hdr: &Header,
299        sskt: SendSocket,
300        smgr: SendMgr,
301    ) -> Result<(), NetStackSendError>
302    where
303        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
304        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
305    {
306        trace!("Sending msg unicast w/ header: {hdr:?}");
307        // Can we assume the destination is local?
308        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
309
310        let res = if !local_bypass {
311            // Not local: offer to the interface manager to send
312            debug!("Offering msg externally unicast w/ header: {hdr:?}");
313            smgr()
314        } else {
315            // just skip to local sending
316            Err(InterfaceSendError::DestinationLocal)
317        };
318
319        match res {
320            Ok(()) => {
321                debug!("Externally routed msg unicast");
322                return Ok(());
323            }
324            Err(InterfaceSendError::DestinationLocal) => {
325                debug!("No external interest in msg unicast");
326            }
327            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
328        }
329
330        // It was a destination local error, try to honor that
331        let socket = if hdr.dst.port_id == 0 {
332            debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
333            Self::find_any_local(sockets, hdr)
334        } else {
335            debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
336            Self::find_one_local(sockets, hdr)
337        }?;
338
339        sskt(socket)
340    }
341
342    /// Method that handles unicast logic
343    ///
344    /// Takes closures for sending to a socket or sending to the manager to allow
345    /// for abstracting over send_raw/send_ty.
346    fn unicast_err<SendSocket, SendMgr>(
347        sockets: &mut List<SocketHeader>,
348        hdr: &Header,
349        sskt: SendSocket,
350        smgr: SendMgr,
351    ) -> Result<(), NetStackSendError>
352    where
353        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
354        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
355    {
356        trace!("Sending err unicast w/ header: {hdr:?}");
357        // Can we assume the destination is local?
358        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
359
360        let res = if !local_bypass {
361            // Not local: offer to the interface manager to send
362            debug!("Offering err externally unicast w/ header: {hdr:?}");
363            smgr()
364        } else {
365            // just skip to local sending
366            Err(InterfaceSendError::DestinationLocal)
367        };
368
369        match res {
370            Ok(()) => {
371                debug!("Externally routed err unicast");
372                return Ok(());
373            }
374            Err(InterfaceSendError::DestinationLocal) => {
375                debug!("No external interest in err unicast");
376            }
377            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
378        }
379
380        // It was a destination local error, try to honor that
381        let socket = Self::find_one_err_local(sockets, hdr)?;
382
383        sskt(socket)
384    }
385
386    /// Handle sending of a raw (serialized) message
387    fn send_raw(&mut self, hdr: &Header, body: &[u8]) -> Result<(), NetStackSendError> {
388        let Self {
389            sockets,
390            seq_no,
391            manager,
392            ..
393        } = self;
394        trace!("Sending msg raw w/ header: {hdr:?}");
395
396        if hdr.kind == FrameKind::PROTOCOL_ERROR {
397            todo!("Don't do that");
398        }
399
400        // Is this a broadcast message?
401        if hdr.dst.port_id == 255 {
402            Self::broadcast(
403                sockets,
404                hdr,
405                |skt| Self::send_raw_to_socket(skt, body, hdr, seq_no).is_ok(),
406                || manager.send_raw(hdr, body).is_ok(),
407            )
408        } else {
409            Self::unicast(
410                sockets,
411                hdr,
412                |skt| Self::send_raw_to_socket(skt, body, hdr, seq_no),
413                || manager.send_raw(hdr, body),
414            )
415        }
416    }
417
418    /// Handle sending of a typed message
419    fn send_ty<T: 'static + Serialize + Clone>(
420        &mut self,
421        hdr: &Header,
422        t: &T,
423    ) -> Result<(), NetStackSendError> {
424        let Self {
425            sockets,
426            seq_no,
427            manager,
428            ..
429        } = self;
430        trace!("Sending msg ty w/ header: {hdr:?}");
431
432        if hdr.kind == FrameKind::PROTOCOL_ERROR {
433            todo!("Don't do that");
434        }
435
436        // Is this a broadcast message?
437        if hdr.dst.port_id == 255 {
438            Self::broadcast(
439                sockets,
440                hdr,
441                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
442                || manager.send(hdr, t).is_ok(),
443            )
444        } else {
445            Self::unicast(
446                sockets,
447                hdr,
448                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
449                || manager.send(hdr, t),
450            )
451        }
452    }
453
454    /// Handle sending of a typed message
455    fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
456        let Self {
457            sockets,
458            seq_no,
459            manager,
460            ..
461        } = self;
462        trace!("Sending msg ty w/ header: {hdr:?}");
463
464        if hdr.dst.port_id == 255 {
465            todo!("Don't do that");
466        }
467
468        // Is this a broadcast message?
469        Self::unicast_err(
470            sockets,
471            hdr,
472            |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
473            || manager.send_err(hdr, err),
474        )
475    }
476
477    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
478    /// the given header.
479    fn find_one_local(
480        sockets: &mut List<SocketHeader>,
481        hdr: &Header,
482    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
483        // Find the specific matching port
484        let mut iter = sockets.iter_raw();
485        let socket = loop {
486            let Some(skt) = iter.next() else {
487                return Err(NetStackSendError::NoRoute);
488            };
489            let skt_ref = unsafe { skt.as_ref() };
490            if skt_ref.port != hdr.dst.port_id {
491                continue;
492            }
493            if skt_ref.attrs.kind != hdr.kind {
494                return Err(NetStackSendError::WrongPortKind);
495            }
496            break skt;
497        };
498        Ok(socket)
499    }
500
501    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
502    /// the given header.
503    fn find_one_err_local(
504        sockets: &mut List<SocketHeader>,
505        hdr: &Header,
506    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
507        // Find the specific matching port
508        let mut iter = sockets.iter_raw();
509        let socket = loop {
510            let Some(skt) = iter.next() else {
511                return Err(NetStackSendError::NoRoute);
512            };
513            let skt_ref = unsafe { skt.as_ref() };
514            if skt_ref.port != hdr.dst.port_id {
515                continue;
516            }
517            break skt;
518        };
519        Ok(socket)
520    }
521
522    /// Find a wildcard (e.g. port_id == 0) destination port matching the given header.
523    ///
524    /// If more than one port matches the wildcard, an error is returned.
525    /// Does not match sockets that does not have the `discoverable` [`Attributes`].
526    fn find_any_local(
527        sockets: &mut List<SocketHeader>,
528        hdr: &Header,
529    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
530        // Find ONE specific matching port
531        let Some(key) = hdr.key.as_ref() else {
532            return Err(NetStackSendError::AnyPortMissingKey);
533        };
534        let mut iter = sockets.iter_raw();
535        let mut socket: Option<NonNull<SocketHeader>> = None;
536
537        loop {
538            let Some(skt) = iter.next() else {
539                break;
540            };
541            let skt_ref = unsafe { skt.as_ref() };
542
543            // Check for things that would disqualify a socket from being an
544            // "ANY" destination
545            let mut illegal = false;
546            illegal |= skt_ref.attrs.kind != hdr.kind;
547            illegal |= !skt_ref.attrs.discoverable;
548            illegal |= &skt_ref.key != key;
549
550            if illegal {
551                // Wait, that's illegal
552                continue;
553            }
554
555            // It's a match! Is it a second match?
556            if socket.is_some() {
557                return Err(NetStackSendError::AnyPortNotUnique);
558            }
559            // Nope! Store this one, then we keep going to ensure that no
560            // other socket matches this description.
561            socket = Some(skt);
562        }
563
564        socket.ok_or(NetStackSendError::NoRoute)
565    }
566
567    /// Find ALL broadcast (e.g. port_id == 255) sockets matching the given header.
568    ///
569    /// Returns an error if the header does not contain a Key. May return zero
570    /// matches.
571    fn find_all_local(
572        sockets: &mut List<SocketHeader>,
573        hdr: &Header,
574    ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
575        let Some(key) = hdr.key.as_ref() else {
576            return Err(NetStackSendError::AllPortMissingKey);
577        };
578        Ok(sockets.iter_raw().filter(move |socket| {
579            let skt_ref = unsafe { socket.as_ref() };
580            let bport = skt_ref.port == 255;
581            let dkind = skt_ref.attrs.kind == hdr.kind;
582            let dkey = &skt_ref.key == key;
583            bport && dkind && dkey
584        }))
585    }
586
587    /// Helper method for sending a type to a given socket
588    fn send_ty_to_socket<T: 'static + Serialize + Clone>(
589        this: NonNull<SocketHeader>,
590        t: &T,
591        hdr: &Header,
592        seq_no: &mut u16,
593    ) -> Result<(), NetStackSendError> {
594        let vtable: &'static SocketVTable = {
595            let skt_ref = unsafe { this.as_ref() };
596            skt_ref.vtable
597        };
598
599        if let Some(f) = vtable.recv_owned {
600            let this: NonNull<()> = this.cast();
601            let that: NonNull<T> = NonNull::from(t);
602            let that: NonNull<()> = that.cast();
603            let hdr = hdr.to_headerseq_or_with_seq(|| {
604                let seq = *seq_no;
605                *seq_no = seq_no.wrapping_add(1);
606                seq
607            });
608            (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
609        } else if let Some(_f) = vtable.recv_bor {
610            // TODO: support send borrowed
611            todo!()
612        } else {
613            // todo: keep going? If we found the "right" destination and
614            // sending fails, then there's not much we can do. Probably: there
615            // is no case where a socket has NEITHER send_owned NOR send_bor,
616            // can we make this state impossible instead?
617            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
618        }
619    }
620
621    /// Helper method for sending a type to a given socket
622    fn send_err_to_socket(
623        this: NonNull<SocketHeader>,
624        err: ProtocolError,
625        hdr: &Header,
626        seq_no: &mut u16,
627    ) -> Result<(), NetStackSendError> {
628        let vtable: &'static SocketVTable = {
629            let skt_ref = unsafe { this.as_ref() };
630            skt_ref.vtable
631        };
632
633        if let Some(f) = vtable.recv_err {
634            let this: NonNull<()> = this.cast();
635            let hdr = hdr.to_headerseq_or_with_seq(|| {
636                let seq = *seq_no;
637                *seq_no = seq_no.wrapping_add(1);
638                seq
639            });
640            (f)(this, hdr, err);
641            Ok(())
642        } else {
643            // todo: keep going? If we found the "right" destination and
644            // sending fails, then there's not much we can do. Probably: there
645            // is no case where a socket has NEITHER send_owned NOR send_bor,
646            // can we make this state impossible instead?
647            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
648        }
649    }
650
651    // /// Helper message for sending a raw message to a given socket
652    // fn send_err_raw_to_socket(
653    //     this: NonNull<SocketHeader>,
654    //     body: &[u8],
655    //     hdr: &Header,
656    //     seq_no: &mut u16,
657    // ) -> Result<(), NetStackSendError> {
658    //     let vtable: &'static SocketVTable = {
659    //         let skt_ref = unsafe { this.as_ref() };
660    //         skt_ref.vtable
661    //     };
662    //     let f = vtable.recv_raw;
663
664    //     let this: NonNull<()> = this.cast();
665    //     let hdr = hdr.to_headerseq_or_with_seq(|| {
666    //         let seq = *seq_no;
667    //         *seq_no = seq_no.wrapping_add(1);
668    //         seq
669    //     });
670
671    //     (f)(this, body, hdr).map_err(NetStackSendError::SocketSend)
672    // }
673
674    /// Helper message for sending a raw message to a given socket
675    fn send_raw_to_socket(
676        this: NonNull<SocketHeader>,
677        body: &[u8],
678        hdr: &Header,
679        seq_no: &mut u16,
680    ) -> Result<(), NetStackSendError> {
681        let vtable: &'static SocketVTable = {
682            let skt_ref = unsafe { this.as_ref() };
683            skt_ref.vtable
684        };
685        let f = vtable.recv_raw;
686
687        let this: NonNull<()> = this.cast();
688        let hdr = hdr.to_headerseq_or_with_seq(|| {
689            let seq = *seq_no;
690            *seq_no = seq_no.wrapping_add(1);
691            seq
692        });
693
694        (f)(this, body, hdr).map_err(NetStackSendError::SocketSend)
695    }
696}
697
698impl<M> NetStackInner<M>
699where
700    M: InterfaceManager,
701{
702    /// Cache-based allocator inspired by littlefs2 ID allocator
703    ///
704    /// We remember 32 ports at a time, from the current base, which is always
705    /// a multiple of 32. Allocating from this range does not require moving thru
706    /// the socket lists.
707    ///
708    /// If the current 32 ports are all taken, we will start over from a base port
709    /// of 0, and attempt to
710    fn alloc_port(&mut self) -> Option<u8> {
711        // ports 0 is always taken (could be clear on first alloc)
712        self.pcache_bits |= (self.pcache_start == 0) as u32;
713
714        if self.pcache_bits != u32::MAX {
715            // We can allocate from the current slot
716            let ldg = self.pcache_bits.trailing_ones();
717            debug_assert!(ldg < 32);
718            self.pcache_bits |= 1 << ldg;
719            return Some(self.pcache_start + (ldg as u8));
720        }
721
722        // Nope, cache is all taken. try to find a base with available items.
723        // We always start from the bottom to keep ports small, but if we know
724        // we just exhausted a range, don't waste time checking that
725        let old_start = self.pcache_start;
726        for base in 0..8 {
727            let start = base * 32;
728            if start == old_start {
729                continue;
730            }
731            // Clear/reset cache
732            self.pcache_start = start;
733            self.pcache_bits = 0;
734            // port 0 is not allowed
735            self.pcache_bits |= (self.pcache_start == 0) as u32;
736            // port 255 is not allowed
737            self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
738
739            // TODO: If we trust that sockets are always sorted, we could early-return
740            // when we reach a `pupper > self.pcache_start`. We could also maybe be smart
741            // and iterate forwards for 0..4 and backwards for 4..8 (and switch the early
742            // return check to < instead). NOTE: We currently do NOT guarantee sockets are
743            // sorted!
744            self.sockets.iter().for_each(|s| {
745                if s.port == 255 {
746                    return;
747                }
748
749                // The upper 3 bits of the port
750                let pupper = s.port & !(32 - 1);
751                // The lower 5 bits of the port
752                let plower = s.port & (32 - 1);
753
754                if pupper == self.pcache_start {
755                    self.pcache_bits |= 1 << plower;
756                }
757            });
758
759            if self.pcache_bits != u32::MAX {
760                // We can allocate from the current slot
761                let ldg = self.pcache_bits.trailing_ones();
762                debug_assert!(ldg < 32);
763                self.pcache_bits |= 1 << ldg;
764                return Some(self.pcache_start + (ldg as u8));
765            }
766        }
767
768        // Nope, nothing found
769        None
770    }
771
772    fn free_port(&mut self, port: u8) {
773        debug_assert!(port != 255);
774        // The upper 3 bits of the port
775        let pupper = port & !(32 - 1);
776        // The lower 5 bits of the port
777        let plower = port & (32 - 1);
778
779        // TODO: If the freed port is in the 0..32 range, or just less than
780        // the current start range, maybe do an opportunistic re-look?
781        if pupper == self.pcache_start {
782            self.pcache_bits &= !(1 << plower);
783        }
784    }
785}
786
787impl NetStackSendError {
788    pub fn to_error(&self) -> ProtocolError {
789        match self {
790            NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
791            NetStackSendError::InterfaceSend(interface_send_error) => {
792                interface_send_error.to_error()
793            }
794            NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
795            NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
796            NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
797            NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
798            NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
799        }
800    }
801}
802
803#[cfg(test)]
804mod test {
805    use core::pin::pin;
806    use mutex::raw_impls::cs::CriticalSectionRawMutex;
807    use std::thread::JoinHandle;
808    use tokio::sync::oneshot;
809
810    use crate::{
811        FrameKind, Key, NetStack,
812        interface_manager::null::NullInterfaceManager,
813        socket::{Attributes, owned::OwnedSocket},
814    };
815
816    #[test]
817    fn port_alloc() {
818        static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
819
820        let mut v = vec![];
821
822        fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
823            let (txdone, rxdone) = oneshot::channel();
824            let (txwait, rxwait) = oneshot::channel();
825            let hdl = std::thread::spawn(move || {
826                let skt = OwnedSocket::<u64, _, _>::new(
827                    &STACK,
828                    Key(*b"TEST1234"),
829                    Attributes {
830                        kind: FrameKind::ENDPOINT_REQ,
831                        discoverable: true,
832                    },
833                );
834                let skt = pin!(skt);
835                let hdl = skt.attach();
836                assert_eq!(hdl.port(), id);
837                txwait.send(()).unwrap();
838                let _: () = rxdone.blocking_recv().unwrap();
839            });
840            let _ = rxwait.blocking_recv();
841            (id, hdl, txdone)
842        }
843
844        // make sockets 1..32
845        for i in 1..32 {
846            v.push(spawn_skt(i));
847        }
848
849        // make sockets 32..40
850        for i in 32..40 {
851            v.push(spawn_skt(i));
852        }
853
854        // drop socket 35
855        let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
856        let (_i, hdl, tx) = v.remove(pos);
857        tx.send(()).unwrap();
858        hdl.join().unwrap();
859
860        // make a new socket, it should be 35
861        v.push(spawn_skt(35));
862
863        // drop socket 4
864        let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
865        let (_i, hdl, tx) = v.remove(pos);
866        tx.send(()).unwrap();
867        hdl.join().unwrap();
868
869        // make a new socket, it should be 40
870        v.push(spawn_skt(40));
871
872        // make sockets 41..64
873        for i in 41..64 {
874            v.push(spawn_skt(i));
875        }
876
877        // make a new socket, it should be 4
878        v.push(spawn_skt(4));
879
880        // make sockets 64..255
881        for i in 64..255 {
882            v.push(spawn_skt(i));
883        }
884
885        // drop socket 212
886        let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
887        let (_i, hdl, tx) = v.remove(pos);
888        tx.send(()).unwrap();
889        hdl.join().unwrap();
890
891        // make a new socket, it should be 212
892        v.push(spawn_skt(212));
893
894        // Sockets exhausted (we never see 255)
895        let hdl = std::thread::spawn(move || {
896            let skt = OwnedSocket::<u64, _, _>::new(
897                &STACK,
898                Key(*b"TEST1234"),
899                Attributes {
900                    kind: FrameKind::ENDPOINT_REQ,
901                    discoverable: true,
902                },
903            );
904            let skt = pin!(skt);
905            let hdl = skt.attach();
906            println!("{}", hdl.port());
907        });
908        assert!(hdl.join().is_err());
909
910        for (_i, hdl, tx) in v.drain(..) {
911            tx.send(()).unwrap();
912            hdl.join().unwrap();
913        }
914    }
915}