#[allow(unused)]
use crate::fmt::{debug, error, info, trace, warn};
use core::cell::RefCell;
use core::future::{poll_fn, Future};
use core::pin::pin;
use core::task::Poll;
use crate::reassemble::Reassembler;
use crate::{
AppCookie, Fragmenter, ReceiveHandle, SendOutput, Stack, MAX_MTU,
MAX_PAYLOAD,
};
use mctp::{Eid, Error, MsgIC, MsgType, Result, Tag, TagValue};
use embassy_sync::waitqueue::{MultiWakerRegistration, WakerRegistration};
use embassy_sync::zerocopy_channel::{Channel, Receiver, Sender};
use heapless::Vec;
const MAX_LISTENERS: usize = 20;
const MAX_RECEIVERS: usize = 50;
type RawMutex = embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
type AsyncMutex<T> = embassy_sync::mutex::Mutex<RawMutex, T>;
type BlockingMutex<T> =
embassy_sync::blocking_mutex::Mutex<RawMutex, RefCell<T>>;
type PortRawMutex = embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct PortId(pub u8);
pub trait PortLookup: Send {
fn by_eid(
&mut self,
eid: Eid,
source_port: Option<PortId>,
) -> Option<PortId>;
}
struct PktBuf {
data: [u8; MAX_MTU],
len: usize,
dest: Eid,
}
impl PktBuf {
const fn new() -> Self {
Self {
data: [0u8; MAX_MTU],
len: 0,
dest: Eid(0),
}
}
fn set(&mut self, data: &[u8]) -> Result<()> {
let hdr = Reassembler::header(data);
debug_assert!(hdr.is_ok());
let hdr = hdr?;
let dst = self.data.get_mut(..data.len()).ok_or(Error::NoSpace)?;
dst.copy_from_slice(data);
self.len = data.len();
self.dest = Eid(hdr.dest_endpoint_id());
Ok(())
}
}
impl core::ops::Deref for PktBuf {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.data[..self.len]
}
}
pub struct PortTop<'a> {
packets: AsyncMutex<Sender<'a, PortRawMutex, PktBuf>>,
message: AsyncMutex<Vec<u8, MAX_PAYLOAD>>,
mtu: usize,
}
impl PortTop<'_> {
async fn forward_packet(&self, pkt: &[u8]) -> Result<()> {
debug_assert!(Reassembler::header(pkt).is_ok());
let mut sender = self.packets.lock().await;
if pkt.len() > self.mtu {
debug!("Forward packet too large");
return Err(Error::NoSpace);
}
let slot = sender.try_send().ok_or_else(|| {
debug!("Dropped forward packet");
Error::TxFailure
})?;
slot.set(pkt).unwrap();
sender.send_done();
Ok(())
}
async fn send_message(
&self,
fragmenter: &mut Fragmenter,
pkt: &[&[u8]],
) -> Result<Tag> {
trace!("send_message");
let mut msg;
let payload = if pkt.len() == 1 {
pkt[0]
} else {
msg = self.message.lock().await;
msg.clear();
for p in pkt {
msg.extend_from_slice(p).map_err(|_| {
debug!("Message too large");
Error::NoSpace
})?;
}
&msg
};
loop {
let mut sender = self.packets.lock().await;
let qpkt = sender.send().await;
qpkt.len = 0;
qpkt.dest = fragmenter.dest();
let r = fragmenter.fragment(payload, &mut qpkt.data);
match r {
SendOutput::Packet(p) => {
qpkt.len = p.len();
sender.send_done();
if fragmenter.is_done() {
break Ok(fragmenter.tag());
}
}
SendOutput::Error { err, .. } => {
debug!("Error packetising");
sender.send_done();
break Err(err);
}
SendOutput::Complete { .. } => unreachable!(),
}
}
}
}
pub struct PortBottom<'a> {
packets: Receiver<'a, PortRawMutex, PktBuf>,
}
impl PortBottom<'_> {
pub async fn outbound(&mut self) -> (&[u8], Eid) {
if self.packets.len() > 1 {
trace!("packets avail {}", self.packets.len());
}
let pkt = self.packets.receive().await;
(pkt, pkt.dest)
}
pub fn try_outbound(&mut self) -> Option<(&[u8], Eid)> {
trace!("packets avail {} try", self.packets.len());
self.packets.try_receive().map(|pkt| (&**pkt, pkt.dest))
}
pub fn outbound_done(&mut self) {
self.packets.receive_done()
}
}
pub struct PortStorage<const FORWARD_QUEUE: usize = 4> {
packets: [PktBuf; FORWARD_QUEUE],
}
impl<const FORWARD_QUEUE: usize> PortStorage<FORWARD_QUEUE> {
pub fn new() -> Self {
Self {
packets: [const { PktBuf::new() }; FORWARD_QUEUE],
}
}
}
impl<const FORWARD_QUEUE: usize> Default for PortStorage<FORWARD_QUEUE> {
fn default() -> Self {
Self::new()
}
}
pub struct PortBuilder<'a> {
packets: Channel<'a, PortRawMutex, PktBuf>,
}
impl<'a> PortBuilder<'a> {
pub fn new<const FORWARD_QUEUE: usize>(
storage: &'a mut PortStorage<FORWARD_QUEUE>,
) -> Self {
Self {
packets: Channel::new(storage.packets.as_mut_slice()),
}
}
pub fn build(
&mut self,
mtu: usize,
) -> Result<(PortTop<'_>, PortBottom<'_>)> {
if mtu > MAX_MTU {
debug!("port mtu {} > MAX_MTU {}", mtu, MAX_MTU);
return Err(Error::BadArgument);
}
let (ps, pr) = self.packets.split();
let t = PortTop {
message: AsyncMutex::new(Vec::new()),
packets: AsyncMutex::new(ps),
mtu,
};
let b = PortBottom { packets: pr };
Ok((t, b))
}
}
pub struct Router<'r> {
inner: AsyncMutex<RouterInner<'r>>,
ports: &'r [PortTop<'r>],
app_listeners:
BlockingMutex<[Option<(MsgType, WakerRegistration)>; MAX_LISTENERS]>,
}
pub struct RouterInner<'r> {
stack: Stack,
app_receive_wakers: MultiWakerRegistration<MAX_RECEIVERS>,
lookup: &'r mut dyn PortLookup,
}
impl<'r> Router<'r> {
pub fn new(
stack: Stack,
ports: &'r [PortTop<'r>],
lookup: &'r mut dyn PortLookup,
) -> Self {
let inner = RouterInner {
stack,
app_receive_wakers: MultiWakerRegistration::new(),
lookup,
};
Self {
inner: AsyncMutex::new(inner),
app_listeners: BlockingMutex::new(RefCell::new(
[const { None }; MAX_LISTENERS],
)),
ports,
}
}
pub async fn update_time(&self, now_millis: u64) -> Result<u64> {
let mut inner = self.inner.lock().await;
let (next, expired) = inner.stack.update(now_millis)?;
if expired {
inner.app_receive_wakers.wake();
}
Ok(next)
}
pub async fn inbound(&self, pkt: &[u8], port: PortId) -> Option<Eid> {
let mut inner = self.inner.lock().await;
let Ok(header) = Reassembler::header(pkt) else {
return None;
};
let ret_src = Some(Eid(header.source_endpoint_id()));
if inner.stack.is_local_dest(pkt) {
match inner.stack.receive(pkt) {
Ok(Some((msg, handle))) => {
let typ = msg.typ;
let tag = msg.tag;
drop(inner);
self.incoming_local(tag, typ, handle).await;
return ret_src;
}
Ok(None) => {
return ret_src;
}
Err(e) => {
debug!("Dropped local recv packet. {}", e);
return ret_src;
}
}
}
let dest_eid = Eid(header.dest_endpoint_id());
let Some(p) = inner.lookup.by_eid(dest_eid, Some(port)) else {
debug!("No route for recv {}", dest_eid);
return ret_src;
};
drop(inner);
let Some(top) = self.ports.get(p.0 as usize) else {
debug!("Bad port ID from lookup");
return ret_src;
};
let _ = top.forward_packet(pkt).await;
ret_src
}
async fn incoming_local(
&self,
tag: Tag,
typ: MsgType,
handle: ReceiveHandle,
) {
trace!("incoming local, type {}", typ.0);
if tag.is_owner() {
self.incoming_listener(typ, handle).await
} else {
self.incoming_response(tag, handle).await
}
}
async fn incoming_listener(&self, typ: MsgType, handle: ReceiveHandle) {
let mut inner = self.inner.lock().await;
let mut handle = Some(handle);
self.app_listeners.lock(|a| {
let mut a = a.borrow_mut();
for (cookie, entry) in a.iter_mut().enumerate() {
if let Some((t, waker)) = entry {
trace!("entry. {} vs {}", t.0, typ.0);
if *t == typ {
let handle = handle.take().unwrap();
inner
.stack
.set_cookie(&handle, Some(AppCookie(cookie)));
inner.stack.return_handle(handle);
waker.wake();
trace!("listener match");
break;
}
}
}
});
if let Some(handle) = handle.take() {
trace!("listener no match");
inner.stack.finished_receive(handle);
}
}
async fn incoming_response(&self, _tag: Tag, handle: ReceiveHandle) {
let mut inner = self.inner.lock().await;
inner.stack.return_handle(handle);
inner.app_receive_wakers.wake();
}
fn app_bind(&self, typ: MsgType) -> Result<AppCookie> {
self.app_listeners.lock(|a| {
let mut a = a.borrow_mut();
for bind in a.iter() {
if bind.as_ref().is_some_and(|(t, _)| *t == typ) {
return Err(Error::AddrInUse);
}
}
if let Some((i, bind)) =
a.iter_mut().enumerate().find(|(_i, bind)| bind.is_none())
{
*bind = Some((typ, WakerRegistration::new()));
return Ok(AppCookie(i));
}
Err(Error::NoSpace)
})
}
fn app_unbind(&self, cookie: AppCookie) -> Result<()> {
self.app_listeners.lock(|a| {
let mut a = a.borrow_mut();
let bind = a.get_mut(cookie.0).ok_or(Error::BadArgument)?;
if bind.is_none() {
return Err(Error::BadArgument);
}
*bind = None;
Ok(())
})
}
async fn app_recv_message<'f>(
&self,
cookie: Option<AppCookie>,
tag_eid: Option<(Tag, Eid)>,
buf: &'f mut [u8],
) -> Result<(&'f mut [u8], Eid, MsgType, Tag, MsgIC)> {
let mut buf = Some(buf);
poll_fn(|cx| {
let l = self.inner.lock();
let l = pin!(l);
let mut inner = match l.poll(cx) {
Poll::Ready(i) => i,
Poll::Pending => return Poll::Pending,
};
trace!("poll recv message");
let handle = match (cookie, tag_eid) {
(Some(cookie), None) => {
inner.stack.get_deferred_bycookie(&[cookie])
}
(None, Some((tag, eid))) => inner.stack.get_deferred(eid, tag),
_ => unreachable!(),
};
let Some(handle) = handle else {
if let Some(cookie) = cookie {
trace!("listener, cookie index {}", cookie.0);
self.app_listeners.lock(|a| {
let mut a = a.borrow_mut();
let Some(bind) = a.get_mut(cookie.0) else {
debug_assert!(false, "recv bad cookie");
return;
};
let Some((_typ, waker)) = bind else {
debug_assert!(false, "recv no listener");
return;
};
waker.register(cx.waker());
});
} else {
trace!("other recv");
inner.app_receive_wakers.register(cx.waker());
}
trace!("pending");
return Poll::Pending;
};
trace!("got handle");
let msg = inner.stack.fetch_message(&handle);
let buf = buf.take().unwrap();
let res = if msg.payload.len() > buf.len() {
trace!("no space");
Err(Error::NoSpace)
} else {
trace!("good len {}", msg.payload.len());
let buf = &mut buf[..msg.payload.len()];
buf.copy_from_slice(msg.payload);
Ok((buf, msg.source, msg.typ, msg.tag, msg.ic))
};
inner.stack.finished_receive(handle);
Poll::Ready(res)
})
.await
}
async fn app_send_message(
&self,
eid: Eid,
typ: MsgType,
tag: Option<Tag>,
tag_expires: bool,
integrity_check: MsgIC,
buf: &[&[u8]],
cookie: Option<AppCookie>,
) -> Result<Tag> {
let mut inner = self.inner.lock().await;
let Some(p) = inner.lookup.by_eid(eid, None) else {
debug!("No route for recv {}", eid);
return Err(Error::TxFailure);
};
let Some(top) = self.ports.get(p.0 as usize) else {
debug!("Bad port ID from lookup");
return Err(Error::TxFailure);
};
let mtu = top.mtu;
let mut fragmenter = inner
.stack
.start_send(
eid,
typ,
tag,
tag_expires,
integrity_check,
Some(mtu),
cookie,
)
.inspect_err(|e| trace!("error fragmenter {}", e))?;
drop(inner);
top.send_message(&mut fragmenter, buf).await
}
async fn app_release_tag(&self, eid: Eid, tag: Tag) {
let Tag::Owned(tv) = tag else { unreachable!() };
let mut inner = self.inner.lock().await;
if let Err(e) = inner.stack.cancel_flow(eid, tv) {
warn!("flow cancel failed {}", e);
}
}
pub fn req(&'r self, eid: Eid) -> RouterAsyncReqChannel<'r> {
RouterAsyncReqChannel::new(eid, self)
}
pub fn listener(&'r self, typ: MsgType) -> Result<RouterAsyncListener<'r>> {
let cookie = self.app_bind(typ)?;
Ok(RouterAsyncListener {
cookie,
router: self,
})
}
pub async fn get_eid(&self) -> Eid {
let inner = self.inner.lock().await;
inner.stack.own_eid
}
pub async fn set_eid(&self, eid: Eid) -> mctp::Result<()> {
let mut inner = self.inner.lock().await;
inner.stack.set_eid(eid.0)
}
}
pub struct RouterAsyncReqChannel<'r> {
eid: Eid,
sent_tag: Option<Tag>,
router: &'r Router<'r>,
tag_expires: bool,
}
impl<'r> RouterAsyncReqChannel<'r> {
fn new(eid: Eid, router: &'r Router<'r>) -> Self {
RouterAsyncReqChannel {
eid,
sent_tag: None,
tag_expires: true,
router,
}
}
pub fn tag_noexpire(&mut self) -> Result<()> {
if self.sent_tag.is_some() {
return Err(Error::BadArgument);
}
self.tag_expires = false;
Ok(())
}
pub async fn async_drop(self) {
if !self.tag_expires {
if let Some(tag) = self.sent_tag {
self.router.app_release_tag(self.eid, tag).await;
}
}
}
}
impl Drop for RouterAsyncReqChannel<'_> {
fn drop(&mut self) {
if !self.tag_expires && self.sent_tag.is_some() {
warn!("Didn't call async_drop()");
}
}
}
impl mctp::AsyncReqChannel for RouterAsyncReqChannel<'_> {
async fn send_vectored(
&mut self,
typ: MsgType,
integrity_check: MsgIC,
bufs: &[&[u8]],
) -> Result<()> {
let tag = self
.router
.app_send_message(
self.eid,
typ,
self.sent_tag,
self.tag_expires,
integrity_check,
bufs,
None,
)
.await?;
debug_assert!(matches!(tag, Tag::Owned(_)));
self.sent_tag = Some(tag);
Ok(())
}
async fn recv<'f>(
&mut self,
buf: &'f mut [u8],
) -> Result<(MsgType, MsgIC, &'f mut [u8])> {
let Some(Tag::Owned(tv)) = self.sent_tag else {
debug!("recv without send");
return Err(Error::BadArgument);
};
let recv_tag = Tag::Unowned(tv);
let (buf, eid, typ, tag, ic) = self
.router
.app_recv_message(None, Some((recv_tag, self.eid)), buf)
.await?;
debug_assert_eq!(tag, recv_tag);
debug_assert_eq!(eid, self.eid);
Ok((typ, ic, buf))
}
fn remote_eid(&self) -> Eid {
self.eid
}
}
pub struct RouterAsyncRespChannel<'r> {
eid: Eid,
tv: TagValue,
router: &'r Router<'r>,
typ: MsgType,
}
impl<'r> mctp::AsyncRespChannel for RouterAsyncRespChannel<'r> {
type ReqChannel<'a>
= RouterAsyncReqChannel<'r>
where
Self: 'a;
async fn send_vectored(
&mut self,
integrity_check: MsgIC,
bufs: &[&[u8]],
) -> Result<()> {
let tag = Some(Tag::Unowned(self.tv));
self.router
.app_send_message(
self.eid,
self.typ,
tag,
false,
integrity_check,
bufs,
None,
)
.await?;
Ok(())
}
fn remote_eid(&self) -> Eid {
self.eid
}
fn req_channel(&self) -> mctp::Result<Self::ReqChannel<'_>> {
Ok(RouterAsyncReqChannel::new(self.eid, self.router))
}
}
pub struct RouterAsyncListener<'r> {
router: &'r Router<'r>,
cookie: AppCookie,
}
impl<'r> mctp::AsyncListener for RouterAsyncListener<'r> {
type RespChannel<'a>
= RouterAsyncRespChannel<'r>
where
Self: 'a;
async fn recv<'f>(
&mut self,
buf: &'f mut [u8],
) -> mctp::Result<(MsgType, MsgIC, &'f mut [u8], Self::RespChannel<'_>)>
{
let (msg, eid, typ, tag, ic) = self
.router
.app_recv_message(Some(self.cookie), None, buf)
.await?;
let Tag::Owned(tv) = tag else {
debug_assert!(false, "listeners only accept owned tags");
return Err(Error::InternalError);
};
let resp = RouterAsyncRespChannel {
eid,
tv,
router: self.router,
typ,
};
Ok((typ, ic, msg, resp))
}
}
impl Drop for RouterAsyncListener<'_> {
fn drop(&mut self) {
if self.router.app_unbind(self.cookie).is_err() {
debug_assert!(false, "bad unbind");
}
}
}