ergot_base/socket/
borrow.rs

1//! "Borrow" sockets
2//!
3//! Borrow sockets use a `bbq2` queue to store the serialized form of messages.
4//!
5//! This allows for sending and receiving borrowed types like `&str` or `&[u8]`,
6//! or messages that contain borrowed types. This is achieved by serializing
7//! messages into the bbq2 ring buffer when inserting into the socket, and
8//! deserializing when removing from the socket.
9//!
10//! Although you can use borrowed sockets for types that are fully owned, e.g.
11//! `T: 'static`, you should prefer the [`owned`](crate::socket::owned) socket
12//! variants when possible, as they store messages more efficiently and may be
13//! able to fully skip a ser/de round trip when sending messages locally.
14
15use core::{
16    any::TypeId,
17    cell::UnsafeCell,
18    marker::PhantomData,
19    ops::Deref,
20    pin::Pin,
21    ptr::{NonNull, addr_of},
22    task::{Context, Poll, Waker},
23};
24
25use bbq2::{
26    prod_cons::framed::{FramedConsumer, FramedGrantR},
27    traits::bbqhdl::BbqHandle,
28};
29use cordyceps::list::Links;
30use postcard::ser_flavors;
31use serde::{Deserialize, Serialize};
32
33use crate::{
34    HeaderSeq, Key, ProtocolError,
35    nash::NameHash,
36    net_stack::NetStackHandle,
37    wire_frames::{self, BorrowedFrame, CommonHeader, de_frame},
38};
39
40use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
41
42#[repr(C)]
43pub struct Socket<Q, T, N>
44where
45    Q: BbqHandle,
46    T: Serialize,
47    N: NetStackHandle,
48{
49    // LOAD BEARING: must be first
50    hdr: SocketHeader,
51    pub(crate) net: N::Target,
52    inner: UnsafeCell<QueueBox<Q>>,
53    mtu: u16,
54    _pd: PhantomData<fn() -> T>,
55}
56
57pub struct SocketHdl<'a, Q, T, N>
58where
59    Q: BbqHandle,
60    T: Serialize,
61    N: NetStackHandle,
62{
63    pub(crate) ptr: NonNull<Socket<Q, T, N>>,
64    _lt: PhantomData<Pin<&'a mut Socket<Q, T, N>>>,
65    port: u8,
66}
67
68pub struct Recv<'a, 'b, Q, T, N>
69where
70    Q: BbqHandle,
71    T: Serialize,
72    N: NetStackHandle,
73{
74    hdl: &'a mut SocketHdl<'b, Q, T, N>,
75}
76
77pub struct ResponseGrant<Q: BbqHandle, T> {
78    pub hdr: HeaderSeq,
79    inner: ResponseGrantInner<Q, T>,
80}
81
82struct QueueBox<Q: BbqHandle> {
83    q: Q,
84    waker: Option<Waker>,
85}
86
87enum ResponseGrantInner<Q: BbqHandle, T> {
88    Ok {
89        grant: FramedGrantR<Q, u16>,
90        offset: usize,
91        deser_erased: PhantomData<fn() -> T>,
92    },
93    Err(ProtocolError),
94}
95
96// ---- impls ----
97
98// impl Socket
99
100impl<Q, T, N> Socket<Q, T, N>
101where
102    Q: BbqHandle,
103    T: Serialize,
104    N: NetStackHandle,
105{
106    pub const fn new(
107        net: N::Target,
108        key: Key,
109        attrs: Attributes,
110        sto: Q,
111        mtu: u16,
112        name: Option<&str>,
113    ) -> Self {
114        Self {
115            hdr: SocketHeader {
116                links: Links::new(),
117                vtable: const { &Self::vtable() },
118                port: 0,
119                attrs,
120                key,
121                nash: if let Some(n) = name {
122                    Some(NameHash::new(n))
123                } else {
124                    None
125                },
126            },
127            inner: UnsafeCell::new(QueueBox {
128                q: sto,
129                waker: None,
130            }),
131            net,
132            _pd: PhantomData,
133            mtu,
134        }
135    }
136
137    pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
138        let stack = self.net.clone();
139        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
140        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
141        let port = unsafe { stack.attach_socket(ptr_erase) };
142        SocketHdl {
143            ptr: ptr_self,
144            _lt: PhantomData,
145            port,
146        }
147    }
148
149    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, Q, T, N> {
150        let stack = self.net.clone();
151        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
152        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
153        unsafe { stack.attach_broadcast_socket(ptr_erase) };
154        SocketHdl {
155            ptr: ptr_self,
156            _lt: PhantomData,
157            port: 255,
158        }
159    }
160
161    const fn vtable() -> SocketVTable {
162        SocketVTable {
163            recv_owned: Some(Self::recv_owned),
164            recv_bor: Some(Self::recv_bor),
165            recv_raw: Self::recv_raw,
166            recv_err: Some(Self::recv_err),
167        }
168    }
169
170    pub fn stack(&self) -> N::Target {
171        self.net.clone()
172    }
173
174    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
175        let this: NonNull<Self> = this.cast();
176        let this: &Self = unsafe { this.as_ref() };
177        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
178        let qref = qbox.q.bbq_ref();
179        let prod = qref.framed_producer();
180
181        // TODO: we could probably use a smaller grant here than the MTU,
182        // allowing more grants to succeed.
183        let Ok(mut wgr) = prod.grant(this.mtu) else {
184            return;
185        };
186
187        let ser = ser_flavors::Slice::new(&mut wgr);
188
189        let chdr = CommonHeader {
190            src: hdr.src,
191            dst: hdr.dst,
192            seq_no: hdr.seq_no,
193            kind: hdr.kind,
194            ttl: hdr.ttl,
195        };
196
197        if let Ok(used) = wire_frames::encode_frame_err(ser, &chdr, err) {
198            let len = used.len() as u16;
199            wgr.commit(len);
200            if let Some(wake) = qbox.waker.take() {
201                wake.wake();
202            }
203        }
204    }
205
206    fn recv_owned(
207        this: NonNull<()>,
208        that: NonNull<()>,
209        hdr: HeaderSeq,
210        // We can't use TypeId here because mismatched lifetimes have different
211        // type ids!
212        _ty: &TypeId,
213    ) -> Result<(), SocketSendError> {
214        let that: NonNull<T> = that.cast();
215        let that: &T = unsafe { that.as_ref() };
216        let this: NonNull<Self> = this.cast();
217        let this: &Self = unsafe { this.as_ref() };
218        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
219        let qref = qbox.q.bbq_ref();
220        let prod = qref.framed_producer();
221
222        let Ok(mut wgr) = prod.grant(this.mtu) else {
223            return Err(SocketSendError::NoSpace);
224        };
225        let ser = ser_flavors::Slice::new(&mut wgr);
226
227        let chdr = CommonHeader {
228            src: hdr.src,
229            dst: hdr.dst,
230            seq_no: hdr.seq_no,
231            kind: hdr.kind,
232            ttl: hdr.ttl,
233        };
234
235        let Ok(used) = wire_frames::encode_frame_ty(ser, &chdr, hdr.any_all.as_ref(), that) else {
236            return Err(SocketSendError::NoSpace);
237        };
238
239        let len = used.len() as u16;
240        wgr.commit(len);
241
242        if let Some(wake) = qbox.waker.take() {
243            wake.wake();
244        }
245
246        Ok(())
247    }
248
249    fn recv_bor(
250        this: NonNull<()>,
251        that: NonNull<()>,
252        hdr: HeaderSeq,
253        serfn: fn(NonNull<()>, HeaderSeq, &mut [u8]) -> Result<usize, SocketSendError>,
254    ) -> Result<(), SocketSendError> {
255        let this: NonNull<Self> = this.cast();
256        let this: &Self = unsafe { this.as_ref() };
257        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
258        let qref = qbox.q.bbq_ref();
259        let prod = qref.framed_producer();
260
261        let Ok(mut wgr) = prod.grant(this.mtu) else {
262            return Err(SocketSendError::NoSpace);
263        };
264
265        let used = serfn(that, hdr, &mut wgr)?;
266        let len = used as u16;
267        wgr.commit(len);
268
269        if let Some(wake) = qbox.waker.take() {
270            wake.wake();
271        }
272
273        Ok(())
274    }
275
276    fn recv_raw(
277        this: NonNull<()>,
278        that: &[u8],
279        _hdr: HeaderSeq,
280        hdr_raw: &[u8],
281    ) -> Result<(), SocketSendError> {
282        let this: NonNull<Self> = this.cast();
283        let this: &Self = unsafe { this.as_ref() };
284        let qbox: &mut QueueBox<Q> = unsafe { &mut *this.inner.get() };
285        let qref = qbox.q.bbq_ref();
286        let prod = qref.framed_producer();
287
288        let Ok(needed) = u16::try_from(that.len() + hdr_raw.len()) else {
289            return Err(SocketSendError::NoSpace);
290        };
291
292        let Ok(mut wgr) = prod.grant(needed) else {
293            return Err(SocketSendError::NoSpace);
294        };
295        let (hdr, body) = wgr.split_at_mut(hdr_raw.len());
296        hdr.copy_from_slice(hdr_raw);
297        body.copy_from_slice(that);
298        wgr.commit(needed);
299
300        if let Some(wake) = qbox.waker.take() {
301            wake.wake();
302        }
303
304        Ok(())
305    }
306}
307
308// impl SocketHdl
309
310impl<'a, Q, T, N> SocketHdl<'a, Q, T, N>
311where
312    Q: BbqHandle,
313    T: Serialize,
314    N: NetStackHandle,
315{
316    pub fn port(&self) -> u8 {
317        self.port
318    }
319
320    pub fn stack(&self) -> N::Target {
321        unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
322    }
323
324    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, Q, T, N> {
325        Recv { hdl: self }
326    }
327}
328
329impl<Q, T, N> Drop for Socket<Q, T, N>
330where
331    Q: BbqHandle,
332    T: Serialize,
333    N: NetStackHandle,
334{
335    fn drop(&mut self) {
336        unsafe {
337            let this = NonNull::from(&self.hdr);
338            self.net.detach_socket(this);
339        }
340    }
341}
342
343unsafe impl<Q, T, N> Send for SocketHdl<'_, Q, T, N>
344where
345    Q: BbqHandle,
346    T: Serialize,
347    N: NetStackHandle,
348{
349}
350
351unsafe impl<Q, T, N> Sync for SocketHdl<'_, Q, T, N>
352where
353    Q: BbqHandle,
354    T: Serialize,
355    N: NetStackHandle,
356{
357}
358
359// impl Recv
360
361impl<'a, Q, T, N> Future for Recv<'a, '_, Q, T, N>
362where
363    Q: BbqHandle,
364    T: Serialize,
365    N: NetStackHandle,
366{
367    type Output = ResponseGrant<Q, T>;
368
369    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
370        let net: N::Target = self.hdl.stack();
371        let f = || -> Option<ResponseGrant<Q, T>> {
372            let this_ref: &Socket<Q, T, N> = unsafe { self.hdl.ptr.as_ref() };
373            let qbox: &mut QueueBox<Q> = unsafe { &mut *this_ref.inner.get() };
374            let cons: FramedConsumer<Q, u16> = qbox.q.framed_consumer();
375
376            if let Ok(resp) = cons.read() {
377                let sli: &[u8] = resp.deref();
378
379                if let Some(frame) = de_frame(sli) {
380                    let BorrowedFrame {
381                        hdr,
382                        body,
383                        hdr_raw: _,
384                    } = frame;
385                    match body {
386                        Ok(body) => {
387                            let sli: &[u8] = body;
388                            // I want to be able to do something like this:
389                            //
390                            // if let Ok(_msg) = postcard::from_bytes::<T>(sli) {
391                            //     let offset =
392                            //         (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
393                            //     return Some(ResponseGrant {
394                            //         hdr,
395                            //         inner: ResponseGrantInner::Ok {
396                            //             grant: resp,
397                            //             offset,
398                            //             deser_erased: PhantomData,
399                            //         },
400                            //         _plt: PhantomData,
401                            //     });
402                            // } else {
403                            //     resp.release();
404                            // }
405                            let offset = (sli.as_ptr() as usize) - (resp.deref().as_ptr() as usize);
406                            return Some(ResponseGrant {
407                                hdr,
408                                inner: ResponseGrantInner::Ok {
409                                    grant: resp,
410                                    offset,
411                                    deser_erased: PhantomData,
412                                },
413                            });
414                        }
415                        Err(err) => {
416                            resp.release();
417                            return Some(ResponseGrant {
418                                hdr,
419                                inner: ResponseGrantInner::Err(err),
420                            });
421                        }
422                    }
423                }
424            }
425
426            let new_wake = cx.waker();
427            if let Some(w) = qbox.waker.take() {
428                if !w.will_wake(new_wake) {
429                    w.wake();
430                }
431            }
432            // NOTE: Okay to register waker AFTER checking, because we
433            // have an exclusive lock
434            qbox.waker = Some(new_wake.clone());
435            None
436        };
437        let res = unsafe { net.with_lock(f) };
438        if let Some(t) = res {
439            Poll::Ready(t)
440        } else {
441            Poll::Pending
442        }
443    }
444}
445
446unsafe impl<Q, T, N> Sync for Recv<'_, '_, Q, T, N>
447where
448    Q: BbqHandle,
449    T: Serialize,
450    N: NetStackHandle,
451{
452}
453
454// impl ResponseGrant
455
456impl<Q: BbqHandle, T> ResponseGrant<Q, T> {
457    // TODO: I don't want this being failable, but right now I can't figure out
458    // how to make Recv::poll() do the checking without hitting awkward inner
459    // lifetimes for deserialization. If you know how to make this less awkward,
460    // please @ me somewhere about it.
461    pub fn try_access<'de, 'me: 'de>(&'me self) -> Option<Response<T>>
462    where
463        T: Deserialize<'de>,
464    {
465        Some(match &self.inner {
466            ResponseGrantInner::Ok {
467                grant,
468                deser_erased: _,
469                offset,
470            } => {
471                // TODO: We could use something like Yoke to skip repeating deser
472                let t = postcard::from_bytes::<T>(grant.get(*offset..)?).ok()?;
473                Response::Ok(HeaderMessage {
474                    hdr: self.hdr.clone(),
475                    t,
476                })
477            }
478            ResponseGrantInner::Err(protocol_error) => Response::Err(HeaderMessage {
479                hdr: self.hdr.clone(),
480                t: *protocol_error,
481            }),
482        })
483    }
484}
485
486impl<Q: BbqHandle, T> Drop for ResponseGrant<Q, T> {
487    fn drop(&mut self) {
488        let old = core::mem::replace(
489            &mut self.inner,
490            ResponseGrantInner::Err(ProtocolError(u16::MAX)),
491        );
492        match old {
493            ResponseGrantInner::Ok { grant, .. } => {
494                grant.release();
495            }
496            ResponseGrantInner::Err(_) => {}
497        }
498    }
499}