1use core::{
16 any::TypeId,
17 cell::UnsafeCell,
18 marker::PhantomData,
19 ops::Deref,
20 pin::Pin,
21 ptr::{NonNull, addr_of},
22 task::{Context, Poll, Waker},
23};
24
25use bbq2::{
26 prod_cons::framed::{FramedConsumer, FramedGrantR},
27 traits::bbqhdl::BbqHandle,
28};
29use cordyceps::list::Links;
30use postcard::ser_flavors;
31use serde::{Deserialize, Serialize};
32
33use crate::{
34 HeaderSeq, Key, ProtocolError,
35 nash::NameHash,
36 net_stack::NetStackHandle,
37 wire_frames::{self, BorrowedFrame, CommonHeader, de_frame},
38};
39
40use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
41
42#[repr(C)]
43pub struct Socket<Q, T, N>
44where
45 Q: BbqHandle,
46 T: Serialize,
47 N: NetStackHandle,
48{
49 hdr: SocketHeader,
51 pub(crate) net: N::Target,
52 inner: UnsafeCell<QueueBox<Q>>,
53 mtu: u16,
54 _pd: PhantomData<fn() -> T>,
55}
56
57pub struct SocketHdl<'a, Q, T, N>
58where
59 Q: BbqHandle,
60 T: Serialize,
61 N: NetStackHandle,
62{
63 pub(crate) ptr: NonNull<Socket<Q, T, N>>,
64 _lt: PhantomData<Pin<&'a mut Socket<Q, T, N>>>,
65 port: u8,
66}
67
68pub struct Recv<'a, 'b, Q, T, N>
69where
70 Q: BbqHandle,
71 T: Serialize,
72 N: NetStackHandle,
73{
74 hdl: &'a mut SocketHdl<'b, Q, T, N>,
75}
76
77pub struct ResponseGrant<Q: BbqHandle, T> {
78 pub hdr: HeaderSeq,
79 inner: ResponseGrantInner<Q, T>,
80}
81
82struct QueueBox<Q: BbqHandle> {
83 q: Q,
84 waker: Option<Waker>,
85}
86
87enum ResponseGrantInner<Q: BbqHandle, T> {
88 Ok {
89 grant: FramedGrantR<Q, u16>,
90 offset: usize,
91 deser_erased: PhantomData<fn() -> T>,
92 },
93 Err(ProtocolError),
94}
95
96impl<Q, T, N> Socket<Q, T, N>
101where
102 Q: BbqHandle,
103 T: Serialize,
104 N: NetStackHandle,
105{
106 pub const fn new(
107 net: N::Target,
108 key: Key,
109 attrs: Attributes,
110 sto: Q,
111 mtu: u16,
112 name: Option<&str>,
113 ) -> Self {
114 Self {
115 hdr: SocketHeader {
116 links: Links::new(),
117 vtable: const { &Self::vtable() },
118 port: 0,
119 attrs,
120 key,
121 nash: if let Some(n) = name {
122 Some(NameHash::new(n))
123 } else {
124 None
125 },
126 },
127 inner: UnsafeCell::new(QueueBox {
128 q: sto,
129 waker: None,
130 }),
131 net,
132 _pd: PhantomData,
133 mtu,
134 }
135 }
136
137 pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
138 let stack = self.net.clone();
139 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
140 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
141 let port = unsafe { stack.attach_socket(ptr_erase) };
142 SocketHdl {
143 ptr: ptr_self,
144 _lt: PhantomData,
145 port,
146 }
147 }
148
149 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
150 let stack = self.net.clone();
151 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
152 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
153 unsafe { stack.attach_broadcast_socket(ptr_erase) };
154 SocketHdl {
155 ptr: ptr_self,
156 _lt: PhantomData,
157 port: 255,
158 }
159 }
160
161 const fn vtable() -> SocketVTable {
162 SocketVTable {
163 recv_owned: Some(Self::recv_owned),
164 recv_bor: Some(Self::recv_bor),
165 recv_raw: Self::recv_raw,
166 recv_err: Some(Self::recv_err),
167 }
168 }
169
170 pub fn stack(&self) -> N::Target {
171 self.net.clone()
172 }
173
174 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
175 let this: NonNull<Self> = this.cast();
176 let this: &Self = unsafe { this.as_ref() };
177 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
178 let qref = qbox.q.bbq_ref();
179 let prod = qref.framed_producer();
180
181 let Ok(mut wgr) = prod.grant(this.mtu) else {
184 return;
185 };
186
187 let ser = ser_flavors::Slice::new(&mut wgr);
188
189 let chdr = CommonHeader {
190 src: hdr.src,
191 dst: hdr.dst,
192 seq_no: hdr.seq_no,
193 kind: hdr.kind,
194 ttl: hdr.ttl,
195 };
196
197 if let Ok(used) = wire_frames::encode_frame_err(ser, &chdr, err) {
198 let len = used.len() as u16;
199 wgr.commit(len);
200 if let Some(wake) = qbox.waker.take() {
201 wake.wake();
202 }
203 }
204 }
205
206 fn recv_owned(
207 this: NonNull<()>,
208 that: NonNull<()>,
209 hdr: HeaderSeq,
210 _ty: &TypeId,
213 ) -> Result<(), SocketSendError> {
214 let that: NonNull<T> = that.cast();
215 let that: &T = unsafe { that.as_ref() };
216 let this: NonNull<Self> = this.cast();
217 let this: &Self = unsafe { this.as_ref() };
218 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
219 let qref = qbox.q.bbq_ref();
220 let prod = qref.framed_producer();
221
222 let Ok(mut wgr) = prod.grant(this.mtu) else {
223 return Err(SocketSendError::NoSpace);
224 };
225 let ser = ser_flavors::Slice::new(&mut wgr);
226
227 let chdr = CommonHeader {
228 src: hdr.src,
229 dst: hdr.dst,
230 seq_no: hdr.seq_no,
231 kind: hdr.kind,
232 ttl: hdr.ttl,
233 };
234
235 let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
236 return Err(SocketSendError::NoSpace);
237 };
238
239 let len = used.len() as u16;
240 wgr.commit(len);
241
242 if let Some(wake) = qbox.waker.take() {
243 wake.wake();
244 }
245
246 Ok(())
247 }
248
249 fn recv_bor(
250 this: NonNull<()>,
251 that: NonNull<()>,
252 hdr: HeaderSeq,
253 serfn: fn(NonNull<()>, HeaderSeq, &mut [u8]) -> Result<usize, SocketSendError>,
254 ) -> Result<(), SocketSendError> {
255 let this: NonNull<Self> = this.cast();
256 let this: &Self = unsafe { this.as_ref() };
257 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
258 let qref = qbox.q.bbq_ref();
259 let prod = qref.framed_producer();
260
261 let Ok(mut wgr) = prod.grant(this.mtu) else {
262 return Err(SocketSendError::NoSpace);
263 };
264
265 let used = serfn(that, hdr, &mut wgr)?;
266 let len = used as u16;
267 wgr.commit(len);
268
269 if let Some(wake) = qbox.waker.take() {
270 wake.wake();
271 }
272
273 Ok(())
274 }
275
276 fn recv_raw(
277 this: NonNull<()>,
278 that: &[u8],
279 _hdr: HeaderSeq,
280 hdr_raw: &[u8],
281 ) -> Result<(), SocketSendError> {
282 let this: NonNull<Self> = this.cast();
283 let this: &Self = unsafe { this.as_ref() };
284 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
285 let qref = qbox.q.bbq_ref();
286 let prod = qref.framed_producer();
287
288 let Ok(needed) = u16::try_from(that.len() + hdr_raw.len()) else {
289 return Err(SocketSendError::NoSpace);
290 };
291
292 let Ok(mut wgr) = prod.grant(needed) else {
293 return Err(SocketSendError::NoSpace);
294 };
295 let (hdr, body) = wgr.split_at_mut(hdr_raw.len());
296 hdr.copy_from_slice(hdr_raw);
297 body.copy_from_slice(that);
298 wgr.commit(needed);
299
300 if let Some(wake) = qbox.waker.take() {
301 wake.wake();
302 }
303
304 Ok(())
305 }
306}
307
308impl<'a, Q, T, N> SocketHdl<'a, Q, T, N>
311where
312 Q: BbqHandle,
313 T: Serialize,
314 N: NetStackHandle,
315{
316 pub fn port(&self) -> u8 {
317 self.port
318 }
319
320 pub fn stack(&self) -> N::Target {
321 unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
322 }
323
324 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, Q, T, N> {
325 Recv { hdl: self }
326 }
327}
328
329impl<Q, T, N> Drop for Socket<Q, T, N>
330where
331 Q: BbqHandle,
332 T: Serialize,
333 N: NetStackHandle,
334{
335 fn drop(&mut self) {
336 unsafe {
337 let this = NonNull::from(&self.hdr);
338 self.net.detach_socket(this);
339 }
340 }
341}
342
343unsafe impl<Q, T, N> Send for SocketHdl<'_, Q, T, N>
344where
345 Q: BbqHandle,
346 T: Serialize,
347 N: NetStackHandle,
348{
349}
350
351unsafe impl<Q, T, N> Sync for SocketHdl<'_, Q, T, N>
352where
353 Q: BbqHandle,
354 T: Serialize,
355 N: NetStackHandle,
356{
357}
358
359impl<'a, Q, T, N> Future for Recv<'a, '_, Q, T, N>
362where
363 Q: BbqHandle,
364 T: Serialize,
365 N: NetStackHandle,
366{
367 type Output = ResponseGrant<Q, T>;
368
369 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
370 let net: N::Target = self.hdl.stack();
371 let f = || -> Option<ResponseGrant<Q, T>> {
372 let this_ref: &Socket<Q, T, N> = unsafe { self.hdl.ptr.as_ref() };
373 let qbox: &mut QueueBox<Q> = unsafe { &mut *this_ref.inner.get() };
374 let cons: FramedConsumer<Q, u16> = qbox.q.framed_consumer();
375
376 if let Ok(resp) = cons.read() {
377 let sli: &[u8] = resp.deref();
378
379 if let Some(frame) = de_frame(sli) {
380 let BorrowedFrame {
381 hdr,
382 body,
383 hdr_raw: _,
384 } = frame;
385 match body {
386 Ok(body) => {
387 let sli: &[u8] = body;
388 let offset = (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
406 return Some(ResponseGrant {
407 hdr,
408 inner: ResponseGrantInner::Ok {
409 grant: resp,
410 offset,
411 deser_erased: PhantomData,
412 },
413 });
414 }
415 Err(err) => {
416 resp.release();
417 return Some(ResponseGrant {
418 hdr,
419 inner: ResponseGrantInner::Err(err),
420 });
421 }
422 }
423 }
424 }
425
426 let new_wake = cx.waker();
427 if let Some(w) = qbox.waker.take() {
428 if !w.will_wake(new_wake) {
429 w.wake();
430 }
431 }
432 qbox.waker = Some(new_wake.clone());
435 None
436 };
437 let res = unsafe { net.with_lock(f) };
438 if let Some(t) = res {
439 Poll::Ready(t)
440 } else {
441 Poll::Pending
442 }
443 }
444}
445
446unsafe impl<Q, T, N> Sync for Recv<'_, '_, Q, T, N>
447where
448 Q: BbqHandle,
449 T: Serialize,
450 N: NetStackHandle,
451{
452}
453
454impl<Q: BbqHandle, T> ResponseGrant<Q, T> {
457 pub fn try_access<'de, 'me: 'de>(&'me self) -> Option<Response<T>>
462 where
463 T: Deserialize<'de>,
464 {
465 Some(match &self.inner {
466 ResponseGrantInner::Ok {
467 grant,
468 deser_erased: _,
469 offset,
470 } => {
471 let t = postcard::from_bytes::<T>(grant.get(*offset..)?).ok()?;
473 Response::Ok(HeaderMessage {
474 hdr: self.hdr.clone(),
475 t,
476 })
477 }
478 ResponseGrantInner::Err(protocol_error) => Response::Err(HeaderMessage {
479 hdr: self.hdr.clone(),
480 t: *protocol_error,
481 }),
482 })
483 }
484}
485
486impl<Q: BbqHandle, T> Drop for ResponseGrant<Q, T> {
487 fn drop(&mut self) {
488 let old = core::mem::replace(
489 &mut self.inner,
490 ResponseGrantInner::Err(ProtocolError(u16::MAX)),
491 );
492 match old {
493 ResponseGrantInner::Ok { grant, .. } => {
494 grant.release();
495 }
496 ResponseGrantInner::Err(_) => {}
497 }
498 }
499}