ergot_base/socket/
owned.rs

1use core::{
2    any::TypeId,
3    cell::UnsafeCell,
4    marker::PhantomData,
5    pin::Pin,
6    ptr::{NonNull, addr_of},
7    task::{Context, Poll, Waker},
8};
9
10use cordyceps::list::Links;
11use mutex::ScopedRawMutex;
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager};
15
16use super::{
17    Attributes, Contents, OwnedMessage, Response, SocketHeader, SocketSendError, SocketVTable,
18};
19
20// Owned Socket
21#[repr(C)]
22pub struct OwnedSocket<T, R, M>
23where
24    T: Serialize + Clone + DeserializeOwned + 'static,
25    R: ScopedRawMutex + 'static,
26    M: InterfaceManager + 'static,
27{
28    // LOAD BEARING: must be first
29    hdr: SocketHeader,
30    pub(crate) net: &'static NetStack<R, M>,
31    // TODO: just a single item, we probably want a more ring-buffery
32    // option for this.
33    inner: UnsafeCell<OneBox<T>>,
34}
35
36pub struct OwnedSocketHdl<'a, T, R, M>
37where
38    T: Serialize + Clone + DeserializeOwned + 'static,
39    R: ScopedRawMutex + 'static,
40    M: InterfaceManager + 'static,
41{
42    pub(crate) ptr: NonNull<OwnedSocket<T, R, M>>,
43    _lt: PhantomData<Pin<&'a mut OwnedSocket<T, R, M>>>,
44    port: u8,
45}
46
47pub struct Recv<'a, 'b, T, R, M>
48where
49    T: Serialize + Clone + DeserializeOwned + 'static,
50    R: ScopedRawMutex + 'static,
51    M: InterfaceManager + 'static,
52{
53    hdl: &'a mut OwnedSocketHdl<'b, T, R, M>,
54}
55
56struct OneBox<T: 'static> {
57    wait: Option<Waker>,
58    t: Contents<T>,
59}
60
61// ---- impls ----
62
63// impl OwnedMessage
64
65// ...
66
67// impl OwnedSocket
68
69impl<T, R, M> OwnedSocket<T, R, M>
70where
71    T: Serialize + Clone + DeserializeOwned + 'static,
72    R: ScopedRawMutex + 'static,
73    M: InterfaceManager + 'static,
74{
75    pub const fn new(net: &'static NetStack<R, M>, key: Key, attrs: Attributes) -> Self {
76        Self {
77            hdr: SocketHeader {
78                links: Links::new(),
79                vtable: const { &Self::vtable() },
80                port: 0,
81                attrs,
82                key,
83            },
84            inner: UnsafeCell::new(OneBox::new()),
85            net,
86        }
87    }
88
89    pub fn attach<'a>(self: Pin<&'a mut Self>) -> OwnedSocketHdl<'a, T, R, M> {
90        let stack = self.net;
91        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
92        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
93        let port = unsafe { stack.attach_socket(ptr_erase) };
94        OwnedSocketHdl {
95            ptr: ptr_self,
96            _lt: PhantomData,
97            port,
98        }
99    }
100
101    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> OwnedSocketHdl<'a, T, R, M> {
102        let stack = self.net;
103        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
104        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
105        unsafe { stack.attach_broadcast_socket(ptr_erase) };
106        OwnedSocketHdl {
107            ptr: ptr_self,
108            _lt: PhantomData,
109            port: 255,
110        }
111    }
112
113    const fn vtable() -> SocketVTable {
114        SocketVTable {
115            recv_owned: Some(Self::recv_owned),
116            // TODO: We probably COULD support this, but I'm pretty sure it
117            // would require serializing, copying to a buffer, then later
118            // deserializing. I really don't know if we WANT this.
119            recv_bor: None,
120            recv_raw: Self::recv_raw,
121            recv_err: Some(Self::recv_err),
122        }
123    }
124
125    pub fn stack(&self) -> &'static NetStack<R, M> {
126        self.net
127    }
128
129    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
130        let this: NonNull<Self> = this.cast();
131        let this: &Self = unsafe { this.as_ref() };
132        let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
133
134        if !matches!(mutitem.t, Contents::None) {
135            return;
136        }
137
138        mutitem.t = Contents::Err(OwnedMessage { hdr, t: err });
139        if let Some(w) = mutitem.wait.take() {
140            w.wake();
141        }
142    }
143
144    fn recv_owned(
145        this: NonNull<()>,
146        that: NonNull<()>,
147        hdr: HeaderSeq,
148        ty: &TypeId,
149    ) -> Result<(), SocketSendError> {
150        if &TypeId::of::<T>() != ty {
151            debug_assert!(false, "Type Mismatch!");
152            return Err(SocketSendError::TypeMismatch);
153        }
154        let that: NonNull<T> = that.cast();
155        let that: &T = unsafe { that.as_ref() };
156        let this: NonNull<Self> = this.cast();
157        let this: &Self = unsafe { this.as_ref() };
158        let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
159
160        if !matches!(mutitem.t, Contents::None) {
161            return Err(SocketSendError::NoSpace);
162        }
163
164        mutitem.t = Contents::Mesg(OwnedMessage {
165            hdr,
166            t: that.clone(),
167        });
168        if let Some(w) = mutitem.wait.take() {
169            w.wake();
170        }
171
172        Ok(())
173    }
174
175    // fn send_bor(
176    //     this: NonNull<()>,
177    //     that: NonNull<()>,
178    //     src: Address,
179    //     dst: Address,
180    // ) -> Result<(), ()> {
181    //     // I don't think we can support this?
182    //     Err(())
183    // }
184
185    fn recv_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
186        let this: NonNull<Self> = this.cast();
187        let this: &Self = unsafe { this.as_ref() };
188        let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
189
190        if !matches!(mutitem.t, Contents::None) {
191            return Err(SocketSendError::NoSpace);
192        }
193
194        if let Ok(t) = postcard::from_bytes::<T>(that) {
195            mutitem.t = Contents::Mesg(OwnedMessage { hdr, t });
196            if let Some(w) = mutitem.wait.take() {
197                w.wake();
198            }
199            Ok(())
200        } else {
201            Err(SocketSendError::DeserFailed)
202        }
203    }
204}
205
206// impl OwnedSocketHdl
207
208// TODO: impl drop, remove waker, remove socket
209impl<'a, T, R, M> OwnedSocketHdl<'a, T, R, M>
210where
211    T: Serialize + Clone + DeserializeOwned + 'static,
212    R: ScopedRawMutex + 'static,
213    M: InterfaceManager + 'static,
214{
215    pub fn port(&self) -> u8 {
216        self.port
217    }
218
219    pub fn stack(&self) -> &'static NetStack<R, M> {
220        unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
221    }
222
223    // TODO: This future is !Send? I don't fully understand why, but rustc complains
224    // that since `NonNull<OwnedSocket<E>>` is !Sync, then this future can't be Send,
225    // BUT impl'ing Sync unsafely on OwnedSocketHdl + OwnedSocket doesn't seem to help.
226    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
227        Recv { hdl: self }
228    }
229}
230
231impl<T, R, M> Drop for OwnedSocket<T, R, M>
232where
233    T: Serialize + Clone + DeserializeOwned + 'static,
234    R: ScopedRawMutex + 'static,
235    M: InterfaceManager + 'static,
236{
237    fn drop(&mut self) {
238        println!("Dropping OwnedSocket!");
239        unsafe {
240            let this = NonNull::from(&self.hdr);
241            self.net.detach_socket(this);
242        }
243    }
244}
245
246unsafe impl<T, R, M> Send for OwnedSocketHdl<'_, T, R, M>
247where
248    T: Send,
249    T: Serialize + Clone + DeserializeOwned + 'static,
250    R: ScopedRawMutex + 'static,
251    M: InterfaceManager + 'static,
252{
253}
254
255unsafe impl<T, R, M> Sync for OwnedSocketHdl<'_, T, R, M>
256where
257    T: Send,
258    T: Serialize + Clone + DeserializeOwned + 'static,
259    R: ScopedRawMutex + 'static,
260    M: InterfaceManager + 'static,
261{
262}
263
264// impl Recv
265
266impl<T, R, M> Future for Recv<'_, '_, T, R, M>
267where
268    T: Serialize + Clone + DeserializeOwned + 'static,
269    R: ScopedRawMutex + 'static,
270    M: InterfaceManager + 'static,
271{
272    type Output = Response<T>;
273
274    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
275        let net: &'static NetStack<R, M> = self.hdl.stack();
276        let f = || {
277            let this_ref: &OwnedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
278            let box_ref: &mut OneBox<T> = unsafe { &mut *this_ref.inner.get() };
279            match core::mem::replace(&mut box_ref.t, Contents::None) {
280                Contents::Mesg(owned_message) => return Some(Ok(owned_message)),
281                Contents::Err(owned_message) => return Some(Err(owned_message)),
282                Contents::None => {}
283            }
284
285            let new_wake = cx.waker();
286            if let Some(w) = box_ref.wait.take() {
287                if !w.will_wake(new_wake) {
288                    w.wake();
289                }
290            }
291            // NOTE: Okay to register waker AFTER checking, because we
292            // have an exclusive lock
293            box_ref.wait = Some(new_wake.clone());
294            None
295        };
296        let res = unsafe { net.with_lock(f) };
297        if let Some(t) = res {
298            Poll::Ready(t)
299        } else {
300            Poll::Pending
301        }
302    }
303}
304
305unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
306where
307    T: Send,
308    T: Serialize + Clone + DeserializeOwned + 'static,
309    R: ScopedRawMutex + 'static,
310    M: InterfaceManager + 'static,
311{
312}
313
314// impl OneBox
315
316impl<T: 'static> OneBox<T> {
317    const fn new() -> Self {
318        Self {
319            wait: None,
320            t: Contents::None,
321        }
322    }
323}
324
325impl<T: 'static> Default for OneBox<T> {
326    fn default() -> Self {
327        Self::new()
328    }
329}