use core::{
any::TypeId,
cell::UnsafeCell,
marker::PhantomData,
ops::Deref,
pin::Pin,
ptr::{NonNull, addr_of},
task::{Context, Poll, Waker},
};
use bbq2::{
prod_cons::framed::{FramedConsumer, FramedGrantR},
traits::bbqhdl::BbqHandle,
};
use cordyceps::list::Links;
use postcard::{
Serializer,
ser_flavors::{self, Flavor, Slice},
};
use serde::{Deserialize, Serialize};
use crate::{
HeaderSeq, Key, ProtocolError,
nash::NameHash,
net_stack::NetStackHandle,
socket::{
Attributes, BorSerFn, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable,
},
wire_frames::{self, BorrowedFrame, MAX_HDR_ENCODED_SIZE, de_frame, encode_frame_hdr},
};
#[repr(C)]
pub struct Socket<Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
hdr: SocketHeader,
pub(crate) net: N::Target,
inner: UnsafeCell<QueueBox<Q>>,
mtu: u16,
_pd: PhantomData<fn() -> T>,
}
pub struct SocketHdl<'a, Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
pub(crate) ptr: NonNull<Socket<Q, T, N>>,
_lt: PhantomData<Pin<&'a mut Socket<Q, T, N>>>,
port: u8,
}
pub struct Recv<'a, 'b, Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
hdl: &'a mut SocketHdl<'b, Q, T, N>,
}
pub struct ResponseGrant<Q: BbqHandle, T> {
pub hdr: HeaderSeq,
inner: ResponseGrantInner<Q, T>,
}
struct QueueBox<Q: BbqHandle> {
q: Q,
waker: Option<Waker>,
}
enum ResponseGrantInner<Q: BbqHandle, T> {
Ok {
grant: FramedGrantR<Q, u16>,
offset: usize,
deser_erased: PhantomData<fn() -> T>,
},
Err(ProtocolError),
}
impl<Q, T, N> Socket<Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
pub const fn new(
net: N::Target,
key: Key,
attrs: Attributes,
sto: Q,
mtu: u16,
name: Option<&str>,
) -> Self {
Self {
hdr: SocketHeader {
links: Links::new(),
vtable: const { &Self::vtable() },
port: 0,
attrs,
key,
nash: if let Some(n) = name {
Some(NameHash::new(n))
} else {
None
},
},
inner: UnsafeCell::new(QueueBox {
q: sto,
waker: None,
}),
net,
_pd: PhantomData,
mtu,
}
}
pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
let stack = self.net.clone();
let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
let port = unsafe { stack.attach_socket(ptr_erase) };
SocketHdl {
ptr: ptr_self,
_lt: PhantomData,
port,
}
}
pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
let stack = self.net.clone();
let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
unsafe { stack.attach_broadcast_socket(ptr_erase) };
SocketHdl {
ptr: ptr_self,
_lt: PhantomData,
port: 255,
}
}
const fn vtable() -> SocketVTable {
SocketVTable {
recv_owned: Some(Self::recv_owned),
recv_bor: Some(Self::recv_bor),
recv_raw: Self::recv_raw,
recv_err: Some(Self::recv_err),
}
}
pub fn stack(&self) -> N::Target {
self.net.clone()
}
fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
let this: NonNull<Self> = this.cast();
let this: &Self = unsafe { this.as_ref() };
let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
let qref = qbox.q.bbq_ref();
let prod = qref.framed_producer();
let Ok(mut wgr) = prod.grant(this.mtu) else {
return;
};
let ser = ser_flavors::Slice::new(&mut wgr);
if let Ok(used) = wire_frames::encode_frame_err(ser, &hdr, err) {
let len = used.len() as u16;
wgr.commit(len);
if let Some(wake) = qbox.waker.take() {
wake.wake();
}
}
}
fn recv_owned(
this: NonNull<()>,
that: NonNull<()>,
hdr: HeaderSeq,
_ty: &TypeId,
) -> Result<(), SocketSendError> {
let that: NonNull<T> = that.cast();
let that: &T = unsafe { that.as_ref() };
let this: NonNull<Self> = this.cast();
let this: &Self = unsafe { this.as_ref() };
let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
let qref = qbox.q.bbq_ref();
let prod = qref.framed_producer();
let Ok(mut wgr) = prod.grant(this.mtu) else {
return Err(SocketSendError::NoSpace);
};
let ser = ser_flavors::Slice::new(&mut wgr);
let Ok(used) = wire_frames::encode_frame_ty(ser, &hdr, that) else {
return Err(SocketSendError::NoSpace);
};
let len = used.len() as u16;
wgr.commit(len);
if let Some(wake) = qbox.waker.take() {
wake.wake();
}
Ok(())
}
fn recv_bor(
this: NonNull<()>,
that: NonNull<()>,
hdr: HeaderSeq,
serfn: BorSerFn,
) -> Result<(), SocketSendError> {
let this: NonNull<Self> = this.cast();
let this: &Self = unsafe { this.as_ref() };
let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
let qref = qbox.q.bbq_ref();
let prod = qref.framed_producer();
let Ok(mut wgr) = prod.grant(this.mtu) else {
return Err(SocketSendError::NoSpace);
};
let used = serfn(that, hdr, &mut wgr)?;
let len = used as u16;
wgr.commit(len);
if let Some(wake) = qbox.waker.take() {
wake.wake();
}
Ok(())
}
fn recv_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
let this: NonNull<Self> = this.cast();
let this: &Self = unsafe { this.as_ref() };
let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
let qref = qbox.q.bbq_ref();
let prod = qref.framed_producer();
let mut buf = [0u8; MAX_HDR_ENCODED_SIZE];
let mut ser = Serializer {
output: Slice::new(&mut buf),
};
let Ok(()) = encode_frame_hdr(&mut ser, &hdr) else {
log::error!("Encoding of HeaderSeq should never fail. This is a bug.");
return Err(SocketSendError::WhatTheHell);
};
let Ok(hdr_used) = ser.output.finalize() else {
unreachable!("Slice finalization should never error");
};
let Ok(needed) = u16::try_from(that.len() + hdr_used.len()) else {
return Err(SocketSendError::NoSpace);
};
let Ok(mut wgr) = prod.grant(needed) else {
return Err(SocketSendError::NoSpace);
};
let (hdr, body) = wgr.split_at_mut(hdr_used.len());
hdr.copy_from_slice(hdr_used);
body.copy_from_slice(that);
wgr.commit(needed);
if let Some(wake) = qbox.waker.take() {
wake.wake();
}
Ok(())
}
}
impl<'a, Q, T, N> SocketHdl<'a, Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
pub fn port(&self) -> u8 {
self.port
}
pub fn stack(&self) -> N::Target {
unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
}
pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, Q, T, N> {
Recv { hdl: self }
}
}
impl<Q, T, N> Drop for Socket<Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
fn drop(&mut self) {
unsafe {
let this = NonNull::from(&self.hdr);
self.net.detach_socket(this);
}
}
}
unsafe impl<Q, T, N> Send for SocketHdl<'_, Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
}
unsafe impl<Q, T, N> Sync for SocketHdl<'_, Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
}
impl<'a, Q, T, N> Future for Recv<'a, '_, Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
type Output = ResponseGrant<Q, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let net: N::Target = self.hdl.stack();
let f = || -> Option<ResponseGrant<Q, T>> {
let this_ref: &Socket<Q, T, N> = unsafe { self.hdl.ptr.as_ref() };
let qbox: &mut QueueBox<Q> = unsafe { &mut *this_ref.inner.get() };
let cons: FramedConsumer<Q, u16> = qbox.q.framed_consumer();
if let Ok(resp) = cons.read() {
let sli: &[u8] = resp.deref();
if let Some(frame) = de_frame(sli) {
let BorrowedFrame { hdr, body } = frame;
match body {
Ok(body) => {
let sli: &[u8] = body;
let offset = (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
return Some(ResponseGrant {
hdr,
inner: ResponseGrantInner::Ok {
grant: resp,
offset,
deser_erased: PhantomData,
},
});
}
Err(err) => {
resp.release();
return Some(ResponseGrant {
hdr,
inner: ResponseGrantInner::Err(err),
});
}
}
}
}
let new_wake = cx.waker();
if let Some(w) = qbox.waker.take()
&& !w.will_wake(new_wake)
{
w.wake();
}
qbox.waker = Some(new_wake.clone());
None
};
let res = unsafe { net.with_lock(f) };
if let Some(t) = res {
Poll::Ready(t)
} else {
Poll::Pending
}
}
}
unsafe impl<Q, T, N> Sync for Recv<'_, '_, Q, T, N>
where
Q: BbqHandle,
T: Serialize,
N: NetStackHandle,
{
}
impl<Q: BbqHandle, T> ResponseGrant<Q, T> {
pub fn try_access<'de, 'me: 'de>(&'me self) -> Option<Response<T>>
where
T: Deserialize<'de>,
{
Some(match &self.inner {
ResponseGrantInner::Ok {
grant,
deser_erased: _,
offset,
} => {
let t = postcard::from_bytes::<T>(grant.get(*offset..)?).ok()?;
Response::Ok(HeaderMessage {
hdr: self.hdr.clone(),
t,
})
}
ResponseGrantInner::Err(protocol_error) => Response::Err(HeaderMessage {
hdr: self.hdr.clone(),
t: *protocol_error,
}),
})
}
}
impl<Q: BbqHandle, T> Drop for ResponseGrant<Q, T> {
fn drop(&mut self) {
let old = core::mem::replace(
&mut self.inner,
ResponseGrantInner::Err(ProtocolError(u16::MAX)),
);
match old {
ResponseGrantInner::Ok { grant, .. } => {
grant.release();
}
ResponseGrantInner::Err(_) => {}
}
}
}