ergot_base/
net_stack.rs

1//! The Ergot NetStack
2//!
3//! The [`NetStack`] is the core of Ergot. It is intended to be placed
4//! in a `static` variable for the duration of your application.
5//!
6//! The Netstack is used directly for a couple of main responsibilities:
7//!
8//! 1. Sending a message, either from user code, or to deliver/forward messages
9//!    received from an interface
10//! 2. Attaching a socket, allowing the NetStack to route messages to it
11//! 3. Interacting with the [interface manager], in order to add/remove
12//!    interfaces, or obtain other information
13//!
14//! [interface manager]: crate::interface_manager
15//!
16//! In general, interacting with anything contained by the [`NetStack`] requires
17//! locking of the [`BlockingMutex`] which protects the inner contents. This
18//! is used both to allow sharing of the inner contents, but also to allow
19//! `Drop` impls to remove themselves from the stack in a blocking manner.
20
21use core::{any::TypeId, ops::Deref, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29    FrameKind, Header, ProtocolError,
30    interface_manager::{self, InterfaceSendError, Profile},
31    socket::{SocketHeader, SocketSendError, SocketVTable, borser},
32};
33
34/// The Ergot Netstack
35pub struct NetStack<R: ScopedRawMutex, P: Profile> {
36    inner: BlockingMutex<R, NetStackInner<P>>,
37}
38
39pub trait NetStackHandle
40where
41    Self: Sized + Clone,
42{
43    type Target: Deref<Target = NetStack<Self::Mutex, Self::Profile>> + Clone;
44    type Mutex: ScopedRawMutex;
45    type Profile: Profile;
46    fn stack(&self) -> Self::Target;
47}
48
49pub(crate) struct NetStackInner<P: Profile> {
50    sockets: List<SocketHeader>,
51    profile: P,
52    pcache_bits: u32,
53    pcache_start: u8,
54    seq_no: u16,
55}
56
57/// An error from calling a [`NetStack`] "send" method
58#[derive(Debug, PartialEq, Eq)]
59#[non_exhaustive]
60pub enum NetStackSendError {
61    SocketSend(SocketSendError),
62    InterfaceSend(InterfaceSendError),
63    NoRoute,
64    AnyPortMissingKey,
65    WrongPortKind,
66    AnyPortNotUnique,
67    AllPortMissingKey,
68}
69
70// ---- impl NetStack ----
71
72// TODO: Impl for Arc
73impl<R, P> NetStackHandle for &'_ NetStack<R, P>
74where
75    R: ScopedRawMutex,
76    P: Profile,
77{
78    type Mutex = R;
79    type Profile = P;
80    type Target = Self;
81
82    fn stack(&self) -> Self::Target {
83        self
84    }
85}
86
87#[cfg(feature = "std")]
88impl<R, P> NetStackHandle for std::sync::Arc<NetStack<R, P>>
89where
90    R: ScopedRawMutex,
91    P: Profile,
92{
93    type Mutex = R;
94    type Profile = P;
95    type Target = Self;
96
97    fn stack(&self) -> Self::Target {
98        self.clone()
99    }
100}
101
102impl<R, P> NetStack<R, P>
103where
104    R: ScopedRawMutex + ConstInit,
105    P: Profile + interface_manager::ConstInit,
106{
107    /// Create a new, uninitialized [`NetStack`].
108    ///
109    /// Requires that the [`ScopedRawMutex`] implements the [`mutex::ConstInit`]
110    /// trait, and the [`Profile`] implements the
111    /// [`interface_manager::ConstInit`] trait.
112    ///
113    /// ## Example
114    ///
115    /// ```rust
116    /// use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
117    /// use ergot_base::NetStack;
118    /// use ergot_base::interface_manager::profiles::null::Null;
119    ///
120    /// static STACK: NetStack<CSRMutex, Null> = NetStack::new();
121    /// ```
122    pub const fn new() -> Self {
123        Self {
124            inner: BlockingMutex::new(NetStackInner::new()),
125        }
126    }
127}
128
129impl<R, P> NetStack<R, P>
130where
131    R: ScopedRawMutex + ConstInit,
132    P: Profile,
133{
134    pub const fn new_with_profile(p: P) -> Self {
135        Self {
136            inner: BlockingMutex::new(NetStackInner::new_with_profile(p)),
137        }
138    }
139}
140
141#[cfg(feature = "std")]
142impl<R, P> NetStack<R, P>
143where
144    R: ScopedRawMutex + ConstInit,
145    P: Profile,
146{
147    pub fn new_arc(p: P) -> std::sync::Arc<Self> {
148        std::sync::Arc::new(Self {
149            inner: BlockingMutex::new(NetStackInner::new_with_profile(p)),
150        })
151    }
152}
153
154impl<R, P> NetStack<R, P>
155where
156    R: ScopedRawMutex,
157    P: Profile,
158{
159    /// Manually create a new, uninitialized [`NetStack`].
160    ///
161    /// This method is useful if your [`ScopedRawMutex`] or [`Profile`]
162    /// do not implement their corresponding `ConstInit` trait.
163    ///
164    /// In general, this is most often only needed for `loom` testing, and
165    /// [`NetStack::new()`] should be used when possible.
166    pub const fn const_new(r: R, p: P) -> Self {
167        Self {
168            inner: BlockingMutex::const_new(
169                r,
170                NetStackInner {
171                    sockets: List::new(),
172                    profile: p,
173                    seq_no: 0,
174                    pcache_start: 0,
175                    pcache_bits: 0,
176                },
177            ),
178        }
179    }
180
181    /// Access the contained [`Profile`].
182    ///
183    /// Access to the [`Profile`] is made via the provided closure.
184    /// The [`BlockingMutex`] is locked for the duration of this access,
185    /// inhibiting all other usage of this [`NetStack`].
186    ///
187    /// This can be used to add new interfaces, obtain metadata, or other
188    /// actions supported by the chosen [`Profile`].
189    ///
190    /// ## Example
191    ///
192    /// ```rust
193    /// # use mutex::raw_impls::cs::CriticalSectionRawMutex as CSRMutex;
194    /// # use ergot_base::NetStack;
195    /// # use ergot_base::interface_manager::profiles::null::Null;
196    /// #
197    /// static STACK: NetStack<CSRMutex, Null> = NetStack::new();
198    ///
199    /// let res = STACK.manage_profile(|im| {
200    ///    // The mutex is locked for the full duration of this closure.
201    ///    # _ = im;
202    ///    // We can return whatever we want from this context, though not
203    ///    // anything borrowed from `im`.
204    ///    42
205    /// });
206    /// assert_eq!(res, 42);
207    /// ```
208    pub fn manage_profile<F: FnOnce(&mut P) -> U, U>(&self, f: F) -> U {
209        self.inner.with_lock(|inner| f(&mut inner.profile))
210    }
211
212    /// Send a raw (pre-serialized) message.
213    ///
214    /// This interface should almost never be used by end-users, and is instead
215    /// typically used by interfaces to feed received messages into the
216    /// [`NetStack`].
217    pub fn send_raw(
218        &self,
219        hdr: &Header,
220        hdr_raw: &[u8],
221        body: &[u8],
222    ) -> Result<(), NetStackSendError> {
223        self.inner
224            .with_lock(|inner| inner.send_raw(hdr, hdr_raw, body))
225    }
226
227    /// Send a typed message
228    pub fn send_ty<T: 'static + Serialize + Clone>(
229        &self,
230        hdr: &Header,
231        t: &T,
232    ) -> Result<(), NetStackSendError> {
233        self.inner.with_lock(|inner| inner.send_ty(hdr, t))
234    }
235
236    pub fn send_bor<T: Serialize>(&self, hdr: &Header, t: &T) -> Result<(), NetStackSendError> {
237        self.inner.with_lock(|inner| inner.send_bor(hdr, t))
238    }
239
240    pub fn send_err(&self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
241        self.inner.with_lock(|inner| inner.send_err(hdr, err))
242    }
243
244    pub(crate) unsafe fn try_attach_socket(&self, mut node: NonNull<SocketHeader>) -> Option<u8> {
245        self.inner.with_lock(|inner| {
246            let new_port = inner.alloc_port()?;
247            unsafe {
248                node.as_mut().port = new_port;
249            }
250
251            inner.sockets.push_front(node);
252            Some(new_port)
253        })
254    }
255
256    pub(crate) unsafe fn attach_broadcast_socket(&self, mut node: NonNull<SocketHeader>) {
257        self.inner.with_lock(|inner| {
258            unsafe {
259                node.as_mut().port = 255;
260            }
261            inner.sockets.push_back(node);
262        });
263    }
264
265    pub(crate) unsafe fn attach_socket(&self, node: NonNull<SocketHeader>) -> u8 {
266        let res = unsafe { self.try_attach_socket(node) };
267        let Some(new_port) = res else {
268            panic!("exhausted all addrs");
269        };
270        new_port
271    }
272
273    pub(crate) unsafe fn detach_socket(&self, node: NonNull<SocketHeader>) {
274        self.inner.with_lock(|inner| unsafe {
275            let port = node.as_ref().port;
276            if port != 255 {
277                inner.free_port(port);
278            }
279            inner.sockets.remove(node)
280        });
281    }
282
283    pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&self, f: F) -> U {
284        self.inner.with_lock(|_inner| f())
285    }
286}
287
288impl<R, P> Default for NetStack<R, P>
289where
290    R: ScopedRawMutex + ConstInit,
291    P: Profile + interface_manager::ConstInit,
292{
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298// ---- impl NetStackInner ----
299
300impl<P> NetStackInner<P>
301where
302    P: Profile,
303    P: interface_manager::ConstInit,
304{
305    pub const fn new() -> Self {
306        Self {
307            sockets: List::new(),
308            profile: P::INIT,
309            seq_no: 0,
310            pcache_bits: 0,
311            pcache_start: 0,
312        }
313    }
314}
315
316impl<P> NetStackInner<P>
317where
318    P: Profile,
319{
320    /// Create a netstack with a given profile
321    pub const fn new_with_profile(p: P) -> Self {
322        Self {
323            sockets: List::new(),
324            profile: p,
325            seq_no: 0,
326            pcache_bits: 0,
327            pcache_start: 0,
328        }
329    }
330
331    /// Method that handles broadcast logic
332    ///
333    /// Takes closures for sending to a socket or sending to the manager to allow
334    /// for abstracting over send_raw/send_ty.
335    fn broadcast<SendSocket, SendMgr>(
336        sockets: &mut List<SocketHeader>,
337        hdr: &Header,
338        mut sskt: SendSocket,
339        smgr: SendMgr,
340    ) -> Result<(), NetStackSendError>
341    where
342        SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
343        SendMgr: FnOnce() -> bool,
344    {
345        trace!("Sending msg broadcast w/ header: {hdr:?}");
346        let res_lcl = {
347            let bcast_iter = Self::find_all_local(sockets, hdr)?;
348            let mut any_found = false;
349            for dst in bcast_iter {
350                let res = sskt(dst);
351                if res {
352                    debug!("delivered broadcast message locally");
353                }
354                any_found |= res;
355            }
356            any_found
357        };
358
359        let res_rmt = smgr();
360        if res_rmt {
361            debug!("delivered broadcast message remotely");
362        }
363
364        if res_lcl || res_rmt {
365            Ok(())
366        } else {
367            Err(NetStackSendError::NoRoute)
368        }
369    }
370
371    /// Method that handles unicast logic
372    ///
373    /// Takes closures for sending to a socket or sending to the manager to allow
374    /// for abstracting over send_raw/send_ty.
375    fn unicast<SendSocket, SendMgr>(
376        sockets: &mut List<SocketHeader>,
377        hdr: &Header,
378        sskt: SendSocket,
379        smgr: SendMgr,
380    ) -> Result<(), NetStackSendError>
381    where
382        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
383        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
384    {
385        trace!("Sending msg unicast w/ header: {hdr:?}");
386        // Can we assume the destination is local?
387        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
388
389        let res = if !local_bypass {
390            // Not local: offer to the interface manager to send
391            debug!("Offering msg externally unicast w/ header: {hdr:?}");
392            smgr()
393        } else {
394            // just skip to local sending
395            Err(InterfaceSendError::DestinationLocal)
396        };
397
398        match res {
399            Ok(()) => {
400                debug!("Externally routed msg unicast");
401                return Ok(());
402            }
403            Err(InterfaceSendError::DestinationLocal) => {
404                debug!("No external interest in msg unicast");
405            }
406            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
407        }
408
409        // It was a destination local error, try to honor that
410        let socket = if hdr.dst.port_id == 0 {
411            debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
412            Self::find_any_local(sockets, hdr)
413        } else {
414            debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
415            Self::find_one_local(sockets, hdr)
416        }?;
417
418        sskt(socket)
419    }
420
421    /// Method that handles unicast logic
422    ///
423    /// Takes closures for sending to a socket or sending to the manager to allow
424    /// for abstracting over send_raw/send_ty.
425    fn unicast_err<SendSocket, SendMgr>(
426        sockets: &mut List<SocketHeader>,
427        hdr: &Header,
428        sskt: SendSocket,
429        smgr: SendMgr,
430    ) -> Result<(), NetStackSendError>
431    where
432        SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
433        SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
434    {
435        trace!("Sending err unicast w/ header: {hdr:?}");
436        // Can we assume the destination is local?
437        let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
438
439        let res = if !local_bypass {
440            // Not local: offer to the interface manager to send
441            debug!("Offering err externally unicast w/ header: {hdr:?}");
442            smgr()
443        } else {
444            // just skip to local sending
445            Err(InterfaceSendError::DestinationLocal)
446        };
447
448        match res {
449            Ok(()) => {
450                debug!("Externally routed err unicast");
451                return Ok(());
452            }
453            Err(InterfaceSendError::DestinationLocal) => {
454                debug!("No external interest in err unicast");
455            }
456            Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
457        }
458
459        // It was a destination local error, try to honor that
460        let socket = Self::find_one_err_local(sockets, hdr)?;
461
462        sskt(socket)
463    }
464
465    /// Handle sending of a raw (serialized) message
466    fn send_raw(
467        &mut self,
468        hdr: &Header,
469        hdr_raw: &[u8],
470        body: &[u8],
471    ) -> Result<(), NetStackSendError> {
472        let Self {
473            sockets,
474            seq_no,
475            profile: manager,
476            ..
477        } = self;
478        trace!("Sending msg raw w/ header: {hdr:?}");
479
480        if hdr.kind == FrameKind::PROTOCOL_ERROR {
481            todo!("Don't do that");
482        }
483
484        // Is this a broadcast message?
485        if hdr.dst.port_id == 255 {
486            Self::broadcast(
487                sockets,
488                hdr,
489                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no).is_ok(),
490                || manager.send_raw(hdr, hdr_raw, body).is_ok(),
491            )
492        } else {
493            Self::unicast(
494                sockets,
495                hdr,
496                |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no),
497                || manager.send_raw(hdr, hdr_raw, body),
498            )
499        }
500    }
501
502    /// Handle sending of a typed message
503    fn send_ty<T: 'static + Serialize + Clone>(
504        &mut self,
505        hdr: &Header,
506        t: &T,
507    ) -> Result<(), NetStackSendError> {
508        let Self {
509            sockets,
510            seq_no,
511            profile: manager,
512            ..
513        } = self;
514        trace!("Sending msg ty w/ header: {hdr:?}");
515
516        if hdr.kind == FrameKind::PROTOCOL_ERROR {
517            todo!("Don't do that");
518        }
519
520        // Is this a broadcast message?
521        if hdr.dst.port_id == 255 {
522            Self::broadcast(
523                sockets,
524                hdr,
525                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
526                || manager.send(hdr, t).is_ok(),
527            )
528        } else {
529            Self::unicast(
530                sockets,
531                hdr,
532                |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
533                || manager.send(hdr, t),
534            )
535        }
536    }
537
538    /// Handle sending a borrowed message
539    fn send_bor<T: Serialize>(&mut self, hdr: &Header, t: &T) -> Result<(), NetStackSendError> {
540        let Self {
541            sockets,
542            seq_no,
543            profile: manager,
544            ..
545        } = self;
546        trace!("Sending msg bor w/ header: {hdr:?}");
547
548        if hdr.kind == FrameKind::PROTOCOL_ERROR {
549            todo!("Don't do that");
550        }
551
552        // Is this a broadcast message?
553        if hdr.dst.port_id == 255 {
554            Self::broadcast(
555                sockets,
556                hdr,
557                |skt| Self::send_bor_to_socket(skt, t, hdr, seq_no).is_ok(),
558                || manager.send(hdr, t).is_ok(),
559            )
560        } else {
561            Self::unicast(
562                sockets,
563                hdr,
564                |skt| Self::send_bor_to_socket(skt, t, hdr, seq_no),
565                || manager.send(hdr, t),
566            )
567        }
568    }
569
570    /// Handle sending of a typed message
571    fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
572        let Self {
573            sockets,
574            seq_no,
575            profile: manager,
576            ..
577        } = self;
578        trace!("Sending msg err w/ header: {hdr:?}");
579
580        if hdr.dst.port_id == 255 {
581            todo!("Don't do that");
582        }
583
584        Self::unicast_err(
585            sockets,
586            hdr,
587            |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
588            || manager.send_err(hdr, err),
589        )
590    }
591
592    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
593    /// the given header.
594    fn find_one_local(
595        sockets: &mut List<SocketHeader>,
596        hdr: &Header,
597    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
598        // Find the specific matching port
599        let mut iter = sockets.iter_raw();
600        let socket = loop {
601            let Some(skt) = iter.next() else {
602                return Err(NetStackSendError::NoRoute);
603            };
604            let skt_ref = unsafe { skt.as_ref() };
605            if skt_ref.port != hdr.dst.port_id {
606                continue;
607            }
608            if skt_ref.attrs.kind != hdr.kind {
609                return Err(NetStackSendError::WrongPortKind);
610            }
611            break skt;
612        };
613        Ok(socket)
614    }
615
616    /// Find a specific (e.g. port_id not 0 or 255) destination port matching
617    /// the given header.
618    fn find_one_err_local(
619        sockets: &mut List<SocketHeader>,
620        hdr: &Header,
621    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
622        // Find the specific matching port
623        let mut iter = sockets.iter_raw();
624        let socket = loop {
625            let Some(skt) = iter.next() else {
626                return Err(NetStackSendError::NoRoute);
627            };
628            let skt_ref = unsafe { skt.as_ref() };
629            if skt_ref.port != hdr.dst.port_id {
630                continue;
631            }
632            break skt;
633        };
634        Ok(socket)
635    }
636
637    /// Find a wildcard (e.g. port_id == 0) destination port matching the given header.
638    ///
639    /// If more than one port matches the wildcard, an error is returned.
640    /// Does not match sockets that does not have the `discoverable` [`Attributes`].
641    fn find_any_local(
642        sockets: &mut List<SocketHeader>,
643        hdr: &Header,
644    ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
645        // Find ONE specific matching port
646        let Some(apdx) = hdr.any_all.as_ref() else {
647            return Err(NetStackSendError::AnyPortMissingKey);
648        };
649        let mut iter = sockets.iter_raw();
650        let mut socket: Option<NonNull<SocketHeader>> = None;
651
652        loop {
653            let Some(skt) = iter.next() else {
654                break;
655            };
656            let skt_ref = unsafe { skt.as_ref() };
657
658            // Check for things that would disqualify a socket from being an
659            // "ANY" destination
660            let mut illegal = false;
661            illegal |= skt_ref.attrs.kind != hdr.kind;
662            illegal |= !skt_ref.attrs.discoverable;
663            illegal |= skt_ref.key != apdx.key;
664            if let Some(nash) = apdx.nash {
665                illegal |= Some(nash) != skt_ref.nash;
666            }
667
668            if illegal {
669                // Wait, that's illegal
670                continue;
671            }
672
673            // It's a match! Is it a second match?
674            if socket.is_some() {
675                return Err(NetStackSendError::AnyPortNotUnique);
676            }
677            // Nope! Store this one, then we keep going to ensure that no
678            // other socket matches this description.
679            socket = Some(skt);
680        }
681
682        socket.ok_or(NetStackSendError::NoRoute)
683    }
684
685    /// Find ALL broadcast (e.g. port_id == 255) sockets matching the given header.
686    ///
687    /// Returns an error if the header does not contain a Key. May return zero
688    /// matches.
689    fn find_all_local(
690        sockets: &mut List<SocketHeader>,
691        hdr: &Header,
692    ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
693        let Some(any_all) = hdr.any_all.as_ref() else {
694            return Err(NetStackSendError::AllPortMissingKey);
695        };
696        Ok(sockets.iter_raw().filter(move |socket| {
697            let skt_ref = unsafe { socket.as_ref() };
698            let bport = skt_ref.port == 255;
699            let dkind = skt_ref.attrs.kind == hdr.kind;
700            let dkey = skt_ref.key == any_all.key;
701
702            // If the any/all message DOES contain a name hash, then ONLY match
703            // sockets with the same name hash.
704            let name = if let Some(nash) = any_all.nash {
705                Some(nash) == skt_ref.nash
706            } else {
707                true
708            };
709            bport && dkind && dkey && name
710        }))
711    }
712
713    /// Helper method for sending a type to a given socket
714    fn send_ty_to_socket<T: 'static + Serialize + Clone>(
715        this: NonNull<SocketHeader>,
716        t: &T,
717        hdr: &Header,
718        seq_no: &mut u16,
719    ) -> Result<(), NetStackSendError> {
720        let vtable: &'static SocketVTable = {
721            let skt_ref = unsafe { this.as_ref() };
722            skt_ref.vtable
723        };
724
725        if let Some(f) = vtable.recv_owned {
726            let this: NonNull<()> = this.cast();
727            let that: NonNull<T> = NonNull::from(t);
728            let that: NonNull<()> = that.cast();
729            let hdr = hdr.to_headerseq_or_with_seq(|| {
730                let seq = *seq_no;
731                *seq_no = seq_no.wrapping_add(1);
732                seq
733            });
734            (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
735        } else if let Some(_f) = vtable.recv_bor {
736            // TODO: support send borrowed
737            todo!()
738        } else {
739            // todo: keep going? If we found the "right" destination and
740            // sending fails, then there's not much we can do. Probably: there
741            // is no case where a socket has NEITHER send_owned NOR send_bor,
742            // can we make this state impossible instead?
743            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
744        }
745    }
746
747    /// Helper method for sending a type to a given socket
748    fn send_bor_to_socket<T: Serialize>(
749        this: NonNull<SocketHeader>,
750        t: &T,
751        hdr: &Header,
752        seq_no: &mut u16,
753    ) -> Result<(), NetStackSendError> {
754        let vtable: &'static SocketVTable = {
755            let skt_ref = unsafe { this.as_ref() };
756            skt_ref.vtable
757        };
758
759        if let Some(f) = vtable.recv_bor {
760            let this: NonNull<()> = this.cast();
761            let that: NonNull<T> = NonNull::from(t);
762            let that: NonNull<()> = that.cast();
763            let hdr = hdr.to_headerseq_or_with_seq(|| {
764                let seq = *seq_no;
765                *seq_no = seq_no.wrapping_add(1);
766                seq
767            });
768            let func = borser::<T>;
769            let res = (f)(this, that, hdr, func).map_err(NetStackSendError::SocketSend);
770            res
771        } else {
772            // todo: keep going? If we found the "right" destination and
773            // sending fails, then there's not much we can do. Probably: there
774            // is no case where a socket has NEITHER send_owned NOR send_bor,
775            // can we make this state impossible instead?
776            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
777        }
778    }
779
780    /// Helper method for sending a type to a given socket
781    fn send_err_to_socket(
782        this: NonNull<SocketHeader>,
783        err: ProtocolError,
784        hdr: &Header,
785        seq_no: &mut u16,
786    ) -> Result<(), NetStackSendError> {
787        let vtable: &'static SocketVTable = {
788            let skt_ref = unsafe { this.as_ref() };
789            skt_ref.vtable
790        };
791
792        if let Some(f) = vtable.recv_err {
793            let this: NonNull<()> = this.cast();
794            let hdr = hdr.to_headerseq_or_with_seq(|| {
795                let seq = *seq_no;
796                *seq_no = seq_no.wrapping_add(1);
797                seq
798            });
799            (f)(this, hdr, err);
800            Ok(())
801        } else {
802            // todo: keep going? If we found the "right" destination and
803            // sending fails, then there's not much we can do. Probably: there
804            // is no case where a socket has NEITHER send_owned NOR send_bor,
805            // can we make this state impossible instead?
806            Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
807        }
808    }
809
810    /// Helper message for sending a raw message to a given socket
811    fn send_raw_to_socket(
812        this: NonNull<SocketHeader>,
813        body: &[u8],
814        hdr: &Header,
815        hdr_raw: &[u8],
816        seq_no: &mut u16,
817    ) -> Result<(), NetStackSendError> {
818        let vtable: &'static SocketVTable = {
819            let skt_ref = unsafe { this.as_ref() };
820            skt_ref.vtable
821        };
822        let f = vtable.recv_raw;
823
824        let this: NonNull<()> = this.cast();
825        let hdr = hdr.to_headerseq_or_with_seq(|| {
826            let seq = *seq_no;
827            *seq_no = seq_no.wrapping_add(1);
828            seq
829        });
830
831        (f)(this, body, hdr, hdr_raw).map_err(NetStackSendError::SocketSend)
832    }
833}
834
835impl<P> NetStackInner<P>
836where
837    P: Profile,
838{
839    /// Cache-based allocator inspired by littlefs2 ID allocator
840    ///
841    /// We remember 32 ports at a time, from the current base, which is always
842    /// a multiple of 32. Allocating from this range does not require moving thru
843    /// the socket lists.
844    ///
845    /// If the current 32 ports are all taken, we will start over from a base port
846    /// of 0, and attempt to
847    fn alloc_port(&mut self) -> Option<u8> {
848        // ports 0 is always taken (could be clear on first alloc)
849        self.pcache_bits |= (self.pcache_start == 0) as u32;
850
851        if self.pcache_bits != u32::MAX {
852            // We can allocate from the current slot
853            let ldg = self.pcache_bits.trailing_ones();
854            debug_assert!(ldg < 32);
855            self.pcache_bits |= 1 << ldg;
856            return Some(self.pcache_start + (ldg as u8));
857        }
858
859        // Nope, cache is all taken. try to find a base with available items.
860        // We always start from the bottom to keep ports small, but if we know
861        // we just exhausted a range, don't waste time checking that
862        let old_start = self.pcache_start;
863        for base in 0..8 {
864            let start = base * 32;
865            if start == old_start {
866                continue;
867            }
868            // Clear/reset cache
869            self.pcache_start = start;
870            self.pcache_bits = 0;
871            // port 0 is not allowed
872            self.pcache_bits |= (self.pcache_start == 0) as u32;
873            // port 255 is not allowed
874            self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
875
876            // TODO: If we trust that sockets are always sorted, we could early-return
877            // when we reach a `pupper > self.pcache_start`. We could also maybe be smart
878            // and iterate forwards for 0..4 and backwards for 4..8 (and switch the early
879            // return check to < instead). NOTE: We currently do NOT guarantee sockets are
880            // sorted!
881            self.sockets.iter().for_each(|s| {
882                if s.port == 255 {
883                    return;
884                }
885
886                // The upper 3 bits of the port
887                let pupper = s.port & !(32 - 1);
888                // The lower 5 bits of the port
889                let plower = s.port & (32 - 1);
890
891                if pupper == self.pcache_start {
892                    self.pcache_bits |= 1 << plower;
893                }
894            });
895
896            if self.pcache_bits != u32::MAX {
897                // We can allocate from the current slot
898                let ldg = self.pcache_bits.trailing_ones();
899                debug_assert!(ldg < 32);
900                self.pcache_bits |= 1 << ldg;
901                return Some(self.pcache_start + (ldg as u8));
902            }
903        }
904
905        // Nope, nothing found
906        None
907    }
908
909    fn free_port(&mut self, port: u8) {
910        debug_assert!(port != 255);
911        // The upper 3 bits of the port
912        let pupper = port & !(32 - 1);
913        // The lower 5 bits of the port
914        let plower = port & (32 - 1);
915
916        // TODO: If the freed port is in the 0..32 range, or just less than
917        // the current start range, maybe do an opportunistic re-look?
918        if pupper == self.pcache_start {
919            self.pcache_bits &= !(1 << plower);
920        }
921    }
922}
923
924impl NetStackSendError {
925    pub fn to_error(&self) -> ProtocolError {
926        match self {
927            NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
928            NetStackSendError::InterfaceSend(interface_send_error) => {
929                interface_send_error.to_error()
930            }
931            NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
932            NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
933            NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
934            NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
935            NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
936        }
937    }
938}
939
940#[cfg(test)]
941mod test {
942    use core::pin::pin;
943    use mutex::raw_impls::cs::CriticalSectionRawMutex;
944    use std::thread::JoinHandle;
945    use tokio::sync::oneshot;
946
947    use crate::{
948        FrameKind, Key, NetStack,
949        interface_manager::profiles::null::Null,
950        socket::{Attributes, owned::single::Socket},
951    };
952
953    #[test]
954    fn port_alloc() {
955        static STACK: NetStack<CriticalSectionRawMutex, Null> = NetStack::new();
956
957        let mut v = vec![];
958
959        fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
960            let (txdone, rxdone) = oneshot::channel();
961            let (txwait, rxwait) = oneshot::channel();
962            let hdl = std::thread::spawn(move || {
963                let skt = Socket::<u64, &_>::new(
964                    &STACK,
965                    Key(*b"TEST1234"),
966                    Attributes {
967                        kind: FrameKind::ENDPOINT_REQ,
968                        discoverable: true,
969                    },
970                    None,
971                );
972                let skt = pin!(skt);
973                let hdl = skt.attach();
974                assert_eq!(hdl.port(), id);
975                txwait.send(()).unwrap();
976                let _: () = rxdone.blocking_recv().unwrap();
977            });
978            let _ = rxwait.blocking_recv();
979            (id, hdl, txdone)
980        }
981
982        // make sockets 1..32
983        for i in 1..32 {
984            v.push(spawn_skt(i));
985        }
986
987        // make sockets 32..40
988        for i in 32..40 {
989            v.push(spawn_skt(i));
990        }
991
992        // drop socket 35
993        let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
994        let (_i, hdl, tx) = v.remove(pos);
995        tx.send(()).unwrap();
996        hdl.join().unwrap();
997
998        // make a new socket, it should be 35
999        v.push(spawn_skt(35));
1000
1001        // drop socket 4
1002        let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
1003        let (_i, hdl, tx) = v.remove(pos);
1004        tx.send(()).unwrap();
1005        hdl.join().unwrap();
1006
1007        // make a new socket, it should be 40
1008        v.push(spawn_skt(40));
1009
1010        // make sockets 41..64
1011        for i in 41..64 {
1012            v.push(spawn_skt(i));
1013        }
1014
1015        // make a new socket, it should be 4
1016        v.push(spawn_skt(4));
1017
1018        // make sockets 64..255
1019        for i in 64..255 {
1020            v.push(spawn_skt(i));
1021        }
1022
1023        // drop socket 212
1024        let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
1025        let (_i, hdl, tx) = v.remove(pos);
1026        tx.send(()).unwrap();
1027        hdl.join().unwrap();
1028
1029        // make a new socket, it should be 212
1030        v.push(spawn_skt(212));
1031
1032        // Sockets exhausted (we never see 255)
1033        let hdl = std::thread::spawn(move || {
1034            let skt = Socket::<u64, &_>::new(
1035                &STACK,
1036                Key(*b"TEST1234"),
1037                Attributes {
1038                    kind: FrameKind::ENDPOINT_REQ,
1039                    discoverable: true,
1040                },
1041                None,
1042            );
1043            let skt = pin!(skt);
1044            let hdl = skt.attach();
1045            println!("{}", hdl.port());
1046        });
1047        assert!(hdl.join().is_err());
1048
1049        for (_i, hdl, tx) in v.drain(..) {
1050            tx.send(()).unwrap();
1051            hdl.join().unwrap();
1052        }
1053    }
1054}