1use core::{any::TypeId, ops::Deref, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29 FrameKind, Header, ProtocolError,
30 interface_manager::{self, InterfaceManager, InterfaceSendError},
31 socket::{SocketHeader, SocketSendError, SocketVTable},
32};
33
34pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
36 inner: BlockingMutex<R, NetStackInner<M>>,
37}
38
39pub trait NetStackHandle
40where
41 Self: Sized,
42{
43 type Target: Deref<Target = NetStack<Self::Mutex, Self::Interface>> + Clone;
44 type Mutex: ScopedRawMutex;
45 type Interface: InterfaceManager;
46 fn stack(&self) -> Self::Target;
47}
48
49pub(crate) struct NetStackInner<M: InterfaceManager> {
50 sockets: List<SocketHeader>,
51 manager: M,
52 pcache_bits: u32,
53 pcache_start: u8,
54 seq_no: u16,
55}
56
57#[derive(Debug, PartialEq, Eq)]
59#[non_exhaustive]
60pub enum NetStackSendError {
61 SocketSend(SocketSendError),
62 InterfaceSend(InterfaceSendError),
63 NoRoute,
64 AnyPortMissingKey,
65 WrongPortKind,
66 AnyPortNotUnique,
67 AllPortMissingKey,
68}
69
70impl<R, M> NetStackHandle for &'_ NetStack<R, M>
74where
75 R: ScopedRawMutex,
76 M: InterfaceManager,
77{
78 type Mutex = R;
79 type Interface = M;
80 type Target = Self;
81
82 fn stack(&self) -> Self::Target {
83 self
84 }
85}
86
87impl<R, M> NetStack<R, M>
88where
89 R: ScopedRawMutex + ConstInit,
90 M: InterfaceManager + interface_manager::ConstInit,
91{
92 pub const fn new() -> Self {
108 Self {
109 inner: BlockingMutex::new(NetStackInner::new()),
110 }
111 }
112}
113
114impl<R, M> NetStack<R, M>
115where
116 R: ScopedRawMutex,
117 M: InterfaceManager,
118{
119 pub const fn const_new(r: R, m: M) -> Self {
127 Self {
128 inner: BlockingMutex::const_new(
129 r,
130 NetStackInner {
131 sockets: List::new(),
132 manager: m,
133 seq_no: 0,
134 pcache_start: 0,
135 pcache_bits: 0,
136 },
137 ),
138 }
139 }
140
141 pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&self, f: F) -> U {
169 self.inner.with_lock(|inner| f(&mut inner.manager))
170 }
171
172 pub fn send_raw(
178 &self,
179 hdr: &Header,
180 hdr_raw: &[u8],
181 body: &[u8],
182 ) -> Result<(), NetStackSendError> {
183 self.inner
184 .with_lock(|inner| inner.send_raw(hdr, hdr_raw, body))
185 }
186
187 pub fn send_ty<T: 'static + Serialize + Clone>(
189 &self,
190 hdr: &Header,
191 t: &T,
192 ) -> Result<(), NetStackSendError> {
193 self.inner.with_lock(|inner| inner.send_ty(hdr, t))
194 }
195
196 pub fn send_err(&self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
197 self.inner.with_lock(|inner| inner.send_err(hdr, err))
198 }
199
200 pub(crate) unsafe fn try_attach_socket(&self, mut node: NonNull<SocketHeader>) -> Option<u8> {
201 self.inner.with_lock(|inner| {
202 let new_port = inner.alloc_port()?;
203 unsafe {
204 node.as_mut().port = new_port;
205 }
206
207 inner.sockets.push_front(node);
208 Some(new_port)
209 })
210 }
211
212 pub(crate) unsafe fn attach_broadcast_socket(&self, mut node: NonNull<SocketHeader>) {
213 self.inner.with_lock(|inner| {
214 unsafe {
215 node.as_mut().port = 255;
216 }
217 inner.sockets.push_back(node);
218 });
219 }
220
221 pub(crate) unsafe fn attach_socket(&self, node: NonNull<SocketHeader>) -> u8 {
222 let res = unsafe { self.try_attach_socket(node) };
223 let Some(new_port) = res else {
224 panic!("exhausted all addrs");
225 };
226 new_port
227 }
228
229 pub(crate) unsafe fn detach_socket(&self, node: NonNull<SocketHeader>) {
230 self.inner.with_lock(|inner| unsafe {
231 let port = node.as_ref().port;
232 if port != 255 {
233 inner.free_port(port);
234 }
235 inner.sockets.remove(node)
236 });
237 }
238
239 pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&self, f: F) -> U {
240 self.inner.with_lock(|_inner| f())
241 }
242}
243
244impl<R, M> Default for NetStack<R, M>
245where
246 R: ScopedRawMutex + ConstInit,
247 M: InterfaceManager + interface_manager::ConstInit,
248{
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl<M> NetStackInner<M>
257where
258 M: InterfaceManager,
259 M: interface_manager::ConstInit,
260{
261 pub const fn new() -> Self {
262 Self {
263 sockets: List::new(),
264 manager: M::INIT,
265 seq_no: 0,
266 pcache_bits: 0,
267 pcache_start: 0,
268 }
269 }
270}
271
272impl<M> NetStackInner<M>
273where
274 M: InterfaceManager,
275{
276 fn broadcast<SendSocket, SendMgr>(
281 sockets: &mut List<SocketHeader>,
282 hdr: &Header,
283 mut sskt: SendSocket,
284 smgr: SendMgr,
285 ) -> Result<(), NetStackSendError>
286 where
287 SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
288 SendMgr: FnOnce() -> bool,
289 {
290 trace!("Sending msg broadcast w/ header: {hdr:?}");
291 let res_lcl = {
292 let bcast_iter = Self::find_all_local(sockets, hdr)?;
293 let mut any_found = false;
294 for dst in bcast_iter {
295 let res = sskt(dst);
296 if res {
297 debug!("delivered broadcast message locally");
298 }
299 any_found |= res;
300 }
301 any_found
302 };
303
304 let res_rmt = smgr();
305 if res_rmt {
306 debug!("delivered broadcast message remotely");
307 }
308
309 if res_lcl || res_rmt {
310 Ok(())
311 } else {
312 Err(NetStackSendError::NoRoute)
313 }
314 }
315
316 fn unicast<SendSocket, SendMgr>(
321 sockets: &mut List<SocketHeader>,
322 hdr: &Header,
323 sskt: SendSocket,
324 smgr: SendMgr,
325 ) -> Result<(), NetStackSendError>
326 where
327 SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
328 SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
329 {
330 trace!("Sending msg unicast w/ header: {hdr:?}");
331 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
333
334 let res = if !local_bypass {
335 debug!("Offering msg externally unicast w/ header: {hdr:?}");
337 smgr()
338 } else {
339 Err(InterfaceSendError::DestinationLocal)
341 };
342
343 match res {
344 Ok(()) => {
345 debug!("Externally routed msg unicast");
346 return Ok(());
347 }
348 Err(InterfaceSendError::DestinationLocal) => {
349 debug!("No external interest in msg unicast");
350 }
351 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
352 }
353
354 let socket = if hdr.dst.port_id == 0 {
356 debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
357 Self::find_any_local(sockets, hdr)
358 } else {
359 debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
360 Self::find_one_local(sockets, hdr)
361 }?;
362
363 sskt(socket)
364 }
365
366 fn unicast_err<SendSocket, SendMgr>(
371 sockets: &mut List<SocketHeader>,
372 hdr: &Header,
373 sskt: SendSocket,
374 smgr: SendMgr,
375 ) -> Result<(), NetStackSendError>
376 where
377 SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
378 SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
379 {
380 trace!("Sending err unicast w/ header: {hdr:?}");
381 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
383
384 let res = if !local_bypass {
385 debug!("Offering err externally unicast w/ header: {hdr:?}");
387 smgr()
388 } else {
389 Err(InterfaceSendError::DestinationLocal)
391 };
392
393 match res {
394 Ok(()) => {
395 debug!("Externally routed err unicast");
396 return Ok(());
397 }
398 Err(InterfaceSendError::DestinationLocal) => {
399 debug!("No external interest in err unicast");
400 }
401 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
402 }
403
404 let socket = Self::find_one_err_local(sockets, hdr)?;
406
407 sskt(socket)
408 }
409
410 fn send_raw(
412 &mut self,
413 hdr: &Header,
414 hdr_raw: &[u8],
415 body: &[u8],
416 ) -> Result<(), NetStackSendError> {
417 let Self {
418 sockets,
419 seq_no,
420 manager,
421 ..
422 } = self;
423 trace!("Sending msg raw w/ header: {hdr:?}");
424
425 if hdr.kind == FrameKind::PROTOCOL_ERROR {
426 todo!("Don't do that");
427 }
428
429 if hdr.dst.port_id == 255 {
431 Self::broadcast(
432 sockets,
433 hdr,
434 |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no).is_ok(),
435 || manager.send_raw(hdr, hdr_raw, body).is_ok(),
436 )
437 } else {
438 Self::unicast(
439 sockets,
440 hdr,
441 |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no),
442 || manager.send_raw(hdr, hdr_raw, body),
443 )
444 }
445 }
446
447 fn send_ty<T: 'static + Serialize + Clone>(
449 &mut self,
450 hdr: &Header,
451 t: &T,
452 ) -> Result<(), NetStackSendError> {
453 let Self {
454 sockets,
455 seq_no,
456 manager,
457 ..
458 } = self;
459 trace!("Sending msg ty w/ header: {hdr:?}");
460
461 if hdr.kind == FrameKind::PROTOCOL_ERROR {
462 todo!("Don't do that");
463 }
464
465 if hdr.dst.port_id == 255 {
467 Self::broadcast(
468 sockets,
469 hdr,
470 |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
471 || manager.send(hdr, t).is_ok(),
472 )
473 } else {
474 Self::unicast(
475 sockets,
476 hdr,
477 |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
478 || manager.send(hdr, t),
479 )
480 }
481 }
482
483 fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
485 let Self {
486 sockets,
487 seq_no,
488 manager,
489 ..
490 } = self;
491 trace!("Sending msg ty w/ header: {hdr:?}");
492
493 if hdr.dst.port_id == 255 {
494 todo!("Don't do that");
495 }
496
497 Self::unicast_err(
499 sockets,
500 hdr,
501 |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
502 || manager.send_err(hdr, err),
503 )
504 }
505
506 fn find_one_local(
509 sockets: &mut List<SocketHeader>,
510 hdr: &Header,
511 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
512 let mut iter = sockets.iter_raw();
514 let socket = loop {
515 let Some(skt) = iter.next() else {
516 return Err(NetStackSendError::NoRoute);
517 };
518 let skt_ref = unsafe { skt.as_ref() };
519 if skt_ref.port != hdr.dst.port_id {
520 continue;
521 }
522 if skt_ref.attrs.kind != hdr.kind {
523 return Err(NetStackSendError::WrongPortKind);
524 }
525 break skt;
526 };
527 Ok(socket)
528 }
529
530 fn find_one_err_local(
533 sockets: &mut List<SocketHeader>,
534 hdr: &Header,
535 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
536 let mut iter = sockets.iter_raw();
538 let socket = loop {
539 let Some(skt) = iter.next() else {
540 return Err(NetStackSendError::NoRoute);
541 };
542 let skt_ref = unsafe { skt.as_ref() };
543 if skt_ref.port != hdr.dst.port_id {
544 continue;
545 }
546 break skt;
547 };
548 Ok(socket)
549 }
550
551 fn find_any_local(
556 sockets: &mut List<SocketHeader>,
557 hdr: &Header,
558 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
559 let Some(apdx) = hdr.any_all.as_ref() else {
561 return Err(NetStackSendError::AnyPortMissingKey);
562 };
563 let mut iter = sockets.iter_raw();
564 let mut socket: Option<NonNull<SocketHeader>> = None;
565
566 loop {
567 let Some(skt) = iter.next() else {
568 break;
569 };
570 let skt_ref = unsafe { skt.as_ref() };
571
572 let mut illegal = false;
575 illegal |= skt_ref.attrs.kind != hdr.kind;
576 illegal |= !skt_ref.attrs.discoverable;
577 illegal |= skt_ref.key != apdx.key;
578 if let Some(nash) = apdx.nash {
579 illegal |= Some(nash) != skt_ref.nash;
580 }
581
582 if illegal {
583 continue;
585 }
586
587 if socket.is_some() {
589 return Err(NetStackSendError::AnyPortNotUnique);
590 }
591 socket = Some(skt);
594 }
595
596 socket.ok_or(NetStackSendError::NoRoute)
597 }
598
599 fn find_all_local(
604 sockets: &mut List<SocketHeader>,
605 hdr: &Header,
606 ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
607 let Some(any_all) = hdr.any_all.as_ref() else {
608 return Err(NetStackSendError::AllPortMissingKey);
609 };
610 Ok(sockets.iter_raw().filter(move |socket| {
611 let skt_ref = unsafe { socket.as_ref() };
612 let bport = skt_ref.port == 255;
613 let dkind = skt_ref.attrs.kind == hdr.kind;
614 let dkey = skt_ref.key == any_all.key;
615
616 let name = if let Some(nash) = any_all.nash {
619 Some(nash) == skt_ref.nash
620 } else {
621 true
622 };
623 bport && dkind && dkey && name
624 }))
625 }
626
627 fn send_ty_to_socket<T: 'static + Serialize + Clone>(
629 this: NonNull<SocketHeader>,
630 t: &T,
631 hdr: &Header,
632 seq_no: &mut u16,
633 ) -> Result<(), NetStackSendError> {
634 let vtable: &'static SocketVTable = {
635 let skt_ref = unsafe { this.as_ref() };
636 skt_ref.vtable
637 };
638
639 if let Some(f) = vtable.recv_owned {
640 let this: NonNull<()> = this.cast();
641 let that: NonNull<T> = NonNull::from(t);
642 let that: NonNull<()> = that.cast();
643 let hdr = hdr.to_headerseq_or_with_seq(|| {
644 let seq = *seq_no;
645 *seq_no = seq_no.wrapping_add(1);
646 seq
647 });
648 (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
649 } else if let Some(_f) = vtable.recv_bor {
650 todo!()
652 } else {
653 Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
658 }
659 }
660
661 fn send_err_to_socket(
663 this: NonNull<SocketHeader>,
664 err: ProtocolError,
665 hdr: &Header,
666 seq_no: &mut u16,
667 ) -> Result<(), NetStackSendError> {
668 let vtable: &'static SocketVTable = {
669 let skt_ref = unsafe { this.as_ref() };
670 skt_ref.vtable
671 };
672
673 if let Some(f) = vtable.recv_err {
674 let this: NonNull<()> = this.cast();
675 let hdr = hdr.to_headerseq_or_with_seq(|| {
676 let seq = *seq_no;
677 *seq_no = seq_no.wrapping_add(1);
678 seq
679 });
680 (f)(this, hdr, err);
681 Ok(())
682 } else {
683 Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
688 }
689 }
690
691 fn send_raw_to_socket(
716 this: NonNull<SocketHeader>,
717 body: &[u8],
718 hdr: &Header,
719 hdr_raw: &[u8],
720 seq_no: &mut u16,
721 ) -> Result<(), NetStackSendError> {
722 let vtable: &'static SocketVTable = {
723 let skt_ref = unsafe { this.as_ref() };
724 skt_ref.vtable
725 };
726 let f = vtable.recv_raw;
727
728 let this: NonNull<()> = this.cast();
729 let hdr = hdr.to_headerseq_or_with_seq(|| {
730 let seq = *seq_no;
731 *seq_no = seq_no.wrapping_add(1);
732 seq
733 });
734
735 (f)(this, body, hdr, hdr_raw).map_err(NetStackSendError::SocketSend)
736 }
737}
738
739impl<M> NetStackInner<M>
740where
741 M: InterfaceManager,
742{
743 fn alloc_port(&mut self) -> Option<u8> {
752 self.pcache_bits |= (self.pcache_start == 0) as u32;
754
755 if self.pcache_bits != u32::MAX {
756 let ldg = self.pcache_bits.trailing_ones();
758 debug_assert!(ldg < 32);
759 self.pcache_bits |= 1 << ldg;
760 return Some(self.pcache_start + (ldg as u8));
761 }
762
763 let old_start = self.pcache_start;
767 for base in 0..8 {
768 let start = base * 32;
769 if start == old_start {
770 continue;
771 }
772 self.pcache_start = start;
774 self.pcache_bits = 0;
775 self.pcache_bits |= (self.pcache_start == 0) as u32;
777 self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
779
780 self.sockets.iter().for_each(|s| {
786 if s.port == 255 {
787 return;
788 }
789
790 let pupper = s.port & !(32 - 1);
792 let plower = s.port & (32 - 1);
794
795 if pupper == self.pcache_start {
796 self.pcache_bits |= 1 << plower;
797 }
798 });
799
800 if self.pcache_bits != u32::MAX {
801 let ldg = self.pcache_bits.trailing_ones();
803 debug_assert!(ldg < 32);
804 self.pcache_bits |= 1 << ldg;
805 return Some(self.pcache_start + (ldg as u8));
806 }
807 }
808
809 None
811 }
812
813 fn free_port(&mut self, port: u8) {
814 debug_assert!(port != 255);
815 let pupper = port & !(32 - 1);
817 let plower = port & (32 - 1);
819
820 if pupper == self.pcache_start {
823 self.pcache_bits &= !(1 << plower);
824 }
825 }
826}
827
828impl NetStackSendError {
829 pub fn to_error(&self) -> ProtocolError {
830 match self {
831 NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
832 NetStackSendError::InterfaceSend(interface_send_error) => {
833 interface_send_error.to_error()
834 }
835 NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
836 NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
837 NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
838 NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
839 NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
840 }
841 }
842}
843
844#[cfg(test)]
845mod test {
846 use core::pin::pin;
847 use mutex::raw_impls::cs::CriticalSectionRawMutex;
848 use std::thread::JoinHandle;
849 use tokio::sync::oneshot;
850
851 use crate::{
852 FrameKind, Key, NetStack,
853 interface_manager::null::NullInterfaceManager,
854 socket::{Attributes, owned::single::Socket},
855 };
856
857 #[test]
858 fn port_alloc() {
859 static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
860
861 let mut v = vec![];
862
863 fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
864 let (txdone, rxdone) = oneshot::channel();
865 let (txwait, rxwait) = oneshot::channel();
866 let hdl = std::thread::spawn(move || {
867 let skt = Socket::<u64, &_>::new(
868 &STACK,
869 Key(*b"TEST1234"),
870 Attributes {
871 kind: FrameKind::ENDPOINT_REQ,
872 discoverable: true,
873 },
874 None,
875 );
876 let skt = pin!(skt);
877 let hdl = skt.attach();
878 assert_eq!(hdl.port(), id);
879 txwait.send(()).unwrap();
880 let _: () = rxdone.blocking_recv().unwrap();
881 });
882 let _ = rxwait.blocking_recv();
883 (id, hdl, txdone)
884 }
885
886 for i in 1..32 {
888 v.push(spawn_skt(i));
889 }
890
891 for i in 32..40 {
893 v.push(spawn_skt(i));
894 }
895
896 let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
898 let (_i, hdl, tx) = v.remove(pos);
899 tx.send(()).unwrap();
900 hdl.join().unwrap();
901
902 v.push(spawn_skt(35));
904
905 let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
907 let (_i, hdl, tx) = v.remove(pos);
908 tx.send(()).unwrap();
909 hdl.join().unwrap();
910
911 v.push(spawn_skt(40));
913
914 for i in 41..64 {
916 v.push(spawn_skt(i));
917 }
918
919 v.push(spawn_skt(4));
921
922 for i in 64..255 {
924 v.push(spawn_skt(i));
925 }
926
927 let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
929 let (_i, hdl, tx) = v.remove(pos);
930 tx.send(()).unwrap();
931 hdl.join().unwrap();
932
933 v.push(spawn_skt(212));
935
936 let hdl = std::thread::spawn(move || {
938 let skt = Socket::<u64, &_>::new(
939 &STACK,
940 Key(*b"TEST1234"),
941 Attributes {
942 kind: FrameKind::ENDPOINT_REQ,
943 discoverable: true,
944 },
945 None,
946 );
947 let skt = pin!(skt);
948 let hdl = skt.attach();
949 println!("{}", hdl.port());
950 });
951 assert!(hdl.join().is_err());
952
953 for (_i, hdl, tx) in v.drain(..) {
954 tx.send(()).unwrap();
955 hdl.join().unwrap();
956 }
957 }
958}