ergot_base/socket/
raw_owned.rs

1//! "Raw Owned" sockets
2//!
3//! "Owned" sockets require `T: 'static`, and store messages in their deserialized `T` form,
4//! rather as serialized bytes.
5//!
6//! "Raw Owned" sockets are generic over the [`Storage`] trait, which describes a basic
7//! ring buffer. The [`owned`](crate::socket::owned) module contains variants of this
8//! raw socket that use a specific kind of ring buffer impl, e.g. using std or stackful
9//! storage.
10
11use core::{
12    any::TypeId,
13    cell::UnsafeCell,
14    marker::PhantomData,
15    pin::Pin,
16    ptr::{NonNull, addr_of},
17    task::{Context, Poll, Waker},
18};
19
20use cordyceps::list::Links;
21use serde::de::DeserializeOwned;
22
23use crate::{HeaderSeq, Key, ProtocolError, nash::NameHash, net_stack::NetStackHandle};
24
25use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
26
27#[derive(Debug, PartialEq)]
28pub struct StorageFull;
29
30pub trait Storage<T: 'static>: 'static {
31    fn is_full(&self) -> bool;
32    fn is_empty(&self) -> bool;
33    fn push(&mut self, t: T) -> Result<(), StorageFull>;
34    fn try_pop(&mut self) -> Option<T>;
35}
36
37// Socket
38#[repr(C)]
39pub struct Socket<S, T, N>
40where
41    S: Storage<Response<T>>,
42    T: Clone + DeserializeOwned + 'static,
43    N: NetStackHandle,
44{
45    // LOAD BEARING: must be first
46    hdr: SocketHeader,
47    pub(crate) net: N::Target,
48    inner: UnsafeCell<StoreBox<S, Response<T>>>,
49}
50
51pub struct SocketHdl<'a, S, T, N>
52where
53    S: Storage<Response<T>>,
54    T: Clone + DeserializeOwned + 'static,
55    N: NetStackHandle,
56{
57    pub(crate) ptr: NonNull<Socket<S, T, N>>,
58    _lt: PhantomData<Pin<&'a mut Socket<S, T, N>>>,
59    port: u8,
60}
61
62pub struct Recv<'a, 'b, S, T, N>
63where
64    S: Storage<Response<T>>,
65    T: Clone + DeserializeOwned + 'static,
66    N: NetStackHandle,
67{
68    hdl: &'a mut SocketHdl<'b, S, T, N>,
69}
70
71struct StoreBox<S: Storage<T>, T: 'static> {
72    wait: Option<Waker>,
73    sto: S,
74    _pd: PhantomData<fn() -> T>,
75}
76
77// ---- impls ----
78
79// impl Socket
80
81impl<S, T, N> Socket<S, T, N>
82where
83    S: Storage<Response<T>>,
84    T: Clone + DeserializeOwned + 'static,
85    N: NetStackHandle,
86{
87    pub const fn new(
88        net: N::Target,
89        key: Key,
90        attrs: Attributes,
91        sto: S,
92        name: Option<&str>,
93    ) -> Self {
94        Self {
95            hdr: SocketHeader {
96                links: Links::new(),
97                vtable: const { &Self::vtable() },
98                port: 0,
99                attrs,
100                key,
101                nash: if let Some(n) = name {
102                    Some(NameHash::new(n))
103                } else {
104                    None
105                },
106            },
107            inner: UnsafeCell::new(StoreBox::new(sto)),
108            net,
109        }
110    }
111
112    pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
113        let stack = self.net.clone();
114        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
115        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
116        let port = unsafe { stack.attach_socket(ptr_erase) };
117        SocketHdl {
118            ptr: ptr_self,
119            _lt: PhantomData,
120            port,
121        }
122    }
123
124    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
125        let stack = self.net.clone();
126        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
127        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
128        unsafe { stack.attach_broadcast_socket(ptr_erase) };
129        SocketHdl {
130            ptr: ptr_self,
131            _lt: PhantomData,
132            port: 255,
133        }
134    }
135
136    const fn vtable() -> SocketVTable {
137        SocketVTable {
138            recv_owned: Some(Self::recv_owned),
139            // TODO: We probably COULD support this, but I'm pretty sure it
140            // would require serializing, copying to a buffer, then later
141            // deserializing. I really don't know if we WANT this.
142            recv_bor: None,
143            recv_raw: Self::recv_raw,
144            recv_err: Some(Self::recv_err),
145        }
146    }
147
148    pub fn stack(&self) -> N::Target {
149        self.net.clone()
150    }
151
152    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
153        let this: NonNull<Self> = this.cast();
154        let this: &Self = unsafe { this.as_ref() };
155        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
156
157        let msg = Err(HeaderMessage { hdr, t: err });
158        if mutitem.sto.push(msg).is_ok() {
159            if let Some(w) = mutitem.wait.take() {
160                w.wake();
161            }
162        }
163    }
164
165    fn recv_owned(
166        this: NonNull<()>,
167        that: NonNull<()>,
168        hdr: HeaderSeq,
169        ty: &TypeId,
170    ) -> Result<(), SocketSendError> {
171        if &TypeId::of::<T>() != ty {
172            debug_assert!(false, "Type Mismatch!");
173            return Err(SocketSendError::TypeMismatch);
174        }
175        let that: NonNull<T> = that.cast();
176        let that: &T = unsafe { that.as_ref() };
177        let this: NonNull<Self> = this.cast();
178        let this: &Self = unsafe { this.as_ref() };
179        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
180
181        let msg = Ok(HeaderMessage {
182            hdr,
183            t: that.clone(),
184        });
185
186        match mutitem.sto.push(msg) {
187            Ok(()) => {
188                if let Some(w) = mutitem.wait.take() {
189                    w.wake();
190                }
191                Ok(())
192            }
193            Err(StorageFull) => Err(SocketSendError::NoSpace),
194        }
195    }
196
197    fn recv_raw(
198        this: NonNull<()>,
199        that: &[u8],
200        hdr: HeaderSeq,
201        _hdr_raw: &[u8],
202    ) -> Result<(), SocketSendError> {
203        let this: NonNull<Self> = this.cast();
204        let this: &Self = unsafe { this.as_ref() };
205        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
206
207        if mutitem.sto.is_full() {
208            return Err(SocketSendError::NoSpace);
209        }
210
211        if let Ok(t) = postcard::from_bytes::<T>(that) {
212            let msg = Ok(HeaderMessage { hdr, t });
213            let _ = mutitem.sto.push(msg);
214            if let Some(w) = mutitem.wait.take() {
215                w.wake();
216            }
217            Ok(())
218        } else {
219            Err(SocketSendError::DeserFailed)
220        }
221    }
222}
223
224// impl SocketHdl
225
226impl<'a, S, T, N> SocketHdl<'a, S, T, N>
227where
228    S: Storage<Response<T>>,
229    T: Clone + DeserializeOwned + 'static,
230    N: NetStackHandle,
231{
232    pub fn port(&self) -> u8 {
233        self.port
234    }
235
236    pub fn stack(&self) -> N::Target {
237        unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
238    }
239
240    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, N> {
241        Recv { hdl: self }
242    }
243}
244
245impl<S, T, N> Drop for Socket<S, T, N>
246where
247    S: Storage<Response<T>>,
248    T: Clone + DeserializeOwned + 'static,
249    N: NetStackHandle,
250{
251    fn drop(&mut self) {
252        unsafe {
253            let this = NonNull::from(&self.hdr);
254            self.net.detach_socket(this);
255        }
256    }
257}
258
259unsafe impl<S, T, N> Send for SocketHdl<'_, S, T, N>
260where
261    S: Storage<Response<T>>,
262    T: Send,
263    T: Clone + DeserializeOwned + 'static,
264    N: NetStackHandle,
265{
266}
267
268unsafe impl<S, T, N> Sync for SocketHdl<'_, S, T, N>
269where
270    S: Storage<Response<T>>,
271    T: Send,
272    T: Clone + DeserializeOwned + 'static,
273    N: NetStackHandle,
274{
275}
276
277// impl Recv
278
279impl<S, T, N> Future for Recv<'_, '_, S, T, N>
280where
281    S: Storage<Response<T>>,
282    T: Clone + DeserializeOwned + 'static,
283    N: NetStackHandle,
284{
285    type Output = Response<T>;
286
287    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288        let net: N::Target = self.hdl.stack();
289        let f = || {
290            let this_ref: &Socket<S, T, N> = unsafe { self.hdl.ptr.as_ref() };
291            let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
292
293            if let Some(resp) = box_ref.sto.try_pop() {
294                return Some(resp);
295            }
296
297            let new_wake = cx.waker();
298            if let Some(w) = box_ref.wait.take() {
299                if !w.will_wake(new_wake) {
300                    w.wake();
301                }
302            }
303            // NOTE: Okay to register waker AFTER checking, because we
304            // have an exclusive lock
305            box_ref.wait = Some(new_wake.clone());
306            None
307        };
308        let res = unsafe { net.with_lock(f) };
309        if let Some(t) = res {
310            Poll::Ready(t)
311        } else {
312            Poll::Pending
313        }
314    }
315}
316
317unsafe impl<S, T, N> Sync for Recv<'_, '_, S, T, N>
318where
319    S: Storage<Response<T>>,
320    T: Send,
321    T: Clone + DeserializeOwned + 'static,
322    N: NetStackHandle,
323{
324}
325
326// impl StoreBox
327
328impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
329    const fn new(sto: S) -> Self {
330        Self {
331            wait: None,
332            sto,
333            _pd: PhantomData,
334        }
335    }
336}