1use core::{any::TypeId, ptr::NonNull};
22
23use cordyceps::List;
24use log::{debug, trace};
25use mutex::{BlockingMutex, ConstInit, ScopedRawMutex};
26use serde::Serialize;
27
28use crate::{
29 FrameKind, Header, ProtocolError,
30 interface_manager::{self, InterfaceManager, InterfaceSendError},
31 socket::{SocketHeader, SocketSendError, SocketVTable},
32};
33
34pub struct NetStack<R: ScopedRawMutex, M: InterfaceManager> {
36 inner: BlockingMutex<R, NetStackInner<M>>,
37}
38
39pub(crate) struct NetStackInner<M: InterfaceManager> {
40 sockets: List<SocketHeader>,
41 manager: M,
42 pcache_bits: u32,
43 pcache_start: u8,
44 seq_no: u16,
45}
46
47#[derive(Debug, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum NetStackSendError {
51 SocketSend(SocketSendError),
52 InterfaceSend(InterfaceSendError),
53 NoRoute,
54 AnyPortMissingKey,
55 WrongPortKind,
56 AnyPortNotUnique,
57 AllPortMissingKey,
58}
59
60impl<R, M> NetStack<R, M>
63where
64 R: ScopedRawMutex + ConstInit,
65 M: InterfaceManager + interface_manager::ConstInit,
66{
67 pub const fn new() -> Self {
83 Self {
84 inner: BlockingMutex::new(NetStackInner::new()),
85 }
86 }
87}
88
89impl<R, M> NetStack<R, M>
90where
91 R: ScopedRawMutex,
92 M: InterfaceManager,
93{
94 pub const fn const_new(r: R, m: M) -> Self {
102 Self {
103 inner: BlockingMutex::const_new(
104 r,
105 NetStackInner {
106 sockets: List::new(),
107 manager: m,
108 seq_no: 0,
109 pcache_start: 0,
110 pcache_bits: 0,
111 },
112 ),
113 }
114 }
115
116 pub fn with_interface_manager<F: FnOnce(&mut M) -> U, U>(&'static self, f: F) -> U {
144 self.inner.with_lock(|inner| f(&mut inner.manager))
145 }
146
147 pub fn send_raw(&'static self, hdr: &Header, body: &[u8]) -> Result<(), NetStackSendError> {
153 self.inner.with_lock(|inner| inner.send_raw(hdr, body))
154 }
155
156 pub fn send_ty<T: 'static + Serialize + Clone>(
158 &'static self,
159 hdr: &Header,
160 t: &T,
161 ) -> Result<(), NetStackSendError> {
162 self.inner.with_lock(|inner| inner.send_ty(hdr, t))
163 }
164
165 pub fn send_err(
166 &'static self,
167 hdr: &Header,
168 err: ProtocolError,
169 ) -> Result<(), NetStackSendError> {
170 self.inner.with_lock(|inner| inner.send_err(hdr, err))
171 }
172
173 pub(crate) unsafe fn try_attach_socket(
174 &'static self,
175 mut node: NonNull<SocketHeader>,
176 ) -> Option<u8> {
177 self.inner.with_lock(|inner| {
178 let new_port = inner.alloc_port()?;
179 unsafe {
180 node.as_mut().port = new_port;
181 }
182
183 inner.sockets.push_front(node);
184 Some(new_port)
185 })
186 }
187
188 pub(crate) unsafe fn attach_broadcast_socket(&'static self, mut node: NonNull<SocketHeader>) {
189 self.inner.with_lock(|inner| {
190 unsafe {
191 node.as_mut().port = 255;
192 }
193 inner.sockets.push_back(node);
194 });
195 }
196
197 pub(crate) unsafe fn attach_socket(&'static self, node: NonNull<SocketHeader>) -> u8 {
198 let res = unsafe { self.try_attach_socket(node) };
199 let Some(new_port) = res else {
200 panic!("exhausted all addrs");
201 };
202 new_port
203 }
204
205 pub(crate) unsafe fn detach_socket(&'static self, node: NonNull<SocketHeader>) {
206 self.inner.with_lock(|inner| unsafe {
207 let port = node.as_ref().port;
208 if port != 255 {
209 inner.free_port(port);
210 }
211 inner.sockets.remove(node)
212 });
213 }
214
215 pub(crate) unsafe fn with_lock<U, F: FnOnce() -> U>(&'static self, f: F) -> U {
216 self.inner.with_lock(|_inner| f())
217 }
218}
219
220impl<R, M> Default for NetStack<R, M>
221where
222 R: ScopedRawMutex + ConstInit,
223 M: InterfaceManager + interface_manager::ConstInit,
224{
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230impl<M> NetStackInner<M>
233where
234 M: InterfaceManager,
235 M: interface_manager::ConstInit,
236{
237 pub const fn new() -> Self {
238 Self {
239 sockets: List::new(),
240 manager: M::INIT,
241 seq_no: 0,
242 pcache_bits: 0,
243 pcache_start: 0,
244 }
245 }
246}
247
248impl<M> NetStackInner<M>
249where
250 M: InterfaceManager,
251{
252 fn broadcast<SendSocket, SendMgr>(
257 sockets: &mut List<SocketHeader>,
258 hdr: &Header,
259 mut sskt: SendSocket,
260 smgr: SendMgr,
261 ) -> Result<(), NetStackSendError>
262 where
263 SendSocket: FnMut(NonNull<SocketHeader>) -> bool,
264 SendMgr: FnOnce() -> bool,
265 {
266 trace!("Sending msg broadcast w/ header: {hdr:?}");
267 let res_lcl = {
268 let bcast_iter = Self::find_all_local(sockets, hdr)?;
269 let mut any_found = false;
270 for dst in bcast_iter {
271 let res = sskt(dst);
272 if res {
273 debug!("delivered broadcast message locally");
274 }
275 any_found |= res;
276 }
277 any_found
278 };
279
280 let res_rmt = smgr();
281 if res_rmt {
282 debug!("delivered broadcast message remotely");
283 }
284
285 if res_lcl || res_rmt {
286 Ok(())
287 } else {
288 Err(NetStackSendError::NoRoute)
289 }
290 }
291
292 fn unicast<SendSocket, SendMgr>(
297 sockets: &mut List<SocketHeader>,
298 hdr: &Header,
299 sskt: SendSocket,
300 smgr: SendMgr,
301 ) -> Result<(), NetStackSendError>
302 where
303 SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
304 SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
305 {
306 trace!("Sending msg unicast w/ header: {hdr:?}");
307 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
309
310 let res = if !local_bypass {
311 debug!("Offering msg externally unicast w/ header: {hdr:?}");
313 smgr()
314 } else {
315 Err(InterfaceSendError::DestinationLocal)
317 };
318
319 match res {
320 Ok(()) => {
321 debug!("Externally routed msg unicast");
322 return Ok(());
323 }
324 Err(InterfaceSendError::DestinationLocal) => {
325 debug!("No external interest in msg unicast");
326 }
327 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
328 }
329
330 let socket = if hdr.dst.port_id == 0 {
332 debug!("Sending ANY unicast msg locally w/ header: {hdr:?}");
333 Self::find_any_local(sockets, hdr)
334 } else {
335 debug!("Sending ONE unicast msg locally w/ header: {hdr:?}");
336 Self::find_one_local(sockets, hdr)
337 }?;
338
339 sskt(socket)
340 }
341
342 fn unicast_err<SendSocket, SendMgr>(
347 sockets: &mut List<SocketHeader>,
348 hdr: &Header,
349 sskt: SendSocket,
350 smgr: SendMgr,
351 ) -> Result<(), NetStackSendError>
352 where
353 SendSocket: FnOnce(NonNull<SocketHeader>) -> Result<(), NetStackSendError>,
354 SendMgr: FnOnce() -> Result<(), InterfaceSendError>,
355 {
356 trace!("Sending err unicast w/ header: {hdr:?}");
357 let local_bypass = hdr.src.net_node_any() && hdr.dst.net_node_any();
359
360 let res = if !local_bypass {
361 debug!("Offering err externally unicast w/ header: {hdr:?}");
363 smgr()
364 } else {
365 Err(InterfaceSendError::DestinationLocal)
367 };
368
369 match res {
370 Ok(()) => {
371 debug!("Externally routed err unicast");
372 return Ok(());
373 }
374 Err(InterfaceSendError::DestinationLocal) => {
375 debug!("No external interest in err unicast");
376 }
377 Err(e) => return Err(NetStackSendError::InterfaceSend(e)),
378 }
379
380 let socket = Self::find_one_err_local(sockets, hdr)?;
382
383 sskt(socket)
384 }
385
386 fn send_raw(&mut self, hdr: &Header, body: &[u8]) -> Result<(), NetStackSendError> {
388 let Self {
389 sockets,
390 seq_no,
391 manager,
392 ..
393 } = self;
394 trace!("Sending msg raw w/ header: {hdr:?}");
395
396 if hdr.kind == FrameKind::PROTOCOL_ERROR {
397 todo!("Don't do that");
398 }
399
400 if hdr.dst.port_id == 255 {
402 Self::broadcast(
403 sockets,
404 hdr,
405 |skt| Self::send_raw_to_socket(skt, body, hdr, seq_no).is_ok(),
406 || manager.send_raw(hdr, body).is_ok(),
407 )
408 } else {
409 Self::unicast(
410 sockets,
411 hdr,
412 |skt| Self::send_raw_to_socket(skt, body, hdr, seq_no),
413 || manager.send_raw(hdr, body),
414 )
415 }
416 }
417
418 fn send_ty<T: 'static + Serialize + Clone>(
420 &mut self,
421 hdr: &Header,
422 t: &T,
423 ) -> Result<(), NetStackSendError> {
424 let Self {
425 sockets,
426 seq_no,
427 manager,
428 ..
429 } = self;
430 trace!("Sending msg ty w/ header: {hdr:?}");
431
432 if hdr.kind == FrameKind::PROTOCOL_ERROR {
433 todo!("Don't do that");
434 }
435
436 if hdr.dst.port_id == 255 {
438 Self::broadcast(
439 sockets,
440 hdr,
441 |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no).is_ok(),
442 || manager.send(hdr, t).is_ok(),
443 )
444 } else {
445 Self::unicast(
446 sockets,
447 hdr,
448 |skt| Self::send_ty_to_socket(skt, t, hdr, seq_no),
449 || manager.send(hdr, t),
450 )
451 }
452 }
453
454 fn send_err(&mut self, hdr: &Header, err: ProtocolError) -> Result<(), NetStackSendError> {
456 let Self {
457 sockets,
458 seq_no,
459 manager,
460 ..
461 } = self;
462 trace!("Sending msg ty w/ header: {hdr:?}");
463
464 if hdr.dst.port_id == 255 {
465 todo!("Don't do that");
466 }
467
468 Self::unicast_err(
470 sockets,
471 hdr,
472 |skt| Self::send_err_to_socket(skt, err, hdr, seq_no),
473 || manager.send_err(hdr, err),
474 )
475 }
476
477 fn find_one_local(
480 sockets: &mut List<SocketHeader>,
481 hdr: &Header,
482 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
483 let mut iter = sockets.iter_raw();
485 let socket = loop {
486 let Some(skt) = iter.next() else {
487 return Err(NetStackSendError::NoRoute);
488 };
489 let skt_ref = unsafe { skt.as_ref() };
490 if skt_ref.port != hdr.dst.port_id {
491 continue;
492 }
493 if skt_ref.attrs.kind != hdr.kind {
494 return Err(NetStackSendError::WrongPortKind);
495 }
496 break skt;
497 };
498 Ok(socket)
499 }
500
501 fn find_one_err_local(
504 sockets: &mut List<SocketHeader>,
505 hdr: &Header,
506 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
507 let mut iter = sockets.iter_raw();
509 let socket = loop {
510 let Some(skt) = iter.next() else {
511 return Err(NetStackSendError::NoRoute);
512 };
513 let skt_ref = unsafe { skt.as_ref() };
514 if skt_ref.port != hdr.dst.port_id {
515 continue;
516 }
517 break skt;
518 };
519 Ok(socket)
520 }
521
522 fn find_any_local(
527 sockets: &mut List<SocketHeader>,
528 hdr: &Header,
529 ) -> Result<NonNull<SocketHeader>, NetStackSendError> {
530 let Some(key) = hdr.key.as_ref() else {
532 return Err(NetStackSendError::AnyPortMissingKey);
533 };
534 let mut iter = sockets.iter_raw();
535 let mut socket: Option<NonNull<SocketHeader>> = None;
536
537 loop {
538 let Some(skt) = iter.next() else {
539 break;
540 };
541 let skt_ref = unsafe { skt.as_ref() };
542
543 let mut illegal = false;
546 illegal |= skt_ref.attrs.kind != hdr.kind;
547 illegal |= !skt_ref.attrs.discoverable;
548 illegal |= &skt_ref.key != key;
549
550 if illegal {
551 continue;
553 }
554
555 if socket.is_some() {
557 return Err(NetStackSendError::AnyPortNotUnique);
558 }
559 socket = Some(skt);
562 }
563
564 socket.ok_or(NetStackSendError::NoRoute)
565 }
566
567 fn find_all_local(
572 sockets: &mut List<SocketHeader>,
573 hdr: &Header,
574 ) -> Result<impl Iterator<Item = NonNull<SocketHeader>>, NetStackSendError> {
575 let Some(key) = hdr.key.as_ref() else {
576 return Err(NetStackSendError::AllPortMissingKey);
577 };
578 Ok(sockets.iter_raw().filter(move |socket| {
579 let skt_ref = unsafe { socket.as_ref() };
580 let bport = skt_ref.port == 255;
581 let dkind = skt_ref.attrs.kind == hdr.kind;
582 let dkey = &skt_ref.key == key;
583 bport && dkind && dkey
584 }))
585 }
586
587 fn send_ty_to_socket<T: 'static + Serialize + Clone>(
589 this: NonNull<SocketHeader>,
590 t: &T,
591 hdr: &Header,
592 seq_no: &mut u16,
593 ) -> Result<(), NetStackSendError> {
594 let vtable: &'static SocketVTable = {
595 let skt_ref = unsafe { this.as_ref() };
596 skt_ref.vtable
597 };
598
599 if let Some(f) = vtable.recv_owned {
600 let this: NonNull<()> = this.cast();
601 let that: NonNull<T> = NonNull::from(t);
602 let that: NonNull<()> = that.cast();
603 let hdr = hdr.to_headerseq_or_with_seq(|| {
604 let seq = *seq_no;
605 *seq_no = seq_no.wrapping_add(1);
606 seq
607 });
608 (f)(this, that, hdr, &TypeId::of::<T>()).map_err(NetStackSendError::SocketSend)
609 } else if let Some(_f) = vtable.recv_bor {
610 todo!()
612 } else {
613 Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
618 }
619 }
620
621 fn send_err_to_socket(
623 this: NonNull<SocketHeader>,
624 err: ProtocolError,
625 hdr: &Header,
626 seq_no: &mut u16,
627 ) -> Result<(), NetStackSendError> {
628 let vtable: &'static SocketVTable = {
629 let skt_ref = unsafe { this.as_ref() };
630 skt_ref.vtable
631 };
632
633 if let Some(f) = vtable.recv_err {
634 let this: NonNull<()> = this.cast();
635 let hdr = hdr.to_headerseq_or_with_seq(|| {
636 let seq = *seq_no;
637 *seq_no = seq_no.wrapping_add(1);
638 seq
639 });
640 (f)(this, hdr, err);
641 Ok(())
642 } else {
643 Err(NetStackSendError::SocketSend(SocketSendError::WhatTheHell))
648 }
649 }
650
651 fn send_raw_to_socket(
676 this: NonNull<SocketHeader>,
677 body: &[u8],
678 hdr: &Header,
679 seq_no: &mut u16,
680 ) -> Result<(), NetStackSendError> {
681 let vtable: &'static SocketVTable = {
682 let skt_ref = unsafe { this.as_ref() };
683 skt_ref.vtable
684 };
685 let f = vtable.recv_raw;
686
687 let this: NonNull<()> = this.cast();
688 let hdr = hdr.to_headerseq_or_with_seq(|| {
689 let seq = *seq_no;
690 *seq_no = seq_no.wrapping_add(1);
691 seq
692 });
693
694 (f)(this, body, hdr).map_err(NetStackSendError::SocketSend)
695 }
696}
697
698impl<M> NetStackInner<M>
699where
700 M: InterfaceManager,
701{
702 fn alloc_port(&mut self) -> Option<u8> {
711 self.pcache_bits |= (self.pcache_start == 0) as u32;
713
714 if self.pcache_bits != u32::MAX {
715 let ldg = self.pcache_bits.trailing_ones();
717 debug_assert!(ldg < 32);
718 self.pcache_bits |= 1 << ldg;
719 return Some(self.pcache_start + (ldg as u8));
720 }
721
722 let old_start = self.pcache_start;
726 for base in 0..8 {
727 let start = base * 32;
728 if start == old_start {
729 continue;
730 }
731 self.pcache_start = start;
733 self.pcache_bits = 0;
734 self.pcache_bits |= (self.pcache_start == 0) as u32;
736 self.pcache_bits |= ((self.pcache_start == 0b111_00000) as u32) << 31;
738
739 self.sockets.iter().for_each(|s| {
745 if s.port == 255 {
746 return;
747 }
748
749 let pupper = s.port & !(32 - 1);
751 let plower = s.port & (32 - 1);
753
754 if pupper == self.pcache_start {
755 self.pcache_bits |= 1 << plower;
756 }
757 });
758
759 if self.pcache_bits != u32::MAX {
760 let ldg = self.pcache_bits.trailing_ones();
762 debug_assert!(ldg < 32);
763 self.pcache_bits |= 1 << ldg;
764 return Some(self.pcache_start + (ldg as u8));
765 }
766 }
767
768 None
770 }
771
772 fn free_port(&mut self, port: u8) {
773 debug_assert!(port != 255);
774 let pupper = port & !(32 - 1);
776 let plower = port & (32 - 1);
778
779 if pupper == self.pcache_start {
782 self.pcache_bits &= !(1 << plower);
783 }
784 }
785}
786
787impl NetStackSendError {
788 pub fn to_error(&self) -> ProtocolError {
789 match self {
790 NetStackSendError::SocketSend(socket_send_error) => socket_send_error.to_error(),
791 NetStackSendError::InterfaceSend(interface_send_error) => {
792 interface_send_error.to_error()
793 }
794 NetStackSendError::NoRoute => ProtocolError::NSSE_NO_ROUTE,
795 NetStackSendError::AnyPortMissingKey => ProtocolError::NSSE_ANY_PORT_MISSING_KEY,
796 NetStackSendError::WrongPortKind => ProtocolError::NSSE_WRONG_PORT_KIND,
797 NetStackSendError::AnyPortNotUnique => ProtocolError::NSSE_ANY_PORT_NOT_UNIQUE,
798 NetStackSendError::AllPortMissingKey => ProtocolError::NSSE_ALL_PORT_MISSING_KEY,
799 }
800 }
801}
802
803#[cfg(test)]
804mod test {
805 use core::pin::pin;
806 use mutex::raw_impls::cs::CriticalSectionRawMutex;
807 use std::thread::JoinHandle;
808 use tokio::sync::oneshot;
809
810 use crate::{
811 FrameKind, Key, NetStack,
812 interface_manager::null::NullInterfaceManager,
813 socket::{Attributes, single::Socket},
814 };
815
816 #[test]
817 fn port_alloc() {
818 static STACK: NetStack<CriticalSectionRawMutex, NullInterfaceManager> = NetStack::new();
819
820 let mut v = vec![];
821
822 fn spawn_skt(id: u8) -> (u8, JoinHandle<()>, oneshot::Sender<()>) {
823 let (txdone, rxdone) = oneshot::channel();
824 let (txwait, rxwait) = oneshot::channel();
825 let hdl = std::thread::spawn(move || {
826 let skt = Socket::<u64, _, _>::new(
827 &STACK,
828 Key(*b"TEST1234"),
829 Attributes {
830 kind: FrameKind::ENDPOINT_REQ,
831 discoverable: true,
832 },
833 );
834 let skt = pin!(skt);
835 let hdl = skt.attach();
836 assert_eq!(hdl.port(), id);
837 txwait.send(()).unwrap();
838 let _: () = rxdone.blocking_recv().unwrap();
839 });
840 let _ = rxwait.blocking_recv();
841 (id, hdl, txdone)
842 }
843
844 for i in 1..32 {
846 v.push(spawn_skt(i));
847 }
848
849 for i in 32..40 {
851 v.push(spawn_skt(i));
852 }
853
854 let pos = v.iter().position(|(i, _, _)| *i == 35).unwrap();
856 let (_i, hdl, tx) = v.remove(pos);
857 tx.send(()).unwrap();
858 hdl.join().unwrap();
859
860 v.push(spawn_skt(35));
862
863 let pos = v.iter().position(|(i, _, _)| *i == 4).unwrap();
865 let (_i, hdl, tx) = v.remove(pos);
866 tx.send(()).unwrap();
867 hdl.join().unwrap();
868
869 v.push(spawn_skt(40));
871
872 for i in 41..64 {
874 v.push(spawn_skt(i));
875 }
876
877 v.push(spawn_skt(4));
879
880 for i in 64..255 {
882 v.push(spawn_skt(i));
883 }
884
885 let pos = v.iter().position(|(i, _, _)| *i == 212).unwrap();
887 let (_i, hdl, tx) = v.remove(pos);
888 tx.send(()).unwrap();
889 hdl.join().unwrap();
890
891 v.push(spawn_skt(212));
893
894 let hdl = std::thread::spawn(move || {
896 let skt = Socket::<u64, _, _>::new(
897 &STACK,
898 Key(*b"TEST1234"),
899 Attributes {
900 kind: FrameKind::ENDPOINT_REQ,
901 discoverable: true,
902 },
903 );
904 let skt = pin!(skt);
905 let hdl = skt.attach();
906 println!("{}", hdl.port());
907 });
908 assert!(hdl.join().is_err());
909
910 for (_i, hdl, tx) in v.drain(..) {
911 tx.send(()).unwrap();
912 hdl.join().unwrap();
913 }
914 }
915}