ergot_base/socket/
raw.rs

1use core::{
2    any::TypeId,
3    cell::UnsafeCell,
4    marker::PhantomData,
5    pin::Pin,
6    ptr::{NonNull, addr_of},
7    task::{Context, Poll, Waker},
8};
9
10use cordyceps::list::Links;
11use mutex::ScopedRawMutex;
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager};
15
16use super::{Attributes, OwnedMessage, Response, SocketHeader, SocketSendError, SocketVTable};
17
18#[derive(Debug, PartialEq)]
19pub struct StorageFull;
20
21pub trait Storage<T: 'static>: 'static {
22    fn is_full(&self) -> bool;
23    fn is_empty(&self) -> bool;
24    fn push(&mut self, t: T) -> Result<(), StorageFull>;
25    fn try_pop(&mut self) -> Option<T>;
26}
27
28// Owned Socket
29#[repr(C)]
30pub struct Socket<S, T, R, M>
31where
32    S: Storage<Response<T>>,
33    T: Serialize + Clone + DeserializeOwned + 'static,
34    R: ScopedRawMutex + 'static,
35    M: InterfaceManager + 'static,
36{
37    // LOAD BEARING: must be first
38    hdr: SocketHeader,
39    pub(crate) net: &'static NetStack<R, M>,
40    // TODO: just a single item, we probably want a more ring-buffery
41    // option for this.
42    inner: UnsafeCell<StoreBox<S, Response<T>>>,
43}
44
45pub struct SocketHdl<'a, S, T, R, M>
46where
47    S: Storage<Response<T>>,
48    T: Serialize + Clone + DeserializeOwned + 'static,
49    R: ScopedRawMutex + 'static,
50    M: InterfaceManager + 'static,
51{
52    pub(crate) ptr: NonNull<Socket<S, T, R, M>>,
53    _lt: PhantomData<Pin<&'a mut Socket<S, T, R, M>>>,
54    port: u8,
55}
56
57pub struct Recv<'a, 'b, S, T, R, M>
58where
59    S: Storage<Response<T>>,
60    T: Serialize + Clone + DeserializeOwned + 'static,
61    R: ScopedRawMutex + 'static,
62    M: InterfaceManager + 'static,
63{
64    hdl: &'a mut SocketHdl<'b, S, T, R, M>,
65}
66
67struct StoreBox<S: Storage<T>, T: 'static> {
68    wait: Option<Waker>,
69    sto: S,
70    _pd: PhantomData<fn() -> T>,
71}
72
73// ---- impls ----
74
75// impl OwnedMessage
76
77// ...
78
79// impl OwnedSocket
80
81impl<S, T, R, M> Socket<S, T, R, M>
82where
83    S: Storage<Response<T>>,
84    T: Serialize + Clone + DeserializeOwned + 'static,
85    R: ScopedRawMutex + 'static,
86    M: InterfaceManager + 'static,
87{
88    pub const fn new(net: &'static NetStack<R, M>, key: Key, attrs: Attributes, sto: S) -> Self {
89        Self {
90            hdr: SocketHeader {
91                links: Links::new(),
92                vtable: const { &Self::vtable() },
93                port: 0,
94                attrs,
95                key,
96            },
97            inner: UnsafeCell::new(StoreBox::new(sto)),
98            net,
99        }
100    }
101
102    pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
103        let stack = self.net;
104        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
105        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
106        let port = unsafe { stack.attach_socket(ptr_erase) };
107        SocketHdl {
108            ptr: ptr_self,
109            _lt: PhantomData,
110            port,
111        }
112    }
113
114    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
115        let stack = self.net;
116        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
117        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
118        unsafe { stack.attach_broadcast_socket(ptr_erase) };
119        SocketHdl {
120            ptr: ptr_self,
121            _lt: PhantomData,
122            port: 255,
123        }
124    }
125
126    const fn vtable() -> SocketVTable {
127        SocketVTable {
128            recv_owned: Some(Self::recv_owned),
129            // TODO: We probably COULD support this, but I'm pretty sure it
130            // would require serializing, copying to a buffer, then later
131            // deserializing. I really don't know if we WANT this.
132            recv_bor: None,
133            recv_raw: Self::recv_raw,
134            recv_err: Some(Self::recv_err),
135        }
136    }
137
138    pub fn stack(&self) -> &'static NetStack<R, M> {
139        self.net
140    }
141
142    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
143        let this: NonNull<Self> = this.cast();
144        let this: &Self = unsafe { this.as_ref() };
145        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
146
147        let msg = Err(OwnedMessage { hdr, t: err });
148        if mutitem.sto.push(msg).is_ok() {
149            if let Some(w) = mutitem.wait.take() {
150                w.wake();
151            }
152        }
153    }
154
155    fn recv_owned(
156        this: NonNull<()>,
157        that: NonNull<()>,
158        hdr: HeaderSeq,
159        ty: &TypeId,
160    ) -> Result<(), SocketSendError> {
161        if &TypeId::of::<T>() != ty {
162            debug_assert!(false, "Type Mismatch!");
163            return Err(SocketSendError::TypeMismatch);
164        }
165        let that: NonNull<T> = that.cast();
166        let that: &T = unsafe { that.as_ref() };
167        let this: NonNull<Self> = this.cast();
168        let this: &Self = unsafe { this.as_ref() };
169        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
170
171        let msg = Ok(OwnedMessage {
172            hdr,
173            t: that.clone(),
174        });
175
176        match mutitem.sto.push(msg) {
177            Ok(()) => {
178                if let Some(w) = mutitem.wait.take() {
179                    w.wake();
180                }
181                Ok(())
182            }
183            Err(StorageFull) => Err(SocketSendError::NoSpace),
184        }
185    }
186
187    // fn send_bor(
188    //     this: NonNull<()>,
189    //     that: NonNull<()>,
190    //     src: Address,
191    //     dst: Address,
192    // ) -> Result<(), ()> {
193    //     // I don't think we can support this?
194    //     Err(())
195    // }
196
197    fn recv_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
198        let this: NonNull<Self> = this.cast();
199        let this: &Self = unsafe { this.as_ref() };
200        let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
201
202        if mutitem.sto.is_full() {
203            return Err(SocketSendError::NoSpace);
204        }
205
206        if let Ok(t) = postcard::from_bytes::<T>(that) {
207            let msg = Ok(OwnedMessage { hdr, t });
208            let _ = mutitem.sto.push(msg);
209            if let Some(w) = mutitem.wait.take() {
210                w.wake();
211            }
212            Ok(())
213        } else {
214            Err(SocketSendError::DeserFailed)
215        }
216    }
217}
218
219// impl OwnedSocketHdl
220
221// TODO: impl drop, remove waker, remove socket
222impl<'a, S, T, R, M> SocketHdl<'a, S, T, R, M>
223where
224    S: Storage<Response<T>>,
225    T: Serialize + Clone + DeserializeOwned + 'static,
226    R: ScopedRawMutex + 'static,
227    M: InterfaceManager + 'static,
228{
229    pub fn port(&self) -> u8 {
230        self.port
231    }
232
233    pub fn stack(&self) -> &'static NetStack<R, M> {
234        unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
235    }
236
237    // TODO: This future is !Send? I don't fully understand why, but rustc complains
238    // that since `NonNull<OwnedSocket<E>>` is !Sync, then this future can't be Send,
239    // BUT impl'ing Sync unsafely on OwnedSocketHdl + OwnedSocket doesn't seem to help.
240    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, R, M> {
241        Recv { hdl: self }
242    }
243}
244
245impl<S, T, R, M> Drop for Socket<S, T, R, M>
246where
247    S: Storage<Response<T>>,
248    T: Serialize + Clone + DeserializeOwned + 'static,
249    R: ScopedRawMutex + 'static,
250    M: InterfaceManager + 'static,
251{
252    fn drop(&mut self) {
253        unsafe {
254            let this = NonNull::from(&self.hdr);
255            self.net.detach_socket(this);
256        }
257    }
258}
259
260unsafe impl<S, T, R, M> Send for SocketHdl<'_, S, T, R, M>
261where
262    S: Storage<Response<T>>,
263    T: Send,
264    T: Serialize + Clone + DeserializeOwned + 'static,
265    R: ScopedRawMutex + 'static,
266    M: InterfaceManager + 'static,
267{
268}
269
270unsafe impl<S, T, R, M> Sync for SocketHdl<'_, S, T, R, M>
271where
272    S: Storage<Response<T>>,
273    T: Send,
274    T: Serialize + Clone + DeserializeOwned + 'static,
275    R: ScopedRawMutex + 'static,
276    M: InterfaceManager + 'static,
277{
278}
279
280// impl Recv
281
282impl<S, T, R, M> Future for Recv<'_, '_, S, T, R, M>
283where
284    S: Storage<Response<T>>,
285    T: Serialize + Clone + DeserializeOwned + 'static,
286    R: ScopedRawMutex + 'static,
287    M: InterfaceManager + 'static,
288{
289    type Output = Response<T>;
290
291    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292        let net: &'static NetStack<R, M> = self.hdl.stack();
293        let f = || {
294            let this_ref: &Socket<S, T, R, M> = unsafe { self.hdl.ptr.as_ref() };
295            let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
296
297            if let Some(resp) = box_ref.sto.try_pop() {
298                return Some(resp);
299            }
300
301            let new_wake = cx.waker();
302            if let Some(w) = box_ref.wait.take() {
303                if !w.will_wake(new_wake) {
304                    w.wake();
305                }
306            }
307            // NOTE: Okay to register waker AFTER checking, because we
308            // have an exclusive lock
309            box_ref.wait = Some(new_wake.clone());
310            None
311        };
312        let res = unsafe { net.with_lock(f) };
313        if let Some(t) = res {
314            Poll::Ready(t)
315        } else {
316            Poll::Pending
317        }
318    }
319}
320
321unsafe impl<S, T, R, M> Sync for Recv<'_, '_, S, T, R, M>
322where
323    S: Storage<Response<T>>,
324    T: Send,
325    T: Serialize + Clone + DeserializeOwned + 'static,
326    R: ScopedRawMutex + 'static,
327    M: InterfaceManager + 'static,
328{
329}
330
331// impl OneBox
332
333impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
334    const fn new(sto: S) -> Self {
335        Self {
336            wait: None,
337            sto,
338            _pd: PhantomData,
339        }
340    }
341}