ergot_base/socket/
borrow.rs

1//! "Borrow" sockets
2//!
3//! Borrow sockets use a `bbq2` queue to store the serialized form of messages.
4//!
5//! This allows for sending and receiving borrowed types like `&str` or `&[u8]`,
6//! or messages that contain borrowed types. This is achieved by serializing
7//! messages into the bbq2 ring buffer when inserting into the socket, and
8//! deserializing when removing from the socket.
9//!
10//! Although you can use borrowed sockets for types that are fully owned, e.g.
11//! `T: 'static`, you should prefer the [`owned`](crate::socket::owned) socket
12//! variants when possible, as they store messages more efficiently and may be
13//! able to fully skip a ser/de round trip when sending messages locally.
14
15use core::{
16    any::TypeId,
17    cell::UnsafeCell,
18    marker::PhantomData,
19    ops::Deref,
20    pin::Pin,
21    ptr::{NonNull, addr_of},
22    task::{Context, Poll, Waker},
23};
24
25use bbq2::{
26    prod_cons::framed::{FramedConsumer, FramedGrantR},
27    traits::bbqhdl::BbqHandle,
28};
29use cordyceps::list::Links;
30use mutex::ScopedRawMutex;
31use postcard::ser_flavors;
32use serde::{Deserialize, Serialize};
33
34use crate::{
35    HeaderSeq, Key, NetStack, ProtocolError,
36    interface_manager::{
37        BorrowedFrame, InterfaceManager,
38        wire_frames::{self, CommonHeader, de_frame},
39    },
40    nash::NameHash,
41};
42
43use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
44
45struct QueueBox<Q: BbqHandle> {
46    q: Q,
47    waker: Option<Waker>,
48}
49
50#[repr(C)]
51pub struct Socket<Q, T, R, M>
52where
53    Q: BbqHandle,
54    T: Serialize + Clone,
55    R: ScopedRawMutex + 'static,
56    M: InterfaceManager + 'static,
57{
58    // LOAD BEARING: must be first
59    hdr: SocketHeader,
60    pub(crate) net: &'static NetStack<R, M>,
61    inner: UnsafeCell<QueueBox<Q>>,
62    mtu: u16,
63    _pd: PhantomData<fn() -> T>,
64}
65
66pub struct SocketHdl<'a, Q, T, R, M>
67where
68    Q: BbqHandle,
69    T: Serialize + Clone,
70    R: ScopedRawMutex + 'static,
71    M: InterfaceManager + 'static,
72{
73    pub(crate) ptr: NonNull<Socket<Q, T, R, M>>,
74    _lt: PhantomData<Pin<&'a mut Socket<Q, T, R, M>>>,
75    port: u8,
76}
77
78pub struct Recv<'a, 'b, Q, T, R, M>
79where
80    Q: BbqHandle,
81    T: Serialize + Clone,
82    R: ScopedRawMutex + 'static,
83    M: InterfaceManager + 'static,
84{
85    hdl: &'a mut SocketHdl<'b, Q, T, R, M>,
86}
87
88// ---- impls ----
89
90// impl Socket
91
92impl<Q, T, R, M> Socket<Q, T, R, M>
93where
94    Q: BbqHandle,
95    T: Serialize + Clone,
96    R: ScopedRawMutex + 'static,
97    M: InterfaceManager + 'static,
98{
99    pub const fn new(
100        net: &'static NetStack<R, M>,
101        key: Key,
102        attrs: Attributes,
103        sto: Q,
104        mtu: u16,
105        name: Option<&str>,
106    ) -> Self {
107        Self {
108            hdr: SocketHeader {
109                links: Links::new(),
110                vtable: const { &Self::vtable() },
111                port: 0,
112                attrs,
113                key,
114                nash: if let Some(n) = name {
115                    Some(NameHash::new(n))
116                } else {
117                    None
118                },
119            },
120            inner: UnsafeCell::new(QueueBox {
121                q: sto,
122                waker: None,
123            }),
124            net,
125            _pd: PhantomData,
126            mtu,
127        }
128    }
129
130    pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, R, M> {
131        let stack = self.net;
132        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
133        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
134        let port = unsafe { stack.attach_socket(ptr_erase) };
135        SocketHdl {
136            ptr: ptr_self,
137            _lt: PhantomData,
138            port,
139        }
140    }
141
142    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, R, M> {
143        let stack = self.net;
144        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
145        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
146        unsafe { stack.attach_broadcast_socket(ptr_erase) };
147        SocketHdl {
148            ptr: ptr_self,
149            _lt: PhantomData,
150            port: 255,
151        }
152    }
153
154    const fn vtable() -> SocketVTable {
155        SocketVTable {
156            recv_owned: Some(Self::recv_owned),
157            recv_bor: Some(Self::recv_bor),
158            recv_raw: Self::recv_raw,
159            recv_err: Some(Self::recv_err),
160        }
161    }
162
163    pub fn stack(&self) -> &'static NetStack<R, M> {
164        self.net
165    }
166
167    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
168        let this: NonNull<Self> = this.cast();
169        let this: &Self = unsafe { this.as_ref() };
170        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
171        let qref = qbox.q.bbq_ref();
172        let prod = qref.framed_producer();
173
174        // TODO: we could probably use a smaller grant here than the MTU,
175        // allowing more grants to succeed.
176        let Ok(mut wgr) = prod.grant(this.mtu) else {
177            return;
178        };
179
180        let ser = ser_flavors::Slice::new(&mut wgr);
181
182        let chdr = CommonHeader {
183            src: hdr.src.as_u32(),
184            dst: hdr.dst.as_u32(),
185            seq_no: hdr.seq_no,
186            kind: hdr.kind.0,
187            ttl: hdr.ttl,
188        };
189
190        if let Ok(used) = wire_frames::encode_frame_err(ser, &chdr, err) {
191            let len = used.len() as u16;
192            wgr.commit(len);
193            if let Some(wake) = qbox.waker.take() {
194                wake.wake();
195            }
196        }
197    }
198
199    fn recv_owned(
200        this: NonNull<()>,
201        that: NonNull<()>,
202        hdr: HeaderSeq,
203        // We can't use TypeId here because mismatched lifetimes have different
204        // type ids!
205        _ty: &TypeId,
206    ) -> Result<(), SocketSendError> {
207        let that: NonNull<T> = that.cast();
208        let that: &T = unsafe { that.as_ref() };
209        let this: NonNull<Self> = this.cast();
210        let this: &Self = unsafe { this.as_ref() };
211        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
212        let qref = qbox.q.bbq_ref();
213        let prod = qref.framed_producer();
214
215        let Ok(mut wgr) = prod.grant(this.mtu) else {
216            return Err(SocketSendError::NoSpace);
217        };
218        let ser = ser_flavors::Slice::new(&mut wgr);
219
220        let chdr = CommonHeader {
221            src: hdr.src.as_u32(),
222            dst: hdr.dst.as_u32(),
223            seq_no: hdr.seq_no,
224            kind: hdr.kind.0,
225            ttl: hdr.ttl,
226        };
227
228        let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
229            return Err(SocketSendError::NoSpace);
230        };
231
232        let len = used.len() as u16;
233        wgr.commit(len);
234
235        if let Some(wake) = qbox.waker.take() {
236            wake.wake();
237        }
238
239        Ok(())
240    }
241
242    fn recv_bor(
243        this: NonNull<()>,
244        that: NonNull<()>,
245        hdr: HeaderSeq,
246    ) -> Result<(), SocketSendError> {
247        let this: NonNull<Self> = this.cast();
248        let this: &Self = unsafe { this.as_ref() };
249        let that: NonNull<T> = that.cast();
250        let that: &T = unsafe { that.as_ref() };
251        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
252        let qref = qbox.q.bbq_ref();
253        let prod = qref.framed_producer();
254
255        let Ok(mut wgr) = prod.grant(this.mtu) else {
256            return Err(SocketSendError::NoSpace);
257        };
258        let ser = ser_flavors::Slice::new(&mut wgr);
259
260        let chdr = CommonHeader {
261            src: hdr.src.as_u32(),
262            dst: hdr.dst.as_u32(),
263            seq_no: hdr.seq_no,
264            kind: hdr.kind.0,
265            ttl: hdr.ttl,
266        };
267
268        let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
269            return Err(SocketSendError::NoSpace);
270        };
271
272        let len = used.len() as u16;
273        wgr.commit(len);
274
275        if let Some(wake) = qbox.waker.take() {
276            wake.wake();
277        }
278
279        Ok(())
280    }
281
282    fn recv_raw(
283        this: NonNull<()>,
284        that: &[u8],
285        _hdr: HeaderSeq,
286        hdr_raw: &[u8],
287    ) -> Result<(), SocketSendError> {
288        let this: NonNull<Self> = this.cast();
289        let this: &Self = unsafe { this.as_ref() };
290        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
291        let qref = qbox.q.bbq_ref();
292        let prod = qref.framed_producer();
293
294        let Ok(needed) = u16::try_from(that.len() + hdr_raw.len()) else {
295            return Err(SocketSendError::NoSpace);
296        };
297
298        let Ok(mut wgr) = prod.grant(needed) else {
299            return Err(SocketSendError::NoSpace);
300        };
301        let (hdr, body) = wgr.split_at_mut(hdr_raw.len());
302        hdr.copy_from_slice(hdr_raw);
303        body.copy_from_slice(that);
304        wgr.commit(needed);
305
306        if let Some(wake) = qbox.waker.take() {
307            wake.wake();
308        }
309
310        Ok(())
311    }
312}
313
314// impl SocketHdl
315
316// TODO: impl drop, remove waker, remove socket
317impl<'a, Q, T, R, M> SocketHdl<'a, Q, T, R, M>
318where
319    Q: BbqHandle,
320    T: Serialize + Clone,
321    R: ScopedRawMutex + 'static,
322    M: InterfaceManager + 'static,
323{
324    pub fn port(&self) -> u8 {
325        self.port
326    }
327
328    pub fn stack(&self) -> &'static NetStack<R, M> {
329        unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
330    }
331
332    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, Q, T, R, M> {
333        Recv { hdl: self }
334    }
335}
336
337impl<Q, T, R, M> Drop for Socket<Q, T, R, M>
338where
339    Q: BbqHandle,
340    T: Serialize + Clone,
341    R: ScopedRawMutex + 'static,
342    M: InterfaceManager + 'static,
343{
344    fn drop(&mut self) {
345        unsafe {
346            let this = NonNull::from(&self.hdr);
347            self.net.detach_socket(this);
348        }
349    }
350}
351
352unsafe impl<Q, T, R, M> Send for SocketHdl<'_, Q, T, R, M>
353where
354    Q: BbqHandle,
355    T: Serialize + Clone,
356    R: ScopedRawMutex + 'static,
357    M: InterfaceManager + 'static,
358{
359}
360
361unsafe impl<Q, T, R, M> Sync for SocketHdl<'_, Q, T, R, M>
362where
363    Q: BbqHandle,
364    T: Serialize + Clone,
365    R: ScopedRawMutex + 'static,
366    M: InterfaceManager + 'static,
367{
368}
369
370// impl Recv
371
372enum ResponseGrantInner<Q: BbqHandle, T> {
373    Ok {
374        grant: FramedGrantR<Q, u16>,
375        offset: usize,
376        deser_erased: PhantomData<fn() -> T>,
377    },
378    Err(ProtocolError),
379}
380
381pub struct ResponseGrant<Q: BbqHandle, T> {
382    pub hdr: HeaderSeq,
383    inner: ResponseGrantInner<Q, T>,
384}
385
386impl<Q: BbqHandle, T> Drop for ResponseGrant<Q, T> {
387    fn drop(&mut self) {
388        let old = core::mem::replace(
389            &mut self.inner,
390            ResponseGrantInner::Err(ProtocolError(u16::MAX)),
391        );
392        match old {
393            ResponseGrantInner::Ok { grant, .. } => {
394                grant.release();
395            }
396            ResponseGrantInner::Err(_) => {}
397        }
398    }
399}
400
401impl<Q: BbqHandle, T> ResponseGrant<Q, T> {
402    // TODO: I don't want this being failable, but right now I can't figure out
403    // how to make Recv::poll() do the checking without hitting awkward inner
404    // lifetimes for deserialization. If you know how to make this less awkward,
405    // please @ me somewhere about it.
406    pub fn try_access<'de, 'me: 'de>(&'me self) -> Option<Response<T>>
407    where
408        T: Deserialize<'de>,
409    {
410        Some(match &self.inner {
411            ResponseGrantInner::Ok {
412                grant,
413                deser_erased: _,
414                offset,
415            } => {
416                // TODO: We could use something like Yoke to skip repeating deser
417                let t = postcard::from_bytes::<T>(grant.get(*offset..)?).ok()?;
418                Response::Ok(HeaderMessage {
419                    hdr: self.hdr.clone(),
420                    t,
421                })
422            }
423            ResponseGrantInner::Err(protocol_error) => Response::Err(HeaderMessage {
424                hdr: self.hdr.clone(),
425                t: *protocol_error,
426            }),
427        })
428    }
429}
430
431impl<'a, Q, T, R, M> Future for Recv<'a, '_, Q, T, R, M>
432where
433    Q: BbqHandle,
434    T: Serialize + Clone,
435    R: ScopedRawMutex + 'static,
436    M: InterfaceManager + 'static,
437{
438    type Output = ResponseGrant<Q, T>;
439
440    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
441        let net: &'static NetStack<R, M> = self.hdl.stack();
442        let f = || -> Option<ResponseGrant<Q, T>> {
443            let this_ref: &Socket<Q, T, R, M> = unsafe { self.hdl.ptr.as_ref() };
444            let qbox: &mut QueueBox<Q> = unsafe { &mut *this_ref.inner.get() };
445            let cons: FramedConsumer<Q, u16> = qbox.q.framed_consumer();
446
447            if let Ok(resp) = cons.read() {
448                let sli: &[u8] = resp.deref();
449
450                if let Some(frame) = de_frame(sli) {
451                    let BorrowedFrame {
452                        hdr,
453                        body,
454                        hdr_raw: _,
455                    } = frame;
456                    match body {
457                        Ok(body) => {
458                            let sli: &[u8] = body;
459                            // I want to be able to do something like this:
460                            //
461                            // if let Ok(_msg) = postcard::from_bytes::<T>(sli) {
462                            //     let offset =
463                            //         (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
464                            //     return Some(ResponseGrant {
465                            //         hdr,
466                            //         inner: ResponseGrantInner::Ok {
467                            //             grant: resp,
468                            //             offset,
469                            //             deser_erased: PhantomData,
470                            //         },
471                            //         _plt: PhantomData,
472                            //     });
473                            // } else {
474                            //     resp.release();
475                            // }
476                            let offset = (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
477                            return Some(ResponseGrant {
478                                hdr,
479                                inner: ResponseGrantInner::Ok {
480                                    grant: resp,
481                                    offset,
482                                    deser_erased: PhantomData,
483                                },
484                            });
485                        }
486                        Err(err) => {
487                            resp.release();
488                            return Some(ResponseGrant {
489                                hdr,
490                                inner: ResponseGrantInner::Err(err),
491                            });
492                        }
493                    }
494                }
495            }
496
497            let new_wake = cx.waker();
498            if let Some(w) = qbox.waker.take() {
499                if !w.will_wake(new_wake) {
500                    w.wake();
501                }
502            }
503            // NOTE: Okay to register waker AFTER checking, because we
504            // have an exclusive lock
505            qbox.waker = Some(new_wake.clone());
506            None
507        };
508        let res = unsafe { net.with_lock(f) };
509        if let Some(t) = res {
510            Poll::Ready(t)
511        } else {
512            Poll::Pending
513        }
514    }
515}
516
517unsafe impl<Q, T, R, M> Sync for Recv<'_, '_, Q, T, R, M>
518where
519    Q: BbqHandle,
520    T: Serialize + Clone,
521    R: ScopedRawMutex + 'static,
522    M: InterfaceManager + 'static,
523{
524}