ergot_base/
net_stack.rs

1//! The Ergot NetStack
2//!
3//! The [`NetStack`] is the core of Ergot. It is intended to be placed
4//! in a `static` variable for the duration of your application.
5//!
6//! The Netstack is used directly for a couple of main responsibilities:
7//!
8//! 1. Sending a message, either from user code, or to deliver/forward messages
9//!    received from an interface
10//! 2. Attaching a socket, allowing the NetStack to route messages to it
11//! 3. Interacting with the [interface manager], in order to add/remove
12//!    interfaces, or obtain other information
13//!
14//! [interface manager]: crate::interface_manager
15//!
16//! In general, interacting with anything contained by the [`NetStack`] requires
17//! locking of the [`BlockingMutex`] which protects the inner contents. This
18//! is used both to allow sharing of the inner contents, but also to allow
19//! `Drop` impls to remove themselves from the stack in a blocking manner.
20
21use core::{any::TypeId, ops::Deref, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29    FrameKind, Header, ProtocolError,
30    interface_manager::{self, InterfaceSendError, Profile},
31    socket::{SocketHeader, SocketSendError, SocketVTable},
32};
33
34/// The Ergot Netstack
35pub struct NetStack<R: ScopedRawMutex, P: Profile> {
36    inner: BlockingMutex<R, NetStackInner<P>>,
37}
38
39pub trait NetStackHandle
40where
41    Self: Sized + Clone,
42{
43    type Target: Deref<Target = NetStack<Self::Mutex, Self::Profile>> + Clone;
44    type Mutex: ScopedRawMutex;
45    type Profile: Profile;
46    fn stack(&self) -> Self::Target;
47}
48
49pub(crate) struct NetStackInner<P: Profile> {
50    sockets: List<SocketHeader>,
51    profile: P,
52    pcache_bits: u32,
53    pcache_start: u8,
54    seq_no: u16,
55}
56
57/// An error from calling a [`NetStack`] "send" method
58#[derive(Debug, PartialEq, Eq)]
59#[non_exhaustive]
60pub enum NetStackSendError {
61    SocketSend(SocketSendError),
62    InterfaceSend(InterfaceSendError),
63    NoRoute,
64    AnyPortMissingKey,
65    WrongPortKind,
66    AnyPortNotUnique,
67    AllPortMissingKey,
68}
69
70// ---- impl NetStack ----
71
72// TODO: Impl for Arc
73impl<R, P> NetStackHandle for &'_ NetStack<R, P>
74where
75    R: ScopedRawMutex,
76    P: Profile,
77{
78    type Mutex = R;
79    type Profile = P;
80    type Target = Self;
81
82    fn stack(&self) -> Self::Target {
83        self
84    }
85}
86
87#[cfg(feature = "std")]
88impl<R, P> NetStackHandle for std::sync::Arc<NetStack<R, P>>
89where
90    R: ScopedRawMutex,
91    P: Profile,
92{
93    type Mutex = R;
94    type Profile = P;
95    type Target = Self;
96
97    fn stack(&self) -> Self::Target {
98        self.clone()
99    }
100}
101
102impl<R, P> NetStack<R, P>
103where
104    R: ScopedRawMutex + ConstInit,
105    P: Profile + interface_manager::ConstInit,
106{
107    /// Create a new, uninitialized [`NetStack`].
108    ///
109    /// Requires that the [`ScopedRawMutex`] implements the [`mutex::ConstInit`]
110    /// trait, and the [`Profile`] implements the
111    /// [`interface_manager::ConstInit`] trait.
112    ///
113    /// ## Example
114    ///
115    /// ```rust
116    /// use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
117    /// use ergot_base::NetStack;
118    /// use ergot_base::interface_manager::profiles::null::Null;
119    ///
120    /// static STACK: NetStack<CSRMutex, Null> = NetStack::new();
121    /// ```
122    pub const fn new() -> Self {
123        Self {
124            inner: BlockingMutex::new(NetStackInner::new()),
125        }
126    }
127}
128
129impl<R, P> NetStack<R, P>
130where
131    R: ScopedRawMutex + ConstInit,
132    P: Profile,
133{
134    pub const fn new_with_profile(p: P) -> Self {
135        Self {
136            inner: BlockingMutex::new(NetStackInner::new_with_profile(p)),
137        }
138    }
139}
140
141#[cfg(feature = "std")]
142impl<R, P> NetStack<R, P>
143where
144    R: ScopedRawMutex + ConstInit,
145    P: Profile,
146{
147    pub fn new_arc(p: P) -> std::sync::Arc<Self> {
148        std::sync::Arc::new(Self {
149            inner: BlockingMutex::new(NetStackInner::new_with_profile(p)),
150        })
151    }
152}
153
154impl<R, P> NetStack<R, P>
155where
156    R: ScopedRawMutex,
157    P: Profile,
158{
159    /// Manually create a new, uninitialized [`NetStack`].
160    ///
161    /// This method is useful if your [`ScopedRawMutex`] or [`Profile`]
162    /// do not implement their corresponding `ConstInit` trait.
163    ///
164    /// In general, this is most often only needed for `loom` testing, and
165    /// [`NetStack::new()`] should be used when possible.
166    pub const fn const_new(r: R, p: P) -> Self {
167        Self {
168            inner: BlockingMutex::const_new(
169                r,
170                NetStackInner {
171                    sockets: List::new(),
172                    profile: p,
173                    seq_no: 0,
174                    pcache_start: 0,
175                    pcache_bits: 0,
176                },
177            ),
178        }
179    }
180
181    /// Access the contained [`Profile`].
182    ///
183    /// Access to the [`Profile`] is made via the provided closure.
184    /// The [`BlockingMutex`] is locked for the duration of this access,
185    /// inhibiting all other usage of this [`NetStack`].
186    ///
187    /// This can be used to add new interfaces, obtain metadata, or other
188    /// actions supported by the chosen [`Profile`].
189    ///
190    /// ## Example
191    ///
192    /// ```rust
193    /// # use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
194    /// # use ergot_base::NetStack;
195    /// # use ergot_base::interface_manager::profiles::null::Null;
196    /// #
197    /// static STACK: NetStack<CSRMutex, Null> = NetStack::new();
198    ///
199    /// let res = STACK.manage_profile(|im| {
200    ///    // The mutex is locked for the full duration of this closure.
201    ///    # _ = im;
202    ///    // We can return whatever we want from this context, though not
203    ///    // anything borrowed from `im`.
204    ///    42
205    /// });
206    /// assert_eq!(res, 42);
207    /// ```
208    pub fn manage_profile<F: FnOnce(&mut P) -> U, U>(&self, f: F) -> U {
209        self.inner.with_lock(|inner| f(&mut inner.profile))
210    }
211
212    /// Send a raw (pre-serialized) message.
213    ///
214    /// This interface should almost never be used by end-users, and is instead
215    /// typically used by interfaces to feed received messages into the
216    /// [`NetStack`].
217    pub fn send_raw(
218        &self,
219        hdr: &Header,
220        hdr_raw: &[u8],
221        body: &[u8],
222    ) -> Result<(), NetStackSendError> {
223        self.inner
224            .with_lock(|inner| inner.send_raw(hdr, hdr_raw, body))
225    }
226
227    /// Send a typed message
228    pub fn send_ty<T: 'static + Serialize + Clone>(
229        &self,
230        hdr: &Header,
231        t: &T,
232    ) -> Result<(), NetStackSendError> {
233        self.inner.with_lock(|inner| inner.send_ty(hdr, t))
234    }
235
236    pub fn send_err(&self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
237        self.inner.with_lock(|inner| inner.send_err(hdr, err))
238    }
239
240    pub(crate) unsafe fn try_attach_socket(&self, mut node: NonNull<SocketHeader>) -> Option<u8> {
241        self.inner.with_lock(|inner| {
242            let new_port = inner.alloc_port()?;
243            unsafe {
244                node.as_mut().port = new_port;
245            }
246
247            inner.sockets.push_front(node);
248            Some(new_port)
249        })
250    }
251
252    pub(crate) unsafe fn attach_broadcast_socket(&self, mut node: NonNull<SocketHeader>) {
253        self.inner.with_lock(|inner| {
254            unsafe {
255                node.as_mut().port = 255;
256            }
257            inner.sockets.push_back(node);
258        });
259    }
260
261    pub(crate) unsafe fn attach_socket(&self, node: NonNull<SocketHeader>) -> u8 {
262        let res = unsafe { self.try_attach_socket(node) };
263        let Some(new_port) = res else {
264            panic!("exhausted all addrs");
265        };
266        new_port
267    }
268
269    pub(crate) unsafe fn detach_socket(&self, node: NonNull<SocketHeader>) {
270        self.inner.with_lock(|inner| unsafe {
271            let port = node.as_ref().port;
272            if port != 255 {
273                inner.free_port(port);
274            }
275            inner.sockets.remove(node)
276        });
277    }
278
279    pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&self, f: F) -> U {
280        self.inner.with_lock(|_inner| f())
281    }
282}
283
284impl<R, P> Default for NetStack<R, P>
285where
286    R: ScopedRawMutex + ConstInit,
287    P: Profile + interface_manager::ConstInit,
288{
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294// ---- impl NetStackInner ----
295
296impl<P> NetStackInner<P>
297where
298    P: Profile,
299    P: interface_manager::ConstInit,
300{
301    pub const fn new() -> Self {
302        Self {
303            sockets: List::new(),
304            profile: P::INIT,
305            seq_no: 0,
306            pcache_bits: 0,
307            pcache_start: 0,
308        }
309    }
310}
311
312impl<P> NetStackInner<P>
313where
314    P: Profile,
315{
316    /// Create a netstack with a given profile
317    pub const fn new_with_profile(p: P) -> Self {
318        Self {
319            sockets: List::new(),
320            profile: p,
321            seq_no: 0,
322            pcache_bits: 0,
323            pcache_start: 0,
324        }
325    }
326
327    /// Method that handles broadcast logic
328    ///
329    /// Takes closures for sending to a socket or sending to the manager to allow
330    /// for abstracting over send_raw/send_ty.
331    fn broadcast<SendSocket, SendMgr>(
332        sockets: &mut List<SocketHeader>,
333        hdr: &Header,
334        mut sskt: SendSocket,
335        smgr: SendMgr,
336    ) -> Result<(), NetStackSendError>
337    where
338        SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
339        SendMgr: FnOnce() -> bool,
340    {
341        trace!("Sending msg broadcast w/ header: {hdr:?}");
342        let res_lcl = {
343            let bcast_iter = Self::find_all_local(sockets, hdr)?;
344            let mut any_found = false;
345            for dst in bcast_iter {
346                let res = sskt(dst);
347                if res {
348                    debug!("delivered broadcast message locally");
349                }
350                any_found |= res;
351            }
352            any_found
353        };
354
355        let res_rmt = smgr();
356        if res_rmt {
357            debug!("delivered broadcast message remotely");
358        }
359
360        if res_lcl || res_rmt {
361            Ok(())
362        } else {
363            Err(NetStackSendError::NoRoute)
364        }
365    }
366
367    /// Method that handles unicast logic
368    ///
369    /// Takes closures for sending to a socket or sending to the manager to allow
370    /// for abstracting over send_raw/send_ty.
371    fn unicast<SendSocket, SendMgr>(
372        sockets: &mut List<SocketHeader>,
373        hdr: &Header,
374        sskt: SendSocket,
375        smgr: SendMgr,
376    ) -> Result<(), NetStackSendError>
377    where
378        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
379        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
380    {
381        trace!("Sending msg unicast w/ header: {hdr:?}");
382        // Can we assume the destination is local?
383        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
384
385        let res = if !local_bypass {
386            // Not local: offer to the interface manager to send
387            debug!("Offering msg externally unicast w/ header: {hdr:?}");
388            smgr()
389        } else {
390            // just skip to local sending
391            Err(InterfaceSendError::DestinationLocal)
392        };
393
394        match res {
395            Ok(()) => {
396                debug!("Externally routed msg unicast");
397                return Ok(());
398            }
399            Err(InterfaceSendError::DestinationLocal) => {
400                debug!("No external interest in msg unicast");
401            }
402            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
403        }
404
405        // It was a destination local error, try to honor that
406        let socket = if hdr.dst.port_id == 0 {
407            debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
408            Self::find_any_local(sockets, hdr)
409        } else {
410            debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
411            Self::find_one_local(sockets, hdr)
412        }?;
413
414        sskt(socket)
415    }
416
417    /// Method that handles unicast logic
418    ///
419    /// Takes closures for sending to a socket or sending to the manager to allow
420    /// for abstracting over send_raw/send_ty.
421    fn unicast_err<SendSocket, SendMgr>(
422        sockets: &mut List<SocketHeader>,
423        hdr: &Header,
424        sskt: SendSocket,
425        smgr: SendMgr,
426    ) -> Result<(), NetStackSendError>
427    where
428        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
429        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
430    {
431        trace!("Sending err unicast w/ header: {hdr:?}");
432        // Can we assume the destination is local?
433        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
434
435        let res = if !local_bypass {
436            // Not local: offer to the interface manager to send
437            debug!("Offering err externally unicast w/ header: {hdr:?}");
438            smgr()
439        } else {
440            // just skip to local sending
441            Err(InterfaceSendError::DestinationLocal)
442        };
443
444        match res {
445            Ok(()) => {
446                debug!("Externally routed err unicast");
447                return Ok(());
448            }
449            Err(InterfaceSendError::DestinationLocal) => {
450                debug!("No external interest in err unicast");
451            }
452            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
453        }
454
455        // It was a destination local error, try to honor that
456        let socket = Self::find_one_err_local(sockets, hdr)?;
457
458        sskt(socket)
459    }
460
461    /// Handle sending of a raw (serialized) message
462    fn send_raw(
463        &mut self,
464        hdr: &Header,
465        hdr_raw: &[u8],
466        body: &[u8],
467    ) -> Result<(), NetStackSendError> {
468        let Self {
469            sockets,
470            seq_no,
471            profile: manager,
472            ..
473        } = self;
474        trace!("Sending msg raw w/ header: {hdr:?}");
475
476        if hdr.kind == FrameKind::PROTOCOL_ERROR {
477            todo!("Don't do that");
478        }
479
480        // Is this a broadcast message?
481        if hdr.dst.port_id == 255 {
482            Self::broadcast(
483                sockets,
484                hdr,
485                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no).is_ok(),
486                || manager.send_raw(hdr, hdr_raw, body).is_ok(),
487            )
488        } else {
489            Self::unicast(
490                sockets,
491                hdr,
492                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no),
493                || manager.send_raw(hdr, hdr_raw, body),
494            )
495        }
496    }
497
498    /// Handle sending of a typed message
499    fn send_ty<T: 'static + Serialize + Clone>(
500        &mut self,
501        hdr: &Header,
502        t: &T,
503    ) -> Result<(), NetStackSendError> {
504        let Self {
505            sockets,
506            seq_no,
507            profile: manager,
508            ..
509        } = self;
510        trace!("Sending msg ty w/ header: {hdr:?}");
511
512        if hdr.kind == FrameKind::PROTOCOL_ERROR {
513            todo!("Don't do that");
514        }
515
516        // Is this a broadcast message?
517        if hdr.dst.port_id == 255 {
518            Self::broadcast(
519                sockets,
520                hdr,
521                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
522                || manager.send(hdr, t).is_ok(),
523            )
524        } else {
525            Self::unicast(
526                sockets,
527                hdr,
528                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
529                || manager.send(hdr, t),
530            )
531        }
532    }
533
534    /// Handle sending of a typed message
535    fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
536        let Self {
537            sockets,
538            seq_no,
539            profile: manager,
540            ..
541        } = self;
542        trace!("Sending msg ty w/ header: {hdr:?}");
543
544        if hdr.dst.port_id == 255 {
545            todo!("Don't do that");
546        }
547
548        // Is this a broadcast message?
549        Self::unicast_err(
550            sockets,
551            hdr,
552            |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
553            || manager.send_err(hdr, err),
554        )
555    }
556
557    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
558    /// the given header.
559    fn find_one_local(
560        sockets: &mut List<SocketHeader>,
561        hdr: &Header,
562    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
563        // Find the specific matching port
564        let mut iter = sockets.iter_raw();
565        let socket = loop {
566            let Some(skt) = iter.next() else {
567                return Err(NetStackSendError::NoRoute);
568            };
569            let skt_ref = unsafe { skt.as_ref() };
570            if skt_ref.port != hdr.dst.port_id {
571                continue;
572            }
573            if skt_ref.attrs.kind != hdr.kind {
574                return Err(NetStackSendError::WrongPortKind);
575            }
576            break skt;
577        };
578        Ok(socket)
579    }
580
581    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
582    /// the given header.
583    fn find_one_err_local(
584        sockets: &mut List<SocketHeader>,
585        hdr: &Header,
586    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
587        // Find the specific matching port
588        let mut iter = sockets.iter_raw();
589        let socket = loop {
590            let Some(skt) = iter.next() else {
591                return Err(NetStackSendError::NoRoute);
592            };
593            let skt_ref = unsafe { skt.as_ref() };
594            if skt_ref.port != hdr.dst.port_id {
595                continue;
596            }
597            break skt;
598        };
599        Ok(socket)
600    }
601
602    /// Find a wildcard (e.g. port_id == 0) destination port matching the given header.
603    ///
604    /// If more than one port matches the wildcard, an error is returned.
605    /// Does not match sockets that does not have the `discoverable` [`Attributes`].
606    fn find_any_local(
607        sockets: &mut List<SocketHeader>,
608        hdr: &Header,
609    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
610        // Find ONE specific matching port
611        let Some(apdx) = hdr.any_all.as_ref() else {
612            return Err(NetStackSendError::AnyPortMissingKey);
613        };
614        let mut iter = sockets.iter_raw();
615        let mut socket: Option<NonNull<SocketHeader>> = None;
616
617        loop {
618            let Some(skt) = iter.next() else {
619                break;
620            };
621            let skt_ref = unsafe { skt.as_ref() };
622
623            // Check for things that would disqualify a socket from being an
624            // "ANY" destination
625            let mut illegal = false;
626            illegal |= skt_ref.attrs.kind != hdr.kind;
627            illegal |= !skt_ref.attrs.discoverable;
628            illegal |= skt_ref.key != apdx.key;
629            if let Some(nash) = apdx.nash {
630                illegal |= Some(nash) != skt_ref.nash;
631            }
632
633            if illegal {
634                // Wait, that's illegal
635                continue;
636            }
637
638            // It's a match! Is it a second match?
639            if socket.is_some() {
640                return Err(NetStackSendError::AnyPortNotUnique);
641            }
642            // Nope! Store this one, then we keep going to ensure that no
643            // other socket matches this description.
644            socket = Some(skt);
645        }
646
647        socket.ok_or(NetStackSendError::NoRoute)
648    }
649
650    /// Find ALL broadcast (e.g. port_id == 255) sockets matching the given header.
651    ///
652    /// Returns an error if the header does not contain a Key. May return zero
653    /// matches.
654    fn find_all_local(
655        sockets: &mut List<SocketHeader>,
656        hdr: &Header,
657    ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
658        let Some(any_all) = hdr.any_all.as_ref() else {
659            return Err(NetStackSendError::AllPortMissingKey);
660        };
661        Ok(sockets.iter_raw().filter(move |socket| {
662            let skt_ref = unsafe { socket.as_ref() };
663            let bport = skt_ref.port == 255;
664            let dkind = skt_ref.attrs.kind == hdr.kind;
665            let dkey = skt_ref.key == any_all.key;
666
667            // If the any/all message DOES contain a name hash, then ONLY match
668            // sockets with the same name hash.
669            let name = if let Some(nash) = any_all.nash {
670                Some(nash) == skt_ref.nash
671            } else {
672                true
673            };
674            bport && dkind && dkey && name
675        }))
676    }
677
678    /// Helper method for sending a type to a given socket
679    fn send_ty_to_socket<T: 'static + Serialize + Clone>(
680        this: NonNull<SocketHeader>,
681        t: &T,
682        hdr: &Header,
683        seq_no: &mut u16,
684    ) -> Result<(), NetStackSendError> {
685        let vtable: &'static SocketVTable = {
686            let skt_ref = unsafe { this.as_ref() };
687            skt_ref.vtable
688        };
689
690        if let Some(f) = vtable.recv_owned {
691            let this: NonNull<()> = this.cast();
692            let that: NonNull<T> = NonNull::from(t);
693            let that: NonNull<()> = that.cast();
694            let hdr = hdr.to_headerseq_or_with_seq(|| {
695                let seq = *seq_no;
696                *seq_no = seq_no.wrapping_add(1);
697                seq
698            });
699            (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
700        } else if let Some(_f) = vtable.recv_bor {
701            // TODO: support send borrowed
702            todo!()
703        } else {
704            // todo: keep going? If we found the "right" destination and
705            // sending fails, then there's not much we can do. Probably: there
706            // is no case where a socket has NEITHER send_owned NOR send_bor,
707            // can we make this state impossible instead?
708            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
709        }
710    }
711
712    /// Helper method for sending a type to a given socket
713    fn send_err_to_socket(
714        this: NonNull<SocketHeader>,
715        err: ProtocolError,
716        hdr: &Header,
717        seq_no: &mut u16,
718    ) -> Result<(), NetStackSendError> {
719        let vtable: &'static SocketVTable = {
720            let skt_ref = unsafe { this.as_ref() };
721            skt_ref.vtable
722        };
723
724        if let Some(f) = vtable.recv_err {
725            let this: NonNull<()> = this.cast();
726            let hdr = hdr.to_headerseq_or_with_seq(|| {
727                let seq = *seq_no;
728                *seq_no = seq_no.wrapping_add(1);
729                seq
730            });
731            (f)(this, hdr, err);
732            Ok(())
733        } else {
734            // todo: keep going? If we found the "right" destination and
735            // sending fails, then there's not much we can do. Probably: there
736            // is no case where a socket has NEITHER send_owned NOR send_bor,
737            // can we make this state impossible instead?
738            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
739        }
740    }
741
742    /// Helper message for sending a raw message to a given socket
743    fn send_raw_to_socket(
744        this: NonNull<SocketHeader>,
745        body: &[u8],
746        hdr: &Header,
747        hdr_raw: &[u8],
748        seq_no: &mut u16,
749    ) -> Result<(), NetStackSendError> {
750        let vtable: &'static SocketVTable = {
751            let skt_ref = unsafe { this.as_ref() };
752            skt_ref.vtable
753        };
754        let f = vtable.recv_raw;
755
756        let this: NonNull<()> = this.cast();
757        let hdr = hdr.to_headerseq_or_with_seq(|| {
758            let seq = *seq_no;
759            *seq_no = seq_no.wrapping_add(1);
760            seq
761        });
762
763        (f)(this, body, hdr, hdr_raw).map_err(NetStackSendError::SocketSend)
764    }
765}
766
767impl<P> NetStackInner<P>
768where
769    P: Profile,
770{
771    /// Cache-based allocator inspired by littlefs2 ID allocator
772    ///
773    /// We remember 32 ports at a time, from the current base, which is always
774    /// a multiple of 32. Allocating from this range does not require moving thru
775    /// the socket lists.
776    ///
777    /// If the current 32 ports are all taken, we will start over from a base port
778    /// of 0, and attempt to
779    fn alloc_port(&mut self) -> Option<u8> {
780        // ports 0 is always taken (could be clear on first alloc)
781        self.pcache_bits |= (self.pcache_start == 0) as u32;
782
783        if self.pcache_bits != u32::MAX {
784            // We can allocate from the current slot
785            let ldg = self.pcache_bits.trailing_ones();
786            debug_assert!(ldg < 32);
787            self.pcache_bits |= 1 << ldg;
788            return Some(self.pcache_start + (ldg as u8));
789        }
790
791        // Nope, cache is all taken. try to find a base with available items.
792        // We always start from the bottom to keep ports small, but if we know
793        // we just exhausted a range, don't waste time checking that
794        let old_start = self.pcache_start;
795        for base in 0..8 {
796            let start = base * 32;
797            if start == old_start {
798                continue;
799            }
800            // Clear/reset cache
801            self.pcache_start = start;
802            self.pcache_bits = 0;
803            // port 0 is not allowed
804            self.pcache_bits |= (self.pcache_start == 0) as u32;
805            // port 255 is not allowed
806            self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
807
808            // TODO: If we trust that sockets are always sorted, we could early-return
809            // when we reach a `pupper > self.pcache_start`. We could also maybe be smart
810            // and iterate forwards for 0..4 and backwards for 4..8 (and switch the early
811            // return check to < instead). NOTE: We currently do NOT guarantee sockets are
812            // sorted!
813            self.sockets.iter().for_each(|s| {
814                if s.port == 255 {
815                    return;
816                }
817
818                // The upper 3 bits of the port
819                let pupper = s.port & !(32 - 1);
820                // The lower 5 bits of the port
821                let plower = s.port & (32 - 1);
822
823                if pupper == self.pcache_start {
824                    self.pcache_bits |= 1 << plower;
825                }
826            });
827
828            if self.pcache_bits != u32::MAX {
829                // We can allocate from the current slot
830                let ldg = self.pcache_bits.trailing_ones();
831                debug_assert!(ldg < 32);
832                self.pcache_bits |= 1 << ldg;
833                return Some(self.pcache_start + (ldg as u8));
834            }
835        }
836
837        // Nope, nothing found
838        None
839    }
840
841    fn free_port(&mut self, port: u8) {
842        debug_assert!(port != 255);
843        // The upper 3 bits of the port
844        let pupper = port & !(32 - 1);
845        // The lower 5 bits of the port
846        let plower = port & (32 - 1);
847
848        // TODO: If the freed port is in the 0..32 range, or just less than
849        // the current start range, maybe do an opportunistic re-look?
850        if pupper == self.pcache_start {
851            self.pcache_bits &= !(1 << plower);
852        }
853    }
854}
855
856impl NetStackSendError {
857    pub fn to_error(&self) -> ProtocolError {
858        match self {
859            NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
860            NetStackSendError::InterfaceSend(interface_send_error) => {
861                interface_send_error.to_error()
862            }
863            NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
864            NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
865            NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
866            NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
867            NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
868        }
869    }
870}
871
872#[cfg(test)]
873mod test {
874    use core::pin::pin;
875    use mutex::raw_impls::cs::CriticalSectionRawMutex;
876    use std::thread::JoinHandle;
877    use tokio::sync::oneshot;
878
879    use crate::{
880        FrameKind, Key, NetStack,
881        interface_manager::profiles::null::Null,
882        socket::{Attributes, owned::single::Socket},
883    };
884
885    #[test]
886    fn port_alloc() {
887        static STACK: NetStack<CriticalSectionRawMutex, Null> = NetStack::new();
888
889        let mut v = vec![];
890
891        fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
892            let (txdone, rxdone) = oneshot::channel();
893            let (txwait, rxwait) = oneshot::channel();
894            let hdl = std::thread::spawn(move || {
895                let skt = Socket::<u64, &_>::new(
896                    &STACK,
897                    Key(*b"TEST1234"),
898                    Attributes {
899                        kind: FrameKind::ENDPOINT_REQ,
900                        discoverable: true,
901                    },
902                    None,
903                );
904                let skt = pin!(skt);
905                let hdl = skt.attach();
906                assert_eq!(hdl.port(), id);
907                txwait.send(()).unwrap();
908                let _: () = rxdone.blocking_recv().unwrap();
909            });
910            let _ = rxwait.blocking_recv();
911            (id, hdl, txdone)
912        }
913
914        // make sockets 1..32
915        for i in 1..32 {
916            v.push(spawn_skt(i));
917        }
918
919        // make sockets 32..40
920        for i in 32..40 {
921            v.push(spawn_skt(i));
922        }
923
924        // drop socket 35
925        let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
926        let (_i, hdl, tx) = v.remove(pos);
927        tx.send(()).unwrap();
928        hdl.join().unwrap();
929
930        // make a new socket, it should be 35
931        v.push(spawn_skt(35));
932
933        // drop socket 4
934        let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
935        let (_i, hdl, tx) = v.remove(pos);
936        tx.send(()).unwrap();
937        hdl.join().unwrap();
938
939        // make a new socket, it should be 40
940        v.push(spawn_skt(40));
941
942        // make sockets 41..64
943        for i in 41..64 {
944            v.push(spawn_skt(i));
945        }
946
947        // make a new socket, it should be 4
948        v.push(spawn_skt(4));
949
950        // make sockets 64..255
951        for i in 64..255 {
952            v.push(spawn_skt(i));
953        }
954
955        // drop socket 212
956        let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
957        let (_i, hdl, tx) = v.remove(pos);
958        tx.send(()).unwrap();
959        hdl.join().unwrap();
960
961        // make a new socket, it should be 212
962        v.push(spawn_skt(212));
963
964        // Sockets exhausted (we never see 255)
965        let hdl = std::thread::spawn(move || {
966            let skt = Socket::<u64, &_>::new(
967                &STACK,
968                Key(*b"TEST1234"),
969                Attributes {
970                    kind: FrameKind::ENDPOINT_REQ,
971                    discoverable: true,
972                },
973                None,
974            );
975            let skt = pin!(skt);
976            let hdl = skt.attach();
977            println!("{}", hdl.port());
978        });
979        assert!(hdl.join().is_err());
980
981        for (_i, hdl, tx) in v.drain(..) {
982            tx.send(()).unwrap();
983            hdl.join().unwrap();
984        }
985    }
986}