ergot_base/socket/
borrow.rs

1//! "Borrow" sockets
2//!
3//! Borrow sockets use a `bbq2` queue to store the serialized form of messages.
4//!
5//! This allows for sending and receiving borrowed types like `&str` or `&[u8]`,
6//! or messages that contain borrowed types. This is achieved by serializing
7//! messages into the bbq2 ring buffer when inserting into the socket, and
8//! deserializing when removing from the socket.
9//!
10//! Although you can use borrowed sockets for types that are fully owned, e.g.
11//! `T: 'static`, you should prefer the [`owned`](crate::socket::owned) socket
12//! variants when possible, as they store messages more efficiently and may be
13//! able to fully skip a ser/de round trip when sending messages locally.
14
15use core::{
16    any::TypeId,
17    cell::UnsafeCell,
18    marker::PhantomData,
19    ops::Deref,
20    pin::Pin,
21    ptr::{NonNull, addr_of},
22    task::{Context, Poll, Waker},
23};
24
25use bbq2::{
26    prod_cons::framed::{FramedConsumer, FramedGrantR},
27    traits::bbqhdl::BbqHandle,
28};
29use cordyceps::list::Links;
30use postcard::ser_flavors;
31use serde::{Deserialize, Serialize};
32
33use crate::{
34    HeaderSeq, Key, ProtocolError,
35    nash::NameHash,
36    net_stack::NetStackHandle,
37    wire_frames::{self, BorrowedFrame, CommonHeader, de_frame},
38};
39
40use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
41
42#[repr(C)]
43pub struct Socket<Q, T, N>
44where
45    Q: BbqHandle,
46    T: Serialize + Clone,
47    N: NetStackHandle,
48{
49    // LOAD BEARING: must be first
50    hdr: SocketHeader,
51    pub(crate) net: N::Target,
52    inner: UnsafeCell<QueueBox<Q>>,
53    mtu: u16,
54    _pd: PhantomData<fn() -> T>,
55}
56
57pub struct SocketHdl<'a, Q, T, N>
58where
59    Q: BbqHandle,
60    T: Serialize + Clone,
61    N: NetStackHandle,
62{
63    pub(crate) ptr: NonNull<Socket<Q, T, N>>,
64    _lt: PhantomData<Pin<&'a mut Socket<Q, T, N>>>,
65    port: u8,
66}
67
68pub struct Recv<'a, 'b, Q, T, N>
69where
70    Q: BbqHandle,
71    T: Serialize + Clone,
72    N: NetStackHandle,
73{
74    hdl: &'a mut SocketHdl<'b, Q, T, N>,
75}
76
77pub struct ResponseGrant<Q: BbqHandle, T> {
78    pub hdr: HeaderSeq,
79    inner: ResponseGrantInner<Q, T>,
80}
81
82struct QueueBox<Q: BbqHandle> {
83    q: Q,
84    waker: Option<Waker>,
85}
86
87enum ResponseGrantInner<Q: BbqHandle, T> {
88    Ok {
89        grant: FramedGrantR<Q, u16>,
90        offset: usize,
91        deser_erased: PhantomData<fn() -> T>,
92    },
93    Err(ProtocolError),
94}
95
96// ---- impls ----
97
98// impl Socket
99
100impl<Q, T, N> Socket<Q, T, N>
101where
102    Q: BbqHandle,
103    T: Serialize + Clone,
104    N: NetStackHandle,
105{
106    pub const fn new(
107        net: N::Target,
108        key: Key,
109        attrs: Attributes,
110        sto: Q,
111        mtu: u16,
112        name: Option<&str>,
113    ) -> Self {
114        Self {
115            hdr: SocketHeader {
116                links: Links::new(),
117                vtable: const { &Self::vtable() },
118                port: 0,
119                attrs,
120                key,
121                nash: if let Some(n) = name {
122                    Some(NameHash::new(n))
123                } else {
124                    None
125                },
126            },
127            inner: UnsafeCell::new(QueueBox {
128                q: sto,
129                waker: None,
130            }),
131            net,
132            _pd: PhantomData,
133            mtu,
134        }
135    }
136
137    pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
138        let stack = self.net.clone();
139        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
140        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
141        let port = unsafe { stack.attach_socket(ptr_erase) };
142        SocketHdl {
143            ptr: ptr_self,
144            _lt: PhantomData,
145            port,
146        }
147    }
148
149    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
150        let stack = self.net.clone();
151        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
152        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
153        unsafe { stack.attach_broadcast_socket(ptr_erase) };
154        SocketHdl {
155            ptr: ptr_self,
156            _lt: PhantomData,
157            port: 255,
158        }
159    }
160
161    const fn vtable() -> SocketVTable {
162        SocketVTable {
163            recv_owned: Some(Self::recv_owned),
164            recv_bor: Some(Self::recv_bor),
165            recv_raw: Self::recv_raw,
166            recv_err: Some(Self::recv_err),
167        }
168    }
169
170    pub fn stack(&self) -> N::Target {
171        self.net.clone()
172    }
173
174    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
175        let this: NonNull<Self> = this.cast();
176        let this: &Self = unsafe { this.as_ref() };
177        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
178        let qref = qbox.q.bbq_ref();
179        let prod = qref.framed_producer();
180
181        // TODO: we could probably use a smaller grant here than the MTU,
182        // allowing more grants to succeed.
183        let Ok(mut wgr) = prod.grant(this.mtu) else {
184            return;
185        };
186
187        let ser = ser_flavors::Slice::new(&mut wgr);
188
189        let chdr = CommonHeader {
190            src: hdr.src,
191            dst: hdr.dst,
192            seq_no: hdr.seq_no,
193            kind: hdr.kind,
194            ttl: hdr.ttl,
195        };
196
197        if let Ok(used) = wire_frames::encode_frame_err(ser, &chdr, err) {
198            let len = used.len() as u16;
199            wgr.commit(len);
200            if let Some(wake) = qbox.waker.take() {
201                wake.wake();
202            }
203        }
204    }
205
206    fn recv_owned(
207        this: NonNull<()>,
208        that: NonNull<()>,
209        hdr: HeaderSeq,
210        // We can't use TypeId here because mismatched lifetimes have different
211        // type ids!
212        _ty: &TypeId,
213    ) -> Result<(), SocketSendError> {
214        let that: NonNull<T> = that.cast();
215        let that: &T = unsafe { that.as_ref() };
216        let this: NonNull<Self> = this.cast();
217        let this: &Self = unsafe { this.as_ref() };
218        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
219        let qref = qbox.q.bbq_ref();
220        let prod = qref.framed_producer();
221
222        let Ok(mut wgr) = prod.grant(this.mtu) else {
223            return Err(SocketSendError::NoSpace);
224        };
225        let ser = ser_flavors::Slice::new(&mut wgr);
226
227        let chdr = CommonHeader {
228            src: hdr.src,
229            dst: hdr.dst,
230            seq_no: hdr.seq_no,
231            kind: hdr.kind,
232            ttl: hdr.ttl,
233        };
234
235        let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
236            return Err(SocketSendError::NoSpace);
237        };
238
239        let len = used.len() as u16;
240        wgr.commit(len);
241
242        if let Some(wake) = qbox.waker.take() {
243            wake.wake();
244        }
245
246        Ok(())
247    }
248
249    fn recv_bor(
250        this: NonNull<()>,
251        that: NonNull<()>,
252        hdr: HeaderSeq,
253    ) -> Result<(), SocketSendError> {
254        let this: NonNull<Self> = this.cast();
255        let this: &Self = unsafe { this.as_ref() };
256        let that: NonNull<T> = that.cast();
257        let that: &T = unsafe { that.as_ref() };
258        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
259        let qref = qbox.q.bbq_ref();
260        let prod = qref.framed_producer();
261
262        let Ok(mut wgr) = prod.grant(this.mtu) else {
263            return Err(SocketSendError::NoSpace);
264        };
265        let ser = ser_flavors::Slice::new(&mut wgr);
266
267        let chdr = CommonHeader {
268            src: hdr.src,
269            dst: hdr.dst,
270            seq_no: hdr.seq_no,
271            kind: hdr.kind,
272            ttl: hdr.ttl,
273        };
274
275        let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
276            return Err(SocketSendError::NoSpace);
277        };
278
279        let len = used.len() as u16;
280        wgr.commit(len);
281
282        if let Some(wake) = qbox.waker.take() {
283            wake.wake();
284        }
285
286        Ok(())
287    }
288
289    fn recv_raw(
290        this: NonNull<()>,
291        that: &[u8],
292        _hdr: HeaderSeq,
293        hdr_raw: &[u8],
294    ) -> Result<(), SocketSendError> {
295        let this: NonNull<Self> = this.cast();
296        let this: &Self = unsafe { this.as_ref() };
297        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
298        let qref = qbox.q.bbq_ref();
299        let prod = qref.framed_producer();
300
301        let Ok(needed) = u16::try_from(that.len() + hdr_raw.len()) else {
302            return Err(SocketSendError::NoSpace);
303        };
304
305        let Ok(mut wgr) = prod.grant(needed) else {
306            return Err(SocketSendError::NoSpace);
307        };
308        let (hdr, body) = wgr.split_at_mut(hdr_raw.len());
309        hdr.copy_from_slice(hdr_raw);
310        body.copy_from_slice(that);
311        wgr.commit(needed);
312
313        if let Some(wake) = qbox.waker.take() {
314            wake.wake();
315        }
316
317        Ok(())
318    }
319}
320
321// impl SocketHdl
322
323impl<'a, Q, T, N> SocketHdl<'a, Q, T, N>
324where
325    Q: BbqHandle,
326    T: Serialize + Clone,
327    N: NetStackHandle,
328{
329    pub fn port(&self) -> u8 {
330        self.port
331    }
332
333    pub fn stack(&self) -> N::Target {
334        unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
335    }
336
337    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, Q, T, N> {
338        Recv { hdl: self }
339    }
340}
341
342impl<Q, T, N> Drop for Socket<Q, T, N>
343where
344    Q: BbqHandle,
345    T: Serialize + Clone,
346    N: NetStackHandle,
347{
348    fn drop(&mut self) {
349        unsafe {
350            let this = NonNull::from(&self.hdr);
351            self.net.detach_socket(this);
352        }
353    }
354}
355
356unsafe impl<Q, T, N> Send for SocketHdl<'_, Q, T, N>
357where
358    Q: BbqHandle,
359    T: Serialize + Clone,
360    N: NetStackHandle,
361{
362}
363
364unsafe impl<Q, T, N> Sync for SocketHdl<'_, Q, T, N>
365where
366    Q: BbqHandle,
367    T: Serialize + Clone,
368    N: NetStackHandle,
369{
370}
371
372// impl Recv
373
374impl<'a, Q, T, N> Future for Recv<'a, '_, Q, T, N>
375where
376    Q: BbqHandle,
377    T: Serialize + Clone,
378    N: NetStackHandle,
379{
380    type Output = ResponseGrant<Q, T>;
381
382    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
383        let net: N::Target = self.hdl.stack();
384        let f = || -> Option<ResponseGrant<Q, T>> {
385            let this_ref: &Socket<Q, T, N> = unsafe { self.hdl.ptr.as_ref() };
386            let qbox: &mut QueueBox<Q> = unsafe { &mut *this_ref.inner.get() };
387            let cons: FramedConsumer<Q, u16> = qbox.q.framed_consumer();
388
389            if let Ok(resp) = cons.read() {
390                let sli: &[u8] = resp.deref();
391
392                if let Some(frame) = de_frame(sli) {
393                    let BorrowedFrame {
394                        hdr,
395                        body,
396                        hdr_raw: _,
397                    } = frame;
398                    match body {
399                        Ok(body) => {
400                            let sli: &[u8] = body;
401                            // I want to be able to do something like this:
402                            //
403                            // if let Ok(_msg) = postcard::from_bytes::<T>(sli) {
404                            //     let offset =
405                            //         (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
406                            //     return Some(ResponseGrant {
407                            //         hdr,
408                            //         inner: ResponseGrantInner::Ok {
409                            //             grant: resp,
410                            //             offset,
411                            //             deser_erased: PhantomData,
412                            //         },
413                            //         _plt: PhantomData,
414                            //     });
415                            // } else {
416                            //     resp.release();
417                            // }
418                            let offset = (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
419                            return Some(ResponseGrant {
420                                hdr,
421                                inner: ResponseGrantInner::Ok {
422                                    grant: resp,
423                                    offset,
424                                    deser_erased: PhantomData,
425                                },
426                            });
427                        }
428                        Err(err) => {
429                            resp.release();
430                            return Some(ResponseGrant {
431                                hdr,
432                                inner: ResponseGrantInner::Err(err),
433                            });
434                        }
435                    }
436                }
437            }
438
439            let new_wake = cx.waker();
440            if let Some(w) = qbox.waker.take() {
441                if !w.will_wake(new_wake) {
442                    w.wake();
443                }
444            }
445            // NOTE: Okay to register waker AFTER checking, because we
446            // have an exclusive lock
447            qbox.waker = Some(new_wake.clone());
448            None
449        };
450        let res = unsafe { net.with_lock(f) };
451        if let Some(t) = res {
452            Poll::Ready(t)
453        } else {
454            Poll::Pending
455        }
456    }
457}
458
459unsafe impl<Q, T, N> Sync for Recv<'_, '_, Q, T, N>
460where
461    Q: BbqHandle,
462    T: Serialize + Clone,
463    N: NetStackHandle,
464{
465}
466
467// impl ResponseGrant
468
469impl<Q: BbqHandle, T> ResponseGrant<Q, T> {
470    // TODO: I don't want this being failable, but right now I can't figure out
471    // how to make Recv::poll() do the checking without hitting awkward inner
472    // lifetimes for deserialization. If you know how to make this less awkward,
473    // please @ me somewhere about it.
474    pub fn try_access<'de, 'me: 'de>(&'me self) -> Option<Response<T>>
475    where
476        T: Deserialize<'de>,
477    {
478        Some(match &self.inner {
479            ResponseGrantInner::Ok {
480                grant,
481                deser_erased: _,
482                offset,
483            } => {
484                // TODO: We could use something like Yoke to skip repeating deser
485                let t = postcard::from_bytes::<T>(grant.get(*offset..)?).ok()?;
486                Response::Ok(HeaderMessage {
487                    hdr: self.hdr.clone(),
488                    t,
489                })
490            }
491            ResponseGrantInner::Err(protocol_error) => Response::Err(HeaderMessage {
492                hdr: self.hdr.clone(),
493                t: *protocol_error,
494            }),
495        })
496    }
497}
498
499impl<Q: BbqHandle, T> Drop for ResponseGrant<Q, T> {
500    fn drop(&mut self) {
501        let old = core::mem::replace(
502            &mut self.inner,
503            ResponseGrantInner::Err(ProtocolError(u16::MAX)),
504        );
505        match old {
506            ResponseGrantInner::Ok { grant, .. } => {
507                grant.release();
508            }
509            ResponseGrantInner::Err(_) => {}
510        }
511    }
512}