1use core::{
16 any::TypeId,
17 cell::UnsafeCell,
18 marker::PhantomData,
19 ops::Deref,
20 pin::Pin,
21 ptr::{NonNull, addr_of},
22 task::{Context, Poll, Waker},
23};
24
25use bbq2::{
26 prod_cons::framed::{FramedConsumer, FramedGrantR},
27 traits::bbqhdl::BbqHandle,
28};
29use cordyceps::list::Links;
30use mutex::ScopedRawMutex;
31use postcard::ser_flavors;
32use serde::{Deserialize, Serialize};
33
34use crate::{
35 HeaderSeq, Key, NetStack, ProtocolError,
36 interface_manager::{
37 BorrowedFrame, InterfaceManager,
38 wire_frames::{self, CommonHeader, de_frame},
39 },
40 nash::NameHash,
41};
42
43use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
44
45struct QueueBox<Q: BbqHandle> {
46 q: Q,
47 waker: Option<Waker>,
48}
49
50#[repr(C)]
51pub struct Socket<Q, T, R, M>
52where
53 Q: BbqHandle,
54 T: Serialize + Clone,
55 R: ScopedRawMutex + 'static,
56 M: InterfaceManager + 'static,
57{
58 hdr: SocketHeader,
60 pub(crate) net: &'static NetStack<R, M>,
61 inner: UnsafeCell<QueueBox<Q>>,
62 mtu: u16,
63 _pd: PhantomData<fn() -> T>,
64}
65
66pub struct SocketHdl<'a, Q, T, R, M>
67where
68 Q: BbqHandle,
69 T: Serialize + Clone,
70 R: ScopedRawMutex + 'static,
71 M: InterfaceManager + 'static,
72{
73 pub(crate) ptr: NonNull<Socket<Q, T, R, M>>,
74 _lt: PhantomData<Pin<&'a mut Socket<Q, T, R, M>>>,
75 port: u8,
76}
77
78pub struct Recv<'a, 'b, Q, T, R, M>
79where
80 Q: BbqHandle,
81 T: Serialize + Clone,
82 R: ScopedRawMutex + 'static,
83 M: InterfaceManager + 'static,
84{
85 hdl: &'a mut SocketHdl<'b, Q, T, R, M>,
86}
87
88impl<Q, T, R, M> Socket<Q, T, R, M>
93where
94 Q: BbqHandle,
95 T: Serialize + Clone,
96 R: ScopedRawMutex + 'static,
97 M: InterfaceManager + 'static,
98{
99 pub const fn new(
100 net: &'static NetStack<R, M>,
101 key: Key,
102 attrs: Attributes,
103 sto: Q,
104 mtu: u16,
105 name: Option<&str>,
106 ) -> Self {
107 Self {
108 hdr: SocketHeader {
109 links: Links::new(),
110 vtable: const { &Self::vtable() },
111 port: 0,
112 attrs,
113 key,
114 nash: if let Some(n) = name {
115 Some(NameHash::new(n))
116 } else {
117 None
118 },
119 },
120 inner: UnsafeCell::new(QueueBox {
121 q: sto,
122 waker: None,
123 }),
124 net,
125 _pd: PhantomData,
126 mtu,
127 }
128 }
129
130 pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, R, M> {
131 let stack = self.net;
132 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
133 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
134 let port = unsafe { stack.attach_socket(ptr_erase) };
135 SocketHdl {
136 ptr: ptr_self,
137 _lt: PhantomData,
138 port,
139 }
140 }
141
142 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, R, M> {
143 let stack = self.net;
144 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
145 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
146 unsafe { stack.attach_broadcast_socket(ptr_erase) };
147 SocketHdl {
148 ptr: ptr_self,
149 _lt: PhantomData,
150 port: 255,
151 }
152 }
153
154 const fn vtable() -> SocketVTable {
155 SocketVTable {
156 recv_owned: Some(Self::recv_owned),
157 recv_bor: Some(Self::recv_bor),
158 recv_raw: Self::recv_raw,
159 recv_err: Some(Self::recv_err),
160 }
161 }
162
163 pub fn stack(&self) -> &'static NetStack<R, M> {
164 self.net
165 }
166
167 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
168 let this: NonNull<Self> = this.cast();
169 let this: &Self = unsafe { this.as_ref() };
170 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
171 let qref = qbox.q.bbq_ref();
172 let prod = qref.framed_producer();
173
174 let Ok(mut wgr) = prod.grant(this.mtu) else {
177 return;
178 };
179
180 let ser = ser_flavors::Slice::new(&mut wgr);
181
182 let chdr = CommonHeader {
183 src: hdr.src.as_u32(),
184 dst: hdr.dst.as_u32(),
185 seq_no: hdr.seq_no,
186 kind: hdr.kind.0,
187 ttl: hdr.ttl,
188 };
189
190 if let Ok(used) = wire_frames::encode_frame_err(ser, &chdr, err) {
191 let len = used.len() as u16;
192 wgr.commit(len);
193 if let Some(wake) = qbox.waker.take() {
194 wake.wake();
195 }
196 }
197 }
198
199 fn recv_owned(
200 this: NonNull<()>,
201 that: NonNull<()>,
202 hdr: HeaderSeq,
203 _ty: &TypeId,
206 ) -> Result<(), SocketSendError> {
207 let that: NonNull<T> = that.cast();
208 let that: &T = unsafe { that.as_ref() };
209 let this: NonNull<Self> = this.cast();
210 let this: &Self = unsafe { this.as_ref() };
211 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
212 let qref = qbox.q.bbq_ref();
213 let prod = qref.framed_producer();
214
215 let Ok(mut wgr) = prod.grant(this.mtu) else {
216 return Err(SocketSendError::NoSpace);
217 };
218 let ser = ser_flavors::Slice::new(&mut wgr);
219
220 let chdr = CommonHeader {
221 src: hdr.src.as_u32(),
222 dst: hdr.dst.as_u32(),
223 seq_no: hdr.seq_no,
224 kind: hdr.kind.0,
225 ttl: hdr.ttl,
226 };
227
228 let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
229 return Err(SocketSendError::NoSpace);
230 };
231
232 let len = used.len() as u16;
233 wgr.commit(len);
234
235 if let Some(wake) = qbox.waker.take() {
236 wake.wake();
237 }
238
239 Ok(())
240 }
241
242 fn recv_bor(
243 this: NonNull<()>,
244 that: NonNull<()>,
245 hdr: HeaderSeq,
246 ) -> Result<(), SocketSendError> {
247 let this: NonNull<Self> = this.cast();
248 let this: &Self = unsafe { this.as_ref() };
249 let that: NonNull<T> = that.cast();
250 let that: &T = unsafe { that.as_ref() };
251 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
252 let qref = qbox.q.bbq_ref();
253 let prod = qref.framed_producer();
254
255 let Ok(mut wgr) = prod.grant(this.mtu) else {
256 return Err(SocketSendError::NoSpace);
257 };
258 let ser = ser_flavors::Slice::new(&mut wgr);
259
260 let chdr = CommonHeader {
261 src: hdr.src.as_u32(),
262 dst: hdr.dst.as_u32(),
263 seq_no: hdr.seq_no,
264 kind: hdr.kind.0,
265 ttl: hdr.ttl,
266 };
267
268 let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
269 return Err(SocketSendError::NoSpace);
270 };
271
272 let len = used.len() as u16;
273 wgr.commit(len);
274
275 if let Some(wake) = qbox.waker.take() {
276 wake.wake();
277 }
278
279 Ok(())
280 }
281
282 fn recv_raw(
283 this: NonNull<()>,
284 that: &[u8],
285 _hdr: HeaderSeq,
286 hdr_raw: &[u8],
287 ) -> Result<(), SocketSendError> {
288 let this: NonNull<Self> = this.cast();
289 let this: &Self = unsafe { this.as_ref() };
290 let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
291 let qref = qbox.q.bbq_ref();
292 let prod = qref.framed_producer();
293
294 let Ok(needed) = u16::try_from(that.len() + hdr_raw.len()) else {
295 return Err(SocketSendError::NoSpace);
296 };
297
298 let Ok(mut wgr) = prod.grant(needed) else {
299 return Err(SocketSendError::NoSpace);
300 };
301 let (hdr, body) = wgr.split_at_mut(hdr_raw.len());
302 hdr.copy_from_slice(hdr_raw);
303 body.copy_from_slice(that);
304 wgr.commit(needed);
305
306 if let Some(wake) = qbox.waker.take() {
307 wake.wake();
308 }
309
310 Ok(())
311 }
312}
313
314impl<'a, Q, T, R, M> SocketHdl<'a, Q, T, R, M>
318where
319 Q: BbqHandle,
320 T: Serialize + Clone,
321 R: ScopedRawMutex + 'static,
322 M: InterfaceManager + 'static,
323{
324 pub fn port(&self) -> u8 {
325 self.port
326 }
327
328 pub fn stack(&self) -> &'static NetStack<R, M> {
329 unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
330 }
331
332 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, Q, T, R, M> {
333 Recv { hdl: self }
334 }
335}
336
337impl<Q, T, R, M> Drop for Socket<Q, T, R, M>
338where
339 Q: BbqHandle,
340 T: Serialize + Clone,
341 R: ScopedRawMutex + 'static,
342 M: InterfaceManager + 'static,
343{
344 fn drop(&mut self) {
345 unsafe {
346 let this = NonNull::from(&self.hdr);
347 self.net.detach_socket(this);
348 }
349 }
350}
351
352unsafe impl<Q, T, R, M> Send for SocketHdl<'_, Q, T, R, M>
353where
354 Q: BbqHandle,
355 T: Serialize + Clone,
356 R: ScopedRawMutex + 'static,
357 M: InterfaceManager + 'static,
358{
359}
360
361unsafe impl<Q, T, R, M> Sync for SocketHdl<'_, Q, T, R, M>
362where
363 Q: BbqHandle,
364 T: Serialize + Clone,
365 R: ScopedRawMutex + 'static,
366 M: InterfaceManager + 'static,
367{
368}
369
370enum ResponseGrantInner<Q: BbqHandle, T> {
373 Ok {
374 grant: FramedGrantR<Q, u16>,
375 offset: usize,
376 deser_erased: PhantomData<fn() -> T>,
377 },
378 Err(ProtocolError),
379}
380
381pub struct ResponseGrant<Q: BbqHandle, T> {
382 pub hdr: HeaderSeq,
383 inner: ResponseGrantInner<Q, T>,
384}
385
386impl<Q: BbqHandle, T> Drop for ResponseGrant<Q, T> {
387 fn drop(&mut self) {
388 let old = core::mem::replace(
389 &mut self.inner,
390 ResponseGrantInner::Err(ProtocolError(u16::MAX)),
391 );
392 match old {
393 ResponseGrantInner::Ok { grant, .. } => {
394 grant.release();
395 }
396 ResponseGrantInner::Err(_) => {}
397 }
398 }
399}
400
401impl<Q: BbqHandle, T> ResponseGrant<Q, T> {
402 pub fn try_access<'de, 'me: 'de>(&'me self) -> Option<Response<T>>
407 where
408 T: Deserialize<'de>,
409 {
410 Some(match &self.inner {
411 ResponseGrantInner::Ok {
412 grant,
413 deser_erased: _,
414 offset,
415 } => {
416 let t = postcard::from_bytes::<T>(grant.get(*offset..)?).ok()?;
418 Response::Ok(HeaderMessage {
419 hdr: self.hdr.clone(),
420 t,
421 })
422 }
423 ResponseGrantInner::Err(protocol_error) => Response::Err(HeaderMessage {
424 hdr: self.hdr.clone(),
425 t: *protocol_error,
426 }),
427 })
428 }
429}
430
431impl<'a, Q, T, R, M> Future for Recv<'a, '_, Q, T, R, M>
432where
433 Q: BbqHandle,
434 T: Serialize + Clone,
435 R: ScopedRawMutex + 'static,
436 M: InterfaceManager + 'static,
437{
438 type Output = ResponseGrant<Q, T>;
439
440 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
441 let net: &'static NetStack<R, M> = self.hdl.stack();
442 let f = || -> Option<ResponseGrant<Q, T>> {
443 let this_ref: &Socket<Q, T, R, M> = unsafe { self.hdl.ptr.as_ref() };
444 let qbox: &mut QueueBox<Q> = unsafe { &mut *this_ref.inner.get() };
445 let cons: FramedConsumer<Q, u16> = qbox.q.framed_consumer();
446
447 if let Ok(resp) = cons.read() {
448 let sli: &[u8] = resp.deref();
449
450 if let Some(frame) = de_frame(sli) {
451 let BorrowedFrame {
452 hdr,
453 body,
454 hdr_raw: _,
455 } = frame;
456 match body {
457 Ok(body) => {
458 let sli: &[u8] = body;
459 let offset = (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
477 return Some(ResponseGrant {
478 hdr,
479 inner: ResponseGrantInner::Ok {
480 grant: resp,
481 offset,
482 deser_erased: PhantomData,
483 },
484 });
485 }
486 Err(err) => {
487 resp.release();
488 return Some(ResponseGrant {
489 hdr,
490 inner: ResponseGrantInner::Err(err),
491 });
492 }
493 }
494 }
495 }
496
497 let new_wake = cx.waker();
498 if let Some(w) = qbox.waker.take() {
499 if !w.will_wake(new_wake) {
500 w.wake();
501 }
502 }
503 qbox.waker = Some(new_wake.clone());
506 None
507 };
508 let res = unsafe { net.with_lock(f) };
509 if let Some(t) = res {
510 Poll::Ready(t)
511 } else {
512 Poll::Pending
513 }
514 }
515}
516
517unsafe impl<Q, T, R, M> Sync for Recv<'_, '_, Q, T, R, M>
518where
519 Q: BbqHandle,
520 T: Serialize + Clone,
521 R: ScopedRawMutex + 'static,
522 M: InterfaceManager + 'static,
523{
524}