1#[allow(unused)]
9use crate::fmt::{debug, error, info, trace, warn};
10
11use core::cell::RefCell;
12use core::future::{poll_fn, Future};
13use core::pin::pin;
14use core::task::Poll;
15
16use crate::reassemble::Reassembler;
17use crate::{
18 AppCookie, Fragmenter, ReceiveHandle, SendOutput, Stack, MAX_MTU,
19 MAX_PAYLOAD,
20};
21use mctp::{Eid, Error, MsgIC, MsgType, Result, Tag, TagValue};
22
23use embassy_sync::waitqueue::{MultiWakerRegistration, WakerRegistration};
24use embassy_sync::zerocopy_channel::{Channel, Receiver, Sender};
25
26use heapless::Vec;
27
28const MAX_LISTENERS: usize = 20;
30const MAX_RECEIVERS: usize = 50;
31
32type RawMutex = embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
34type AsyncMutex<T> = embassy_sync::mutex::Mutex<RawMutex, T>;
35type BlockingMutex<T> =
36 embassy_sync::blocking_mutex::Mutex<RawMutex, RefCell<T>>;
37
38type PortRawMutex = embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
43pub struct PortId(pub u8);
44
45pub trait PortLookup: Send {
47 fn by_eid(
58 &mut self,
59 eid: Eid,
60 source_port: Option<PortId>,
61 ) -> Option<PortId>;
62}
63
64struct PktBuf {
67 data: [u8; MAX_MTU],
68 len: usize,
69 dest: Eid,
70}
71
72impl PktBuf {
73 const fn new() -> Self {
74 Self {
75 data: [0u8; MAX_MTU],
76 len: 0,
77 dest: Eid(0),
78 }
79 }
80
81 fn set(&mut self, data: &[u8]) -> Result<()> {
82 let hdr = Reassembler::header(data);
83 debug_assert!(hdr.is_ok());
84 let hdr = hdr?;
85 let dst = self.data.get_mut(..data.len()).ok_or(Error::NoSpace)?;
86 dst.copy_from_slice(data);
87 self.len = data.len();
88 self.dest = Eid(hdr.dest_endpoint_id());
89 Ok(())
90 }
91}
92
93impl core::ops::Deref for PktBuf {
94 type Target = [u8];
95
96 fn deref(&self) -> &[u8] {
97 &self.data[..self.len]
98 }
99}
100
101pub struct PortTop<'a> {
105 packets: AsyncMutex<Sender<'a, PortRawMutex, PktBuf>>,
108
109 message: AsyncMutex<Vec<u8, MAX_PAYLOAD>>,
112
113 mtu: usize,
114}
115
116impl PortTop<'_> {
117 async fn forward_packet(&self, pkt: &[u8]) -> Result<()> {
123 debug_assert!(Reassembler::header(pkt).is_ok());
124
125 let mut sender = self.packets.lock().await;
126 if pkt.len() > self.mtu {
130 debug!("Forward packet too large");
131 return Err(Error::NoSpace);
132 }
133
134 let slot = sender.try_send().ok_or_else(|| {
136 debug!("Dropped forward packet");
137 Error::TxFailure
138 })?;
139
140 slot.set(pkt).unwrap();
143 sender.send_done();
144 Ok(())
145 }
146
147 async fn send_message(
152 &self,
153 fragmenter: &mut Fragmenter,
154 pkt: &[&[u8]],
155 ) -> Result<Tag> {
156 trace!("send_message");
157 let mut msg;
158 let payload = if pkt.len() == 1 {
159 pkt[0]
161 } else {
162 msg = self.message.lock().await;
163 msg.clear();
164 for p in pkt {
165 msg.extend_from_slice(p).map_err(|_| {
166 debug!("Message too large");
167 Error::NoSpace
168 })?;
169 }
170 &msg
171 };
172
173 loop {
174 let mut sender = self.packets.lock().await;
175
176 let qpkt = sender.send().await;
177 qpkt.len = 0;
178 qpkt.dest = fragmenter.dest();
179 let r = fragmenter.fragment(payload, &mut qpkt.data);
180 match r {
181 SendOutput::Packet(p) => {
182 qpkt.len = p.len();
183 sender.send_done();
184 if fragmenter.is_done() {
185 break Ok(fragmenter.tag());
186 }
187 }
188 SendOutput::Error { err, .. } => {
189 debug!("Error packetising");
190 sender.send_done();
191 break Err(err);
192 }
193 SendOutput::Complete { .. } => unreachable!(),
194 }
195 }
196 }
197}
198
199pub struct PortBottom<'a> {
204 packets: Receiver<'a, PortRawMutex, PktBuf>,
206}
207
208impl PortBottom<'_> {
209 pub async fn outbound(&mut self) -> (&[u8], Eid) {
216 if self.packets.len() > 1 {
217 trace!("packets avail {}", self.packets.len());
218 }
219 let pkt = self.packets.receive().await;
220 (pkt, pkt.dest)
221 }
222
223 pub fn try_outbound(&mut self) -> Option<(&[u8], Eid)> {
232 trace!("packets avail {} try", self.packets.len());
233 self.packets.try_receive().map(|pkt| (&**pkt, pkt.dest))
234 }
235
236 pub fn outbound_done(&mut self) {
238 self.packets.receive_done()
239 }
240}
241
242pub struct PortStorage<const FORWARD_QUEUE: usize = 4> {
247 packets: [PktBuf; FORWARD_QUEUE],
249}
250
251impl<const FORWARD_QUEUE: usize> PortStorage<FORWARD_QUEUE> {
252 pub fn new() -> Self {
253 Self {
254 packets: [const { PktBuf::new() }; FORWARD_QUEUE],
255 }
256 }
257}
258
259impl<const FORWARD_QUEUE: usize> Default for PortStorage<FORWARD_QUEUE> {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265pub struct PortBuilder<'a> {
266 packets: Channel<'a, PortRawMutex, PktBuf>,
268}
269
270impl<'a> PortBuilder<'a> {
271 pub fn new<const FORWARD_QUEUE: usize>(
272 storage: &'a mut PortStorage<FORWARD_QUEUE>,
273 ) -> Self {
274 Self {
277 packets: Channel::new(storage.packets.as_mut_slice()),
278 }
279 }
280
281 pub fn build(
282 &mut self,
283 mtu: usize,
284 ) -> Result<(PortTop<'_>, PortBottom<'_>)> {
285 if mtu > MAX_MTU {
286 debug!("port mtu {} > MAX_MTU {}", mtu, MAX_MTU);
287 return Err(Error::BadArgument);
288 }
289
290 let (ps, pr) = self.packets.split();
291
292 let t = PortTop {
293 message: AsyncMutex::new(Vec::new()),
294 packets: AsyncMutex::new(ps),
295 mtu,
296 };
297 let b = PortBottom { packets: pr };
298 Ok((t, b))
299 }
300}
301
302pub struct Router<'r> {
322 inner: AsyncMutex<RouterInner<'r>>,
323 ports: &'r [PortTop<'r>],
324
325 app_listeners:
329 BlockingMutex<[Option<(MsgType, WakerRegistration)>; MAX_LISTENERS]>,
330}
331
332pub struct RouterInner<'r> {
333 stack: Stack,
335
336 app_receive_wakers: MultiWakerRegistration<MAX_RECEIVERS>,
338
339 lookup: &'r mut dyn PortLookup,
340}
341
342impl<'r> Router<'r> {
343 pub fn new(
352 stack: Stack,
353 ports: &'r [PortTop<'r>],
354 lookup: &'r mut dyn PortLookup,
355 ) -> Self {
356 let inner = RouterInner {
357 stack,
358 app_receive_wakers: MultiWakerRegistration::new(),
359 lookup,
360 };
361
362 Self {
363 inner: AsyncMutex::new(inner),
364 app_listeners: BlockingMutex::new(RefCell::new(
365 [const { None }; MAX_LISTENERS],
366 )),
367 ports,
368 }
369 }
370
371 pub async fn update_time(&self, now_millis: u64) -> Result<u64> {
376 let mut inner = self.inner.lock().await;
377 let (next, expired) = inner.stack.update(now_millis)?;
378 if expired {
379 inner.app_receive_wakers.wake();
382 }
383 Ok(next)
384 }
385
386 pub async fn inbound(&self, pkt: &[u8], port: PortId) -> Option<Eid> {
393 let mut inner = self.inner.lock().await;
394
395 let Ok(header) = Reassembler::header(pkt) else {
396 return None;
397 };
398 let ret_src = Some(Eid(header.source_endpoint_id()));
400
401 if inner.stack.is_local_dest(pkt) {
403 match inner.stack.receive(pkt) {
404 Ok(Some((msg, handle))) => {
406 let typ = msg.typ;
407 let tag = msg.tag;
408 drop(inner);
409 self.incoming_local(tag, typ, handle).await;
410 return ret_src;
411 }
412 Ok(None) => {
414 return ret_src;
415 }
416 Err(e) => {
417 debug!("Dropped local recv packet. {}", e);
418 return ret_src;
419 }
420 }
421 }
422
423 let dest_eid = Eid(header.dest_endpoint_id());
425
426 let Some(p) = inner.lookup.by_eid(dest_eid, Some(port)) else {
427 debug!("No route for recv {}", dest_eid);
428 return ret_src;
429 };
430 drop(inner);
431
432 let Some(top) = self.ports.get(p.0 as usize) else {
433 debug!("Bad port ID from lookup");
434 return ret_src;
435 };
436
437 let _ = top.forward_packet(pkt).await;
438 ret_src
439 }
440
441 async fn incoming_local(
442 &self,
443 tag: Tag,
444 typ: MsgType,
445 handle: ReceiveHandle,
446 ) {
447 trace!("incoming local, type {}", typ.0);
448 if tag.is_owner() {
449 self.incoming_listener(typ, handle).await
450 } else {
451 self.incoming_response(tag, handle).await
452 }
453 }
454
455 async fn incoming_listener(&self, typ: MsgType, handle: ReceiveHandle) {
456 let mut inner = self.inner.lock().await;
457 let mut handle = Some(handle);
458
459 self.app_listeners.lock(|a| {
461 let mut a = a.borrow_mut();
462 for (cookie, entry) in a.iter_mut().enumerate() {
464 if let Some((t, waker)) = entry {
465 trace!("entry. {} vs {}", t.0, typ.0);
466 if *t == typ {
467 let handle = handle.take().unwrap();
469 inner
470 .stack
471 .set_cookie(&handle, Some(AppCookie(cookie)));
472 inner.stack.return_handle(handle);
473 waker.wake();
474 trace!("listener match");
475 break;
476 }
477 }
478 }
479 });
480
481 if let Some(handle) = handle.take() {
482 trace!("listener no match");
483 inner.stack.finished_receive(handle);
484 }
485 }
486
487 async fn incoming_response(&self, _tag: Tag, handle: ReceiveHandle) {
488 let mut inner = self.inner.lock().await;
489 inner.stack.return_handle(handle);
490 inner.app_receive_wakers.wake();
493 }
494
495 fn app_bind(&self, typ: MsgType) -> Result<AppCookie> {
496 self.app_listeners.lock(|a| {
497 let mut a = a.borrow_mut();
498
499 for bind in a.iter() {
501 if bind.as_ref().is_some_and(|(t, _)| *t == typ) {
502 return Err(Error::AddrInUse);
503 }
504 }
505
506 if let Some((i, bind)) =
508 a.iter_mut().enumerate().find(|(_i, bind)| bind.is_none())
509 {
510 *bind = Some((typ, WakerRegistration::new()));
511 return Ok(AppCookie(i));
512 }
513
514 Err(Error::NoSpace)
515 })
516 }
517
518 fn app_unbind(&self, cookie: AppCookie) -> Result<()> {
519 self.app_listeners.lock(|a| {
520 let mut a = a.borrow_mut();
521 let bind = a.get_mut(cookie.0).ok_or(Error::BadArgument)?;
522
523 if bind.is_none() {
524 return Err(Error::BadArgument);
525 }
526
527 *bind = None;
529 Ok(())
532 })
533 }
534
535 async fn app_recv_message<'f>(
540 &self,
541 cookie: Option<AppCookie>,
542 tag_eid: Option<(Tag, Eid)>,
543 buf: &'f mut [u8],
544 ) -> Result<(&'f mut [u8], Eid, MsgType, Tag, MsgIC)> {
545 let mut buf = Some(buf);
547
548 poll_fn(|cx| {
549 let l = self.inner.lock();
551 let l = pin!(l);
552 let mut inner = match l.poll(cx) {
553 Poll::Ready(i) => i,
554 Poll::Pending => return Poll::Pending,
555 };
556
557 trace!("poll recv message");
558
559 let handle = match (cookie, tag_eid) {
562 (Some(cookie), None) => {
564 inner.stack.get_deferred_bycookie(&[cookie])
565 }
566 (None, Some((tag, eid))) => inner.stack.get_deferred(eid, tag),
568 _ => unreachable!(),
570 };
571
572 let Some(handle) = handle else {
573 if let Some(cookie) = cookie {
577 trace!("listener, cookie index {}", cookie.0);
579 self.app_listeners.lock(|a| {
580 let mut a = a.borrow_mut();
581 let Some(bind) = a.get_mut(cookie.0) else {
582 debug_assert!(false, "recv bad cookie");
583 return;
584 };
585 let Some((_typ, waker)) = bind else {
586 debug_assert!(false, "recv no listener");
587 return;
588 };
589 waker.register(cx.waker());
590 });
591 } else {
592 trace!("other recv");
594 inner.app_receive_wakers.register(cx.waker());
595 }
596 trace!("pending");
597 return Poll::Pending;
598 };
599
600 trace!("got handle");
603
604 let msg = inner.stack.fetch_message(&handle);
605
606 let buf = buf.take().unwrap();
608 let res = if msg.payload.len() > buf.len() {
609 trace!("no space");
610 Err(Error::NoSpace)
611 } else {
612 trace!("good len {}", msg.payload.len());
613 let buf = &mut buf[..msg.payload.len()];
614 buf.copy_from_slice(msg.payload);
615 Ok((buf, msg.source, msg.typ, msg.tag, msg.ic))
616 };
617
618 inner.stack.finished_receive(handle);
619 Poll::Ready(res)
620 })
621 .await
622 }
623
624 async fn app_send_message(
628 &self,
629 eid: Eid,
630 typ: MsgType,
631 tag: Option<Tag>,
632 tag_expires: bool,
633 integrity_check: MsgIC,
634 buf: &[&[u8]],
635 cookie: Option<AppCookie>,
636 ) -> Result<Tag> {
637 let mut inner = self.inner.lock().await;
638
639 let Some(p) = inner.lookup.by_eid(eid, None) else {
640 debug!("No route for recv {}", eid);
641 return Err(Error::TxFailure);
642 };
643
644 let Some(top) = self.ports.get(p.0 as usize) else {
645 debug!("Bad port ID from lookup");
646 return Err(Error::TxFailure);
647 };
648
649 let mtu = top.mtu;
650 let mut fragmenter = inner
651 .stack
652 .start_send(
653 eid,
654 typ,
655 tag,
656 tag_expires,
657 integrity_check,
658 Some(mtu),
659 cookie,
660 )
661 .inspect_err(|e| trace!("error fragmenter {}", e))?;
662 drop(inner);
664
665 top.send_message(&mut fragmenter, buf).await
666 }
667
668 async fn app_release_tag(&self, eid: Eid, tag: Tag) {
672 let Tag::Owned(tv) = tag else { unreachable!() };
673 let mut inner = self.inner.lock().await;
674
675 if let Err(e) = inner.stack.cancel_flow(eid, tv) {
676 warn!("flow cancel failed {}", e);
677 }
678 }
679
680 pub fn req(&'r self, eid: Eid) -> RouterAsyncReqChannel<'r> {
682 RouterAsyncReqChannel::new(eid, self)
683 }
684
685 pub fn listener(&'r self, typ: MsgType) -> Result<RouterAsyncListener<'r>> {
689 let cookie = self.app_bind(typ)?;
690 Ok(RouterAsyncListener {
691 cookie,
692 router: self,
693 })
694 }
695
696 pub async fn get_eid(&self) -> Eid {
698 let inner = self.inner.lock().await;
699 inner.stack.own_eid
700 }
701
702 pub async fn set_eid(&self, eid: Eid) -> mctp::Result<()> {
704 let mut inner = self.inner.lock().await;
705 inner.stack.set_eid(eid.0)
706 }
707}
708
709pub struct RouterAsyncReqChannel<'r> {
711 eid: Eid,
712 sent_tag: Option<Tag>,
713 router: &'r Router<'r>,
714 tag_expires: bool,
715}
716
717impl<'r> RouterAsyncReqChannel<'r> {
718 fn new(eid: Eid, router: &'r Router<'r>) -> Self {
719 RouterAsyncReqChannel {
720 eid,
721 sent_tag: None,
722 tag_expires: true,
723 router,
724 }
725 }
726
727 pub fn tag_noexpire(&mut self) -> Result<()> {
731 if self.sent_tag.is_some() {
732 return Err(Error::BadArgument);
733 }
734 self.tag_expires = false;
735 Ok(())
736 }
737
738 pub async fn async_drop(self) {
743 if !self.tag_expires {
744 if let Some(tag) = self.sent_tag {
745 self.router.app_release_tag(self.eid, tag).await;
746 }
747 }
748 }
749}
750
751impl Drop for RouterAsyncReqChannel<'_> {
752 fn drop(&mut self) {
753 if !self.tag_expires && self.sent_tag.is_some() {
754 warn!("Didn't call async_drop()");
755 }
756 }
757}
758
759impl mctp::AsyncReqChannel for RouterAsyncReqChannel<'_> {
763 async fn send_vectored(
771 &mut self,
772 typ: MsgType,
773 integrity_check: MsgIC,
774 bufs: &[&[u8]],
775 ) -> Result<()> {
776 let tag = self
779 .router
780 .app_send_message(
781 self.eid,
782 typ,
783 self.sent_tag,
784 self.tag_expires,
785 integrity_check,
786 bufs,
787 None,
788 )
789 .await?;
790 debug_assert!(matches!(tag, Tag::Owned(_)));
791 self.sent_tag = Some(tag);
792 Ok(())
793 }
794
795 async fn recv<'f>(
796 &mut self,
797 buf: &'f mut [u8],
798 ) -> Result<(MsgType, MsgIC, &'f mut [u8])> {
799 let Some(Tag::Owned(tv)) = self.sent_tag else {
800 debug!("recv without send");
801 return Err(Error::BadArgument);
802 };
803 let recv_tag = Tag::Unowned(tv);
804 let (buf, eid, typ, tag, ic) = self
805 .router
806 .app_recv_message(None, Some((recv_tag, self.eid)), buf)
807 .await?;
808 debug_assert_eq!(tag, recv_tag);
809 debug_assert_eq!(eid, self.eid);
810 Ok((typ, ic, buf))
811 }
812
813 fn remote_eid(&self) -> Eid {
814 self.eid
815 }
816}
817
818pub struct RouterAsyncRespChannel<'r> {
822 eid: Eid,
823 tv: TagValue,
824 router: &'r Router<'r>,
825 typ: MsgType,
826}
827
828impl<'r> mctp::AsyncRespChannel for RouterAsyncRespChannel<'r> {
829 type ReqChannel<'a>
830 = RouterAsyncReqChannel<'r>
831 where
832 Self: 'a;
833
834 async fn send_vectored(
838 &mut self,
839 integrity_check: MsgIC,
840 bufs: &[&[u8]],
841 ) -> Result<()> {
842 let tag = Some(Tag::Unowned(self.tv));
843 self.router
844 .app_send_message(
845 self.eid,
846 self.typ,
847 tag,
848 false,
849 integrity_check,
850 bufs,
851 None,
852 )
853 .await?;
854 Ok(())
855 }
856
857 fn remote_eid(&self) -> Eid {
858 self.eid
859 }
860
861 fn req_channel(&self) -> mctp::Result<Self::ReqChannel<'_>> {
862 Ok(RouterAsyncReqChannel::new(self.eid, self.router))
863 }
864}
865
866pub struct RouterAsyncListener<'r> {
870 router: &'r Router<'r>,
871 cookie: AppCookie,
872}
873
874impl<'r> mctp::AsyncListener for RouterAsyncListener<'r> {
875 type RespChannel<'a>
877 = RouterAsyncRespChannel<'r>
878 where
879 Self: 'a;
880
881 async fn recv<'f>(
882 &mut self,
883 buf: &'f mut [u8],
884 ) -> mctp::Result<(MsgType, MsgIC, &'f mut [u8], Self::RespChannel<'_>)>
885 {
886 let (msg, eid, typ, tag, ic) = self
887 .router
888 .app_recv_message(Some(self.cookie), None, buf)
889 .await?;
890
891 let Tag::Owned(tv) = tag else {
892 debug_assert!(false, "listeners only accept owned tags");
893 return Err(Error::InternalError);
894 };
895
896 let resp = RouterAsyncRespChannel {
897 eid,
898 tv,
899 router: self.router,
900 typ,
901 };
902 Ok((typ, ic, msg, resp))
903 }
904}
905
906impl Drop for RouterAsyncListener<'_> {
907 fn drop(&mut self) {
908 if self.router.app_unbind(self.cookie).is_err() {
909 debug_assert!(false, "bad unbind");
911 }
912 }
913}