ergot_base/socket/
raw_owned.rs

1//! "Raw Owned" sockets
2//!
3//! "Owned" sockets require `T: 'static`, and store messages in their deserialized `T` form,
4//! rather as serialized bytes.
5//!
6//! "Raw Owned" sockets are generic over the [`Storage`] trait, which describes a basic
7//! ring buffer. The [`owned`](crate::socket::owned) module contains variants of this
8//! raw socket that use a specific kind of ring buffer impl, e.g. using std or stackful
9//! storage.
10
11use core::{
12    any::TypeId,
13    cell::UnsafeCell,
14    marker::PhantomData,
15    pin::Pin,
16    ptr::{NonNull, addr_of},
17    task::{Context, Poll, Waker},
18};
19
20use cordyceps::list::Links;
21use mutex::ScopedRawMutex;
22use serde::de::DeserializeOwned;
23
24use crate::{
25    HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager, nash::NameHash,
26};
27
28use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
29
30#[derive(Debug, PartialEq)]
31pub struct StorageFull;
32
33pub trait Storage<T: 'static>: 'static {
34    fn is_full(&self) -> bool;
35    fn is_empty(&self) -> bool;
36    fn push(&mut self, t: T) -> Result<(), StorageFull>;
37    fn try_pop(&mut self) -> Option<T>;
38}
39
40// Owned Socket
41#[repr(C)]
42pub struct Socket<S, T, R, M>
43where
44    S: Storage<Response<T>>,
45    T: Clone + DeserializeOwned + 'static,
46    R: ScopedRawMutex + 'static,
47    M: InterfaceManager + 'static,
48{
49    // LOAD BEARING: must be first
50    hdr: SocketHeader,
51    pub(crate) net: &'static NetStack<R, M>,
52    // TODO: just a single item, we probably want a more ring-buffery
53    // option for this.
54    inner: UnsafeCell<StoreBox<S, Response<T>>>,
55}
56
57pub struct SocketHdl<'a, S, T, R, M>
58where
59    S: Storage<Response<T>>,
60    T: Clone + DeserializeOwned + 'static,
61    R: ScopedRawMutex + 'static,
62    M: InterfaceManager + 'static,
63{
64    pub(crate) ptr: NonNull<Socket<S, T, R, M>>,
65    _lt: PhantomData<Pin<&'a mut Socket<S, T, R, M>>>,
66    port: u8,
67}
68
69pub struct Recv<'a, 'b, S, T, R, M>
70where
71    S: Storage<Response<T>>,
72    T: Clone + DeserializeOwned + 'static,
73    R: ScopedRawMutex + 'static,
74    M: InterfaceManager + 'static,
75{
76    hdl: &'a mut SocketHdl<'b, S, T, R, M>,
77}
78
79struct StoreBox<S: Storage<T>, T: 'static> {
80    wait: Option<Waker>,
81    sto: S,
82    _pd: PhantomData<fn() -> T>,
83}
84
85// ---- impls ----
86
87// impl OwnedSocket
88
89impl<S, T, R, M> Socket<S, T, R, M>
90where
91    S: Storage<Response<T>>,
92    T: Clone + DeserializeOwned + 'static,
93    R: ScopedRawMutex + 'static,
94    M: InterfaceManager + 'static,
95{
96    pub const fn new(
97        net: &'static NetStack<R, M>,
98        key: Key,
99        attrs: Attributes,
100        sto: S,
101        name: Option<&str>,
102    ) -> Self {
103        Self {
104            hdr: SocketHeader {
105                links: Links::new(),
106                vtable: const { &Self::vtable() },
107                port: 0,
108                attrs,
109                key,
110                nash: if let Some(n) = name {
111                    Some(NameHash::new(n))
112                } else {
113                    None
114                },
115            },
116            inner: UnsafeCell::new(StoreBox::new(sto)),
117            net,
118        }
119    }
120
121    pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
122        let stack = self.net;
123        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
124        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
125        let port = unsafe { stack.attach_socket(ptr_erase) };
126        SocketHdl {
127            ptr: ptr_self,
128            _lt: PhantomData,
129            port,
130        }
131    }
132
133    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
134        let stack = self.net;
135        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
136        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
137        unsafe { stack.attach_broadcast_socket(ptr_erase) };
138        SocketHdl {
139            ptr: ptr_self,
140            _lt: PhantomData,
141            port: 255,
142        }
143    }
144
145    const fn vtable() -> SocketVTable {
146        SocketVTable {
147            recv_owned: Some(Self::recv_owned),
148            // TODO: We probably COULD support this, but I'm pretty sure it
149            // would require serializing, copying to a buffer, then later
150            // deserializing. I really don't know if we WANT this.
151            recv_bor: None,
152            recv_raw: Self::recv_raw,
153            recv_err: Some(Self::recv_err),
154        }
155    }
156
157    pub fn stack(&self) -> &'static NetStack<R, M> {
158        self.net
159    }
160
161    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
162        let this: NonNull<Self> = this.cast();
163        let this: &Self = unsafe { this.as_ref() };
164        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
165
166        let msg = Err(HeaderMessage { hdr, t: err });
167        if mutitem.sto.push(msg).is_ok() {
168            if let Some(w) = mutitem.wait.take() {
169                w.wake();
170            }
171        }
172    }
173
174    fn recv_owned(
175        this: NonNull<()>,
176        that: NonNull<()>,
177        hdr: HeaderSeq,
178        ty: &TypeId,
179    ) -> Result<(), SocketSendError> {
180        if &TypeId::of::<T>() != ty {
181            debug_assert!(false, "Type Mismatch!");
182            return Err(SocketSendError::TypeMismatch);
183        }
184        let that: NonNull<T> = that.cast();
185        let that: &T = unsafe { that.as_ref() };
186        let this: NonNull<Self> = this.cast();
187        let this: &Self = unsafe { this.as_ref() };
188        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
189
190        let msg = Ok(HeaderMessage {
191            hdr,
192            t: that.clone(),
193        });
194
195        match mutitem.sto.push(msg) {
196            Ok(()) => {
197                if let Some(w) = mutitem.wait.take() {
198                    w.wake();
199                }
200                Ok(())
201            }
202            Err(StorageFull) => Err(SocketSendError::NoSpace),
203        }
204    }
205
206    // fn send_bor(
207    //     this: NonNull<()>,
208    //     that: NonNull<()>,
209    //     src: Address,
210    //     dst: Address,
211    // ) -> Result<(), ()> {
212    //     // I don't think we can support this?
213    //     Err(())
214    // }
215
216    fn recv_raw(
217        this: NonNull<()>,
218        that: &[u8],
219        hdr: HeaderSeq,
220        _hdr_raw: &[u8],
221    ) -> Result<(), SocketSendError> {
222        let this: NonNull<Self> = this.cast();
223        let this: &Self = unsafe { this.as_ref() };
224        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
225
226        if mutitem.sto.is_full() {
227            return Err(SocketSendError::NoSpace);
228        }
229
230        if let Ok(t) = postcard::from_bytes::<T>(that) {
231            let msg = Ok(HeaderMessage { hdr, t });
232            let _ = mutitem.sto.push(msg);
233            if let Some(w) = mutitem.wait.take() {
234                w.wake();
235            }
236            Ok(())
237        } else {
238            Err(SocketSendError::DeserFailed)
239        }
240    }
241}
242
243// impl OwnedSocketHdl
244
245// TODO: impl drop, remove waker, remove socket
246impl<'a, S, T, R, M> SocketHdl<'a, S, T, R, M>
247where
248    S: Storage<Response<T>>,
249    T: Clone + DeserializeOwned + 'static,
250    R: ScopedRawMutex + 'static,
251    M: InterfaceManager + 'static,
252{
253    pub fn port(&self) -> u8 {
254        self.port
255    }
256
257    pub fn stack(&self) -> &'static NetStack<R, M> {
258        unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
259    }
260
261    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, R, M> {
262        Recv { hdl: self }
263    }
264}
265
266impl<S, T, R, M> Drop for Socket<S, T, R, M>
267where
268    S: Storage<Response<T>>,
269    T: Clone + DeserializeOwned + 'static,
270    R: ScopedRawMutex + 'static,
271    M: InterfaceManager + 'static,
272{
273    fn drop(&mut self) {
274        unsafe {
275            let this = NonNull::from(&self.hdr);
276            self.net.detach_socket(this);
277        }
278    }
279}
280
281unsafe impl<S, T, R, M> Send for SocketHdl<'_, S, T, R, M>
282where
283    S: Storage<Response<T>>,
284    T: Send,
285    T: Clone + DeserializeOwned + 'static,
286    R: ScopedRawMutex + 'static,
287    M: InterfaceManager + 'static,
288{
289}
290
291unsafe impl<S, T, R, M> Sync for SocketHdl<'_, S, T, R, M>
292where
293    S: Storage<Response<T>>,
294    T: Send,
295    T: Clone + DeserializeOwned + 'static,
296    R: ScopedRawMutex + 'static,
297    M: InterfaceManager + 'static,
298{
299}
300
301// impl Recv
302
303impl<S, T, R, M> Future for Recv<'_, '_, S, T, R, M>
304where
305    S: Storage<Response<T>>,
306    T: Clone + DeserializeOwned + 'static,
307    R: ScopedRawMutex + 'static,
308    M: InterfaceManager + 'static,
309{
310    type Output = Response<T>;
311
312    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313        let net: &'static NetStack<R, M> = self.hdl.stack();
314        let f = || {
315            let this_ref: &Socket<S, T, R, M> = unsafe { self.hdl.ptr.as_ref() };
316            let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
317
318            if let Some(resp) = box_ref.sto.try_pop() {
319                return Some(resp);
320            }
321
322            let new_wake = cx.waker();
323            if let Some(w) = box_ref.wait.take() {
324                if !w.will_wake(new_wake) {
325                    w.wake();
326                }
327            }
328            // NOTE: Okay to register waker AFTER checking, because we
329            // have an exclusive lock
330            box_ref.wait = Some(new_wake.clone());
331            None
332        };
333        let res = unsafe { net.with_lock(f) };
334        if let Some(t) = res {
335            Poll::Ready(t)
336        } else {
337            Poll::Pending
338        }
339    }
340}
341
342unsafe impl<S, T, R, M> Sync for Recv<'_, '_, S, T, R, M>
343where
344    S: Storage<Response<T>>,
345    T: Send,
346    T: Clone + DeserializeOwned + 'static,
347    R: ScopedRawMutex + 'static,
348    M: InterfaceManager + 'static,
349{
350}
351
352// impl OneBox
353
354impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
355    const fn new(sto: S) -> Self {
356        Self {
357            wait: None,
358            sto,
359            _pd: PhantomData,
360        }
361    }
362}