ergot_base/socket/
raw_owned.rs

1//! "Raw Owned" sockets
2//!
3//! "Owned" sockets require `T: 'static`, and store messages in their deserialized `T` form,
4//! rather as serialized bytes.
5//!
6//! "Raw Owned" sockets are generic over the [`Storage`] trait, which describes a basic
7//! ring buffer. The [`owned`](crate::socket::owned) module contains variants of this
8//! raw socket that use a specific kind of ring buffer impl, e.g. using std or stackful
9//! storage.
10
11use core::{
12    any::TypeId,
13    cell::UnsafeCell,
14    marker::PhantomData,
15    pin::Pin,
16    ptr::{NonNull, addr_of},
17    task::{Context, Poll, Waker},
18};
19
20use cordyceps::list::Links;
21use serde::de::DeserializeOwned;
22
23use crate::{HeaderSeq, Key, ProtocolError, nash::NameHash, net_stack::NetStackHandle};
24
25use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
26
27#[derive(Debug, PartialEq)]
28pub struct StorageFull;
29
30pub trait Storage<T: 'static>: 'static {
31    fn is_full(&self) -> bool;
32    fn is_empty(&self) -> bool;
33    fn push(&mut self, t: T) -> Result<(), StorageFull>;
34    fn try_pop(&mut self) -> Option<T>;
35}
36
37// Socket
38#[repr(C)]
39pub struct Socket<S, T, N>
40where
41    S: Storage<Response<T>>,
42    T: Clone + DeserializeOwned + 'static,
43    N: NetStackHandle,
44{
45    // LOAD BEARING: must be first
46    hdr: SocketHeader,
47    pub(crate) net: N::Target,
48    inner: UnsafeCell<StoreBox<S, Response<T>>>,
49}
50
51pub struct SocketHdl<'a, S, T, N>
52where
53    S: Storage<Response<T>>,
54    T: Clone + DeserializeOwned + 'static,
55    N: NetStackHandle,
56{
57    pub(crate) ptr: NonNull<Socket<S, T, N>>,
58    _lt: PhantomData<Pin<&'a mut Socket<S, T, N>>>,
59    port: u8,
60}
61
62pub struct Recv<'a, 'b, S, T, N>
63where
64    S: Storage<Response<T>>,
65    T: Clone + DeserializeOwned + 'static,
66    N: NetStackHandle,
67{
68    hdl: &'a mut SocketHdl<'b, S, T, N>,
69}
70
71struct StoreBox<S: Storage<T>, T: 'static> {
72    wait: Option<Waker>,
73    sto: S,
74    _pd: PhantomData<fn() -> T>,
75}
76
77// ---- impls ----
78
79// impl Socket
80
81impl<S, T, N> Socket<S, T, N>
82where
83    S: Storage<Response<T>>,
84    T: Clone + DeserializeOwned + 'static,
85    N: NetStackHandle,
86{
87    pub const fn new(
88        net: N::Target,
89        key: Key,
90        attrs: Attributes,
91        sto: S,
92        name: Option<&str>,
93    ) -> Self {
94        Self {
95            hdr: SocketHeader {
96                links: Links::new(),
97                vtable: const { &Self::vtable() },
98                port: 0,
99                attrs,
100                key,
101                nash: if let Some(n) = name {
102                    Some(NameHash::new(n))
103                } else {
104                    None
105                },
106            },
107            inner: UnsafeCell::new(StoreBox::new(sto)),
108            net,
109        }
110    }
111
112    pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
113        let stack = self.net.clone();
114        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
115        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
116        let port = unsafe { stack.attach_socket(ptr_erase) };
117        SocketHdl {
118            ptr: ptr_self,
119            _lt: PhantomData,
120            port,
121        }
122    }
123
124    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
125        let stack = self.net.clone();
126        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
127        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
128        unsafe { stack.attach_broadcast_socket(ptr_erase) };
129        SocketHdl {
130            ptr: ptr_self,
131            _lt: PhantomData,
132            port: 255,
133        }
134    }
135
136    const fn vtable() -> SocketVTable {
137        SocketVTable {
138            recv_owned: Some(Self::recv_owned),
139            // TODO: We probably COULD support this, but I'm pretty sure it
140            // would require serializing, copying to a buffer, then later
141            // deserializing. I really don't know if we WANT this.
142            //
143            // TODO: EXTRA danger: if the item is borrowed we can't use TypeId,
144            // which makes it VERY DIFFICULT to do this soundly: The sender and receiver
145            // sockets might not ACTUALLY be the same type, for example if the two
146            // types pun to each other, e.g. `&str` and `String`. THERE BE EVEN MORE
147            // DRAGONS HERE
148            // Update: This is now somewhat better because send_bor passes a serializing fn
149            recv_bor: None,
150            recv_raw: Self::recv_raw,
151            recv_err: Some(Self::recv_err),
152        }
153    }
154
155    pub fn stack(&self) -> N::Target {
156        self.net.clone()
157    }
158
159    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
160        let this: NonNull<Self> = this.cast();
161        let this: &Self = unsafe { this.as_ref() };
162        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
163
164        let msg = Err(HeaderMessage { hdr, t: err });
165        if mutitem.sto.push(msg).is_ok() {
166            if let Some(w) = mutitem.wait.take() {
167                w.wake();
168            }
169        }
170    }
171
172    fn recv_owned(
173        this: NonNull<()>,
174        that: NonNull<()>,
175        hdr: HeaderSeq,
176        ty: &TypeId,
177    ) -> Result<(), SocketSendError> {
178        if &TypeId::of::<T>() != ty {
179            debug_assert!(false, "Type Mismatch!");
180            return Err(SocketSendError::TypeMismatch);
181        }
182        let that: NonNull<T> = that.cast();
183        let that: &T = unsafe { that.as_ref() };
184        let this: NonNull<Self> = this.cast();
185        let this: &Self = unsafe { this.as_ref() };
186        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
187
188        let msg = Ok(HeaderMessage {
189            hdr,
190            t: that.clone(),
191        });
192
193        match mutitem.sto.push(msg) {
194            Ok(()) => {
195                if let Some(w) = mutitem.wait.take() {
196                    w.wake();
197                }
198                Ok(())
199            }
200            Err(StorageFull) => Err(SocketSendError::NoSpace),
201        }
202    }
203
204    fn recv_raw(
205        this: NonNull<()>,
206        that: &[u8],
207        hdr: HeaderSeq,
208        _hdr_raw: &[u8],
209    ) -> Result<(), SocketSendError> {
210        let this: NonNull<Self> = this.cast();
211        let this: &Self = unsafe { this.as_ref() };
212        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
213
214        if mutitem.sto.is_full() {
215            return Err(SocketSendError::NoSpace);
216        }
217
218        if let Ok(t) = postcard::from_bytes::<T>(that) {
219            let msg = Ok(HeaderMessage { hdr, t });
220            let _ = mutitem.sto.push(msg);
221            if let Some(w) = mutitem.wait.take() {
222                w.wake();
223            }
224            Ok(())
225        } else {
226            Err(SocketSendError::DeserFailed)
227        }
228    }
229}
230
231// impl SocketHdl
232
233impl<'a, S, T, N> SocketHdl<'a, S, T, N>
234where
235    S: Storage<Response<T>>,
236    T: Clone + DeserializeOwned + 'static,
237    N: NetStackHandle,
238{
239    pub fn port(&self) -> u8 {
240        self.port
241    }
242
243    pub fn stack(&self) -> N::Target {
244        unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
245    }
246
247    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, N> {
248        Recv { hdl: self }
249    }
250}
251
252impl<S, T, N> Drop for Socket<S, T, N>
253where
254    S: Storage<Response<T>>,
255    T: Clone + DeserializeOwned + 'static,
256    N: NetStackHandle,
257{
258    fn drop(&mut self) {
259        unsafe {
260            let this = NonNull::from(&self.hdr);
261            self.net.detach_socket(this);
262        }
263    }
264}
265
266unsafe impl<S, T, N> Send for SocketHdl<'_, S, T, N>
267where
268    S: Storage<Response<T>>,
269    T: Send,
270    T: Clone + DeserializeOwned + 'static,
271    N: NetStackHandle,
272{
273}
274
275unsafe impl<S, T, N> Sync for SocketHdl<'_, S, T, N>
276where
277    S: Storage<Response<T>>,
278    T: Send,
279    T: Clone + DeserializeOwned + 'static,
280    N: NetStackHandle,
281{
282}
283
284// impl Recv
285
286impl<S, T, N> Future for Recv<'_, '_, S, T, N>
287where
288    S: Storage<Response<T>>,
289    T: Clone + DeserializeOwned + 'static,
290    N: NetStackHandle,
291{
292    type Output = Response<T>;
293
294    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
295        let net: N::Target = self.hdl.stack();
296        let f = || {
297            let this_ref: &Socket<S, T, N> = unsafe { self.hdl.ptr.as_ref() };
298            let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
299
300            if let Some(resp) = box_ref.sto.try_pop() {
301                return Some(resp);
302            }
303
304            let new_wake = cx.waker();
305            if let Some(w) = box_ref.wait.take() {
306                if !w.will_wake(new_wake) {
307                    w.wake();
308                }
309            }
310            // NOTE: Okay to register waker AFTER checking, because we
311            // have an exclusive lock
312            box_ref.wait = Some(new_wake.clone());
313            None
314        };
315        let res = unsafe { net.with_lock(f) };
316        if let Some(t) = res {
317            Poll::Ready(t)
318        } else {
319            Poll::Pending
320        }
321    }
322}
323
324unsafe impl<S, T, N> Sync for Recv<'_, '_, S, T, N>
325where
326    S: Storage<Response<T>>,
327    T: Send,
328    T: Clone + DeserializeOwned + 'static,
329    N: NetStackHandle,
330{
331}
332
333// impl StoreBox
334
335impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
336    const fn new(sto: S) -> Self {
337        Self {
338            wait: None,
339            sto,
340            _pd: PhantomData,
341        }
342    }
343}