1use core::{any::TypeId, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29 FrameKind, Header, ProtocolError,
30 interface_manager::{self, InterfaceManager, InterfaceSendError},
31 socket::{SocketHeader, SocketSendError, SocketVTable},
32};
33
34pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
36 inner: BlockingMutex<R, NetStackInner<M>>,
37}
38
39pub(crate) struct NetStackInner<M: InterfaceManager> {
40 sockets: List<SocketHeader>,
41 manager: M,
42 pcache_bits: u32,
43 pcache_start: u8,
44 seq_no: u16,
45}
46
47#[derive(Debug, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum NetStackSendError {
51 SocketSend(SocketSendError),
52 InterfaceSend(InterfaceSendError),
53 NoRoute,
54 AnyPortMissingKey,
55 WrongPortKind,
56 AnyPortNotUnique,
57 AllPortMissingKey,
58}
59
60impl<R, M> NetStack<R, M>
63where
64 R: ScopedRawMutex + ConstInit,
65 M: InterfaceManager + interface_manager::ConstInit,
66{
67 pub const fn new() -> Self {
83 Self {
84 inner: BlockingMutex::new(NetStackInner::new()),
85 }
86 }
87}
88
89impl<R, M> NetStack<R, M>
90where
91 R: ScopedRawMutex,
92 M: InterfaceManager,
93{
94 pub const fn const_new(r: R, m: M) -> Self {
102 Self {
103 inner: BlockingMutex::const_new(
104 r,
105 NetStackInner {
106 sockets: List::new(),
107 manager: m,
108 seq_no: 0,
109 pcache_start: 0,
110 pcache_bits: 0,
111 },
112 ),
113 }
114 }
115
116 pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&'static self, f: F) -> U {
144 self.inner.with_lock(|inner| f(&mut inner.manager))
145 }
146
147 pub fn send_raw(
153 &'static self,
154 hdr: &Header,
155 hdr_raw: &[u8],
156 body: &[u8],
157 ) -> Result<(), NetStackSendError> {
158 self.inner
159 .with_lock(|inner| inner.send_raw(hdr, hdr_raw, body))
160 }
161
162 pub fn send_ty<T: 'static + Serialize + Clone>(
164 &'static self,
165 hdr: &Header,
166 t: &T,
167 ) -> Result<(), NetStackSendError> {
168 self.inner.with_lock(|inner| inner.send_ty(hdr, t))
169 }
170
171 pub fn send_err(
172 &'static self,
173 hdr: &Header,
174 err: ProtocolError,
175 ) -> Result<(), NetStackSendError> {
176 self.inner.with_lock(|inner| inner.send_err(hdr, err))
177 }
178
179 pub(crate) unsafe fn try_attach_socket(
180 &'static self,
181 mut node: NonNull<SocketHeader>,
182 ) -> Option<u8> {
183 self.inner.with_lock(|inner| {
184 let new_port = inner.alloc_port()?;
185 unsafe {
186 node.as_mut().port = new_port;
187 }
188
189 inner.sockets.push_front(node);
190 Some(new_port)
191 })
192 }
193
194 pub(crate) unsafe fn attach_broadcast_socket(&'static self, mut node: NonNull<SocketHeader>) {
195 self.inner.with_lock(|inner| {
196 unsafe {
197 node.as_mut().port = 255;
198 }
199 inner.sockets.push_back(node);
200 });
201 }
202
203 pub(crate) unsafe fn attach_socket(&'static self, node: NonNull<SocketHeader>) -> u8 {
204 let res = unsafe { self.try_attach_socket(node) };
205 let Some(new_port) = res else {
206 panic!("exhausted all addrs");
207 };
208 new_port
209 }
210
211 pub(crate) unsafe fn detach_socket(&'static self, node: NonNull<SocketHeader>) {
212 self.inner.with_lock(|inner| unsafe {
213 let port = node.as_ref().port;
214 if port != 255 {
215 inner.free_port(port);
216 }
217 inner.sockets.remove(node)
218 });
219 }
220
221 pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&'static self, f: F) -> U {
222 self.inner.with_lock(|_inner| f())
223 }
224}
225
226impl<R, M> Default for NetStack<R, M>
227where
228 R: ScopedRawMutex + ConstInit,
229 M: InterfaceManager + interface_manager::ConstInit,
230{
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236impl<M> NetStackInner<M>
239where
240 M: InterfaceManager,
241 M: interface_manager::ConstInit,
242{
243 pub const fn new() -> Self {
244 Self {
245 sockets: List::new(),
246 manager: M::INIT,
247 seq_no: 0,
248 pcache_bits: 0,
249 pcache_start: 0,
250 }
251 }
252}
253
254impl<M> NetStackInner<M>
255where
256 M: InterfaceManager,
257{
258 fn broadcast<SendSocket, SendMgr>(
263 sockets: &mut List<SocketHeader>,
264 hdr: &Header,
265 mut sskt: SendSocket,
266 smgr: SendMgr,
267 ) -> Result<(), NetStackSendError>
268 where
269 SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
270 SendMgr: FnOnce() -> bool,
271 {
272 trace!("Sending msg broadcast w/ header: {hdr:?}");
273 let res_lcl = {
274 let bcast_iter = Self::find_all_local(sockets, hdr)?;
275 let mut any_found = false;
276 for dst in bcast_iter {
277 let res = sskt(dst);
278 if res {
279 debug!("delivered broadcast message locally");
280 }
281 any_found |= res;
282 }
283 any_found
284 };
285
286 let res_rmt = smgr();
287 if res_rmt {
288 debug!("delivered broadcast message remotely");
289 }
290
291 if res_lcl || res_rmt {
292 Ok(())
293 } else {
294 Err(NetStackSendError::NoRoute)
295 }
296 }
297
298 fn unicast<SendSocket, SendMgr>(
303 sockets: &mut List<SocketHeader>,
304 hdr: &Header,
305 sskt: SendSocket,
306 smgr: SendMgr,
307 ) -> Result<(), NetStackSendError>
308 where
309 SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
310 SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
311 {
312 trace!("Sending msg unicast w/ header: {hdr:?}");
313 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
315
316 let res = if !local_bypass {
317 debug!("Offering msg externally unicast w/ header: {hdr:?}");
319 smgr()
320 } else {
321 Err(InterfaceSendError::DestinationLocal)
323 };
324
325 match res {
326 Ok(()) => {
327 debug!("Externally routed msg unicast");
328 return Ok(());
329 }
330 Err(InterfaceSendError::DestinationLocal) => {
331 debug!("No external interest in msg unicast");
332 }
333 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
334 }
335
336 let socket = if hdr.dst.port_id == 0 {
338 debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
339 Self::find_any_local(sockets, hdr)
340 } else {
341 debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
342 Self::find_one_local(sockets, hdr)
343 }?;
344
345 sskt(socket)
346 }
347
348 fn unicast_err<SendSocket, SendMgr>(
353 sockets: &mut List<SocketHeader>,
354 hdr: &Header,
355 sskt: SendSocket,
356 smgr: SendMgr,
357 ) -> Result<(), NetStackSendError>
358 where
359 SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
360 SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
361 {
362 trace!("Sending err unicast w/ header: {hdr:?}");
363 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
365
366 let res = if !local_bypass {
367 debug!("Offering err externally unicast w/ header: {hdr:?}");
369 smgr()
370 } else {
371 Err(InterfaceSendError::DestinationLocal)
373 };
374
375 match res {
376 Ok(()) => {
377 debug!("Externally routed err unicast");
378 return Ok(());
379 }
380 Err(InterfaceSendError::DestinationLocal) => {
381 debug!("No external interest in err unicast");
382 }
383 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
384 }
385
386 let socket = Self::find_one_err_local(sockets, hdr)?;
388
389 sskt(socket)
390 }
391
392 fn send_raw(
394 &mut self,
395 hdr: &Header,
396 hdr_raw: &[u8],
397 body: &[u8],
398 ) -> Result<(), NetStackSendError> {
399 let Self {
400 sockets,
401 seq_no,
402 manager,
403 ..
404 } = self;
405 trace!("Sending msg raw w/ header: {hdr:?}");
406
407 if hdr.kind == FrameKind::PROTOCOL_ERROR {
408 todo!("Don't do that");
409 }
410
411 if hdr.dst.port_id == 255 {
413 Self::broadcast(
414 sockets,
415 hdr,
416 |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no).is_ok(),
417 || manager.send_raw(hdr, hdr_raw, body).is_ok(),
418 )
419 } else {
420 Self::unicast(
421 sockets,
422 hdr,
423 |skt| Self::send_raw_to_socket(skt, body, hdr, hdr_raw, seq_no),
424 || manager.send_raw(hdr, hdr_raw, body),
425 )
426 }
427 }
428
429 fn send_ty<T: 'static + Serialize + Clone>(
431 &mut self,
432 hdr: &Header,
433 t: &T,
434 ) -> Result<(), NetStackSendError> {
435 let Self {
436 sockets,
437 seq_no,
438 manager,
439 ..
440 } = self;
441 trace!("Sending msg ty w/ header: {hdr:?}");
442
443 if hdr.kind == FrameKind::PROTOCOL_ERROR {
444 todo!("Don't do that");
445 }
446
447 if hdr.dst.port_id == 255 {
449 Self::broadcast(
450 sockets,
451 hdr,
452 |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
453 || manager.send(hdr, t).is_ok(),
454 )
455 } else {
456 Self::unicast(
457 sockets,
458 hdr,
459 |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
460 || manager.send(hdr, t),
461 )
462 }
463 }
464
465 fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
467 let Self {
468 sockets,
469 seq_no,
470 manager,
471 ..
472 } = self;
473 trace!("Sending msg ty w/ header: {hdr:?}");
474
475 if hdr.dst.port_id == 255 {
476 todo!("Don't do that");
477 }
478
479 Self::unicast_err(
481 sockets,
482 hdr,
483 |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
484 || manager.send_err(hdr, err),
485 )
486 }
487
488 fn find_one_local(
491 sockets: &mut List<SocketHeader>,
492 hdr: &Header,
493 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
494 let mut iter = sockets.iter_raw();
496 let socket = loop {
497 let Some(skt) = iter.next() else {
498 return Err(NetStackSendError::NoRoute);
499 };
500 let skt_ref = unsafe { skt.as_ref() };
501 if skt_ref.port != hdr.dst.port_id {
502 continue;
503 }
504 if skt_ref.attrs.kind != hdr.kind {
505 return Err(NetStackSendError::WrongPortKind);
506 }
507 break skt;
508 };
509 Ok(socket)
510 }
511
512 fn find_one_err_local(
515 sockets: &mut List<SocketHeader>,
516 hdr: &Header,
517 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
518 let mut iter = sockets.iter_raw();
520 let socket = loop {
521 let Some(skt) = iter.next() else {
522 return Err(NetStackSendError::NoRoute);
523 };
524 let skt_ref = unsafe { skt.as_ref() };
525 if skt_ref.port != hdr.dst.port_id {
526 continue;
527 }
528 break skt;
529 };
530 Ok(socket)
531 }
532
533 fn find_any_local(
538 sockets: &mut List<SocketHeader>,
539 hdr: &Header,
540 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
541 let Some(apdx) = hdr.any_all.as_ref() else {
543 return Err(NetStackSendError::AnyPortMissingKey);
544 };
545 let mut iter = sockets.iter_raw();
546 let mut socket: Option<NonNull<SocketHeader>> = None;
547
548 loop {
549 let Some(skt) = iter.next() else {
550 break;
551 };
552 let skt_ref = unsafe { skt.as_ref() };
553
554 let mut illegal = false;
557 illegal |= skt_ref.attrs.kind != hdr.kind;
558 illegal |= !skt_ref.attrs.discoverable;
559 illegal |= skt_ref.key != apdx.key;
560 if let Some(nash) = apdx.nash {
561 illegal |= Some(nash) != skt_ref.nash;
562 }
563
564 if illegal {
565 continue;
567 }
568
569 if socket.is_some() {
571 return Err(NetStackSendError::AnyPortNotUnique);
572 }
573 socket = Some(skt);
576 }
577
578 socket.ok_or(NetStackSendError::NoRoute)
579 }
580
581 fn find_all_local(
586 sockets: &mut List<SocketHeader>,
587 hdr: &Header,
588 ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
589 let Some(any_all) = hdr.any_all.as_ref() else {
590 return Err(NetStackSendError::AllPortMissingKey);
591 };
592 Ok(sockets.iter_raw().filter(move |socket| {
593 let skt_ref = unsafe { socket.as_ref() };
594 let bport = skt_ref.port == 255;
595 let dkind = skt_ref.attrs.kind == hdr.kind;
596 let dkey = skt_ref.key == any_all.key;
597
598 let name = if let Some(nash) = any_all.nash {
601 Some(nash) == skt_ref.nash
602 } else {
603 true
604 };
605 bport && dkind && dkey && name
606 }))
607 }
608
609 fn send_ty_to_socket<T: 'static + Serialize + Clone>(
611 this: NonNull<SocketHeader>,
612 t: &T,
613 hdr: &Header,
614 seq_no: &mut u16,
615 ) -> Result<(), NetStackSendError> {
616 let vtable: &'static SocketVTable = {
617 let skt_ref = unsafe { this.as_ref() };
618 skt_ref.vtable
619 };
620
621 if let Some(f) = vtable.recv_owned {
622 let this: NonNull<()> = this.cast();
623 let that: NonNull<T> = NonNull::from(t);
624 let that: NonNull<()> = that.cast();
625 let hdr = hdr.to_headerseq_or_with_seq(|| {
626 let seq = *seq_no;
627 *seq_no = seq_no.wrapping_add(1);
628 seq
629 });
630 (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
631 } else if let Some(_f) = vtable.recv_bor {
632 todo!()
634 } else {
635 Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
640 }
641 }
642
643 fn send_err_to_socket(
645 this: NonNull<SocketHeader>,
646 err: ProtocolError,
647 hdr: &Header,
648 seq_no: &mut u16,
649 ) -> Result<(), NetStackSendError> {
650 let vtable: &'static SocketVTable = {
651 let skt_ref = unsafe { this.as_ref() };
652 skt_ref.vtable
653 };
654
655 if let Some(f) = vtable.recv_err {
656 let this: NonNull<()> = this.cast();
657 let hdr = hdr.to_headerseq_or_with_seq(|| {
658 let seq = *seq_no;
659 *seq_no = seq_no.wrapping_add(1);
660 seq
661 });
662 (f)(this, hdr, err);
663 Ok(())
664 } else {
665 Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
670 }
671 }
672
673 fn send_raw_to_socket(
698 this: NonNull<SocketHeader>,
699 body: &[u8],
700 hdr: &Header,
701 hdr_raw: &[u8],
702 seq_no: &mut u16,
703 ) -> Result<(), NetStackSendError> {
704 let vtable: &'static SocketVTable = {
705 let skt_ref = unsafe { this.as_ref() };
706 skt_ref.vtable
707 };
708 let f = vtable.recv_raw;
709
710 let this: NonNull<()> = this.cast();
711 let hdr = hdr.to_headerseq_or_with_seq(|| {
712 let seq = *seq_no;
713 *seq_no = seq_no.wrapping_add(1);
714 seq
715 });
716
717 (f)(this, body, hdr, hdr_raw).map_err(NetStackSendError::SocketSend)
718 }
719}
720
721impl<M> NetStackInner<M>
722where
723 M: InterfaceManager,
724{
725 fn alloc_port(&mut self) -> Option<u8> {
734 self.pcache_bits |= (self.pcache_start == 0) as u32;
736
737 if self.pcache_bits != u32::MAX {
738 let ldg = self.pcache_bits.trailing_ones();
740 debug_assert!(ldg < 32);
741 self.pcache_bits |= 1 << ldg;
742 return Some(self.pcache_start + (ldg as u8));
743 }
744
745 let old_start = self.pcache_start;
749 for base in 0..8 {
750 let start = base * 32;
751 if start == old_start {
752 continue;
753 }
754 self.pcache_start = start;
756 self.pcache_bits = 0;
757 self.pcache_bits |= (self.pcache_start == 0) as u32;
759 self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
761
762 self.sockets.iter().for_each(|s| {
768 if s.port == 255 {
769 return;
770 }
771
772 let pupper = s.port & !(32 - 1);
774 let plower = s.port & (32 - 1);
776
777 if pupper == self.pcache_start {
778 self.pcache_bits |= 1 << plower;
779 }
780 });
781
782 if self.pcache_bits != u32::MAX {
783 let ldg = self.pcache_bits.trailing_ones();
785 debug_assert!(ldg < 32);
786 self.pcache_bits |= 1 << ldg;
787 return Some(self.pcache_start + (ldg as u8));
788 }
789 }
790
791 None
793 }
794
795 fn free_port(&mut self, port: u8) {
796 debug_assert!(port != 255);
797 let pupper = port & !(32 - 1);
799 let plower = port & (32 - 1);
801
802 if pupper == self.pcache_start {
805 self.pcache_bits &= !(1 << plower);
806 }
807 }
808}
809
810impl NetStackSendError {
811 pub fn to_error(&self) -> ProtocolError {
812 match self {
813 NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
814 NetStackSendError::InterfaceSend(interface_send_error) => {
815 interface_send_error.to_error()
816 }
817 NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
818 NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
819 NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
820 NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
821 NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
822 }
823 }
824}
825
826#[cfg(test)]
827mod test {
828 use core::pin::pin;
829 use mutex::raw_impls::cs::CriticalSectionRawMutex;
830 use std::thread::JoinHandle;
831 use tokio::sync::oneshot;
832
833 use crate::{
834 FrameKind, Key, NetStack,
835 interface_manager::null::NullInterfaceManager,
836 socket::{Attributes, owned::single::Socket},
837 };
838
839 #[test]
840 fn port_alloc() {
841 static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
842
843 let mut v = vec![];
844
845 fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
846 let (txdone, rxdone) = oneshot::channel();
847 let (txwait, rxwait) = oneshot::channel();
848 let hdl = std::thread::spawn(move || {
849 let skt = Socket::<u64, _, _>::new(
850 &STACK,
851 Key(*b"TEST1234"),
852 Attributes {
853 kind: FrameKind::ENDPOINT_REQ,
854 discoverable: true,
855 },
856 None,
857 );
858 let skt = pin!(skt);
859 let hdl = skt.attach();
860 assert_eq!(hdl.port(), id);
861 txwait.send(()).unwrap();
862 let _: () = rxdone.blocking_recv().unwrap();
863 });
864 let _ = rxwait.blocking_recv();
865 (id, hdl, txdone)
866 }
867
868 for i in 1..32 {
870 v.push(spawn_skt(i));
871 }
872
873 for i in 32..40 {
875 v.push(spawn_skt(i));
876 }
877
878 let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
880 let (_i, hdl, tx) = v.remove(pos);
881 tx.send(()).unwrap();
882 hdl.join().unwrap();
883
884 v.push(spawn_skt(35));
886
887 let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
889 let (_i, hdl, tx) = v.remove(pos);
890 tx.send(()).unwrap();
891 hdl.join().unwrap();
892
893 v.push(spawn_skt(40));
895
896 for i in 41..64 {
898 v.push(spawn_skt(i));
899 }
900
901 v.push(spawn_skt(4));
903
904 for i in 64..255 {
906 v.push(spawn_skt(i));
907 }
908
909 let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
911 let (_i, hdl, tx) = v.remove(pos);
912 tx.send(()).unwrap();
913 hdl.join().unwrap();
914
915 v.push(spawn_skt(212));
917
918 let hdl = std::thread::spawn(move || {
920 let skt = Socket::<u64, _, _>::new(
921 &STACK,
922 Key(*b"TEST1234"),
923 Attributes {
924 kind: FrameKind::ENDPOINT_REQ,
925 discoverable: true,
926 },
927 None,
928 );
929 let skt = pin!(skt);
930 let hdl = skt.attach();
931 println!("{}", hdl.port());
932 });
933 assert!(hdl.join().is_err());
934
935 for (_i, hdl, tx) in v.drain(..) {
936 tx.send(()).unwrap();
937 hdl.join().unwrap();
938 }
939 }
940}