ergot_base/socket/
owned.rs

1use core::{
2    any::TypeId,
3    cell::UnsafeCell,
4    marker::PhantomData,
5    pin::Pin,
6    ptr::{NonNull, addr_of},
7    task::{Context, Poll, Waker},
8};
9
10use cordyceps::list::Links;
11use mutex::ScopedRawMutex;
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{FrameKind, HeaderSeq, Key, NetStack, interface_manager::InterfaceManager};
15
16use super::{OwnedMessage, SocketHeader, SocketSendError, SocketVTable};
17
18// Owned Socket
19#[repr(C)]
20pub struct OwnedSocket<T, R, M>
21where
22    T: Serialize + DeserializeOwned + 'static,
23    R: ScopedRawMutex + 'static,
24    M: InterfaceManager + 'static,
25{
26    // LOAD BEARING: must be first
27    hdr: SocketHeader,
28    pub(crate) net: &'static NetStack<R, M>,
29    // TODO: just a single item, we probably want a more ring-buffery
30    // option for this.
31    inner: UnsafeCell<OneBox<T>>,
32}
33
34pub struct OwnedSocketHdl<'a, T, R, M>
35where
36    T: Serialize + DeserializeOwned + 'static,
37    R: ScopedRawMutex + 'static,
38    M: InterfaceManager + 'static,
39{
40    pub(crate) ptr: NonNull<OwnedSocket<T, R, M>>,
41    _lt: PhantomData<Pin<&'a mut OwnedSocket<T, R, M>>>,
42    port: u8,
43}
44
45pub struct Recv<'a, 'b, T, R, M>
46where
47    T: Serialize + DeserializeOwned + 'static,
48    R: ScopedRawMutex + 'static,
49    M: InterfaceManager + 'static,
50{
51    hdl: &'a mut OwnedSocketHdl<'b, T, R, M>,
52}
53
54struct OneBox<T: 'static> {
55    wait: Option<Waker>,
56    t: Option<OwnedMessage<T>>,
57}
58
59// ---- impls ----
60
61// impl OwnedMessage
62
63// ...
64
65// impl OwnedSocket
66
67impl<T, R, M> OwnedSocket<T, R, M>
68where
69    T: Serialize + DeserializeOwned + 'static,
70    R: ScopedRawMutex + 'static,
71    M: InterfaceManager + 'static,
72{
73    pub const fn new(net: &'static NetStack<R, M>, key: Key, kind: FrameKind) -> Self {
74        Self {
75            hdr: SocketHeader {
76                links: Links::new(),
77                vtable: const { &Self::vtable() },
78                port: 0,
79                kind,
80                key,
81            },
82            inner: UnsafeCell::new(OneBox::new()),
83            net,
84        }
85    }
86
87    pub fn attach<'a>(self: Pin<&'a mut Self>) -> OwnedSocketHdl<'a, T, R, M> {
88        let stack = self.net;
89        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
90        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
91        let port = unsafe { stack.attach_socket(ptr_erase) };
92        OwnedSocketHdl {
93            ptr: ptr_self,
94            _lt: PhantomData,
95            port,
96        }
97        // TODO: once-check?
98    }
99
100    const fn vtable() -> SocketVTable {
101        SocketVTable {
102            send_owned: Some(Self::send_owned),
103            // TODO: We probably COULD support this, but I'm pretty sure it
104            // would require serializing, copying to a buffer, then later
105            // deserializing. I really don't know if we WANT this.
106            send_bor: None,
107            send_raw: Self::send_raw,
108        }
109    }
110
111    pub fn stack(&self) -> &'static NetStack<R, M> {
112        self.net
113    }
114
115    fn send_owned(
116        this: NonNull<()>,
117        that: NonNull<()>,
118        hdr: HeaderSeq,
119        ty: &TypeId,
120    ) -> Result<(), SocketSendError> {
121        if &TypeId::of::<T>() != ty {
122            debug_assert!(false, "Type Mismatch!");
123            return Err(SocketSendError::TypeMismatch);
124        }
125        let that: NonNull<T> = that.cast();
126        let this: NonNull<Self> = this.cast();
127        let this: &Self = unsafe { this.as_ref() };
128        let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
129
130        if mutitem.t.is_some() {
131            return Err(SocketSendError::NoSpace);
132        }
133
134        mutitem.t = Some(OwnedMessage {
135            hdr,
136            t: unsafe { that.read() },
137        });
138        if let Some(w) = mutitem.wait.take() {
139            w.wake();
140        }
141
142        Ok(())
143    }
144
145    // fn send_bor(
146    //     this: NonNull<()>,
147    //     that: NonNull<()>,
148    //     src: Address,
149    //     dst: Address,
150    // ) -> Result<(), ()> {
151    //     // I don't think we can support this?
152    //     Err(())
153    // }
154
155    fn send_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
156        let this: NonNull<Self> = this.cast();
157        let this: &Self = unsafe { this.as_ref() };
158        let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
159
160        if mutitem.t.is_some() {
161            return Err(SocketSendError::NoSpace);
162        }
163
164        if let Ok(t) = postcard::from_bytes::<T>(that) {
165            mutitem.t = Some(OwnedMessage { hdr, t });
166            if let Some(w) = mutitem.wait.take() {
167                w.wake();
168            }
169            Ok(())
170        } else {
171            Err(SocketSendError::DeserFailed)
172        }
173    }
174}
175
176// impl OwnedSocketHdl
177
178// TODO: impl drop, remove waker, remove socket
179impl<'a, T, R, M> OwnedSocketHdl<'a, T, R, M>
180where
181    T: Serialize + DeserializeOwned + 'static,
182    R: ScopedRawMutex + 'static,
183    M: InterfaceManager + 'static,
184{
185    pub fn port(&self) -> u8 {
186        self.port
187    }
188
189    pub fn stack(&self) -> &'static NetStack<R, M> {
190        unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
191    }
192
193    // TODO: This future is !Send? I don't fully understand why, but rustc complains
194    // that since `NonNull<OwnedSocket<E>>` is !Sync, then this future can't be Send,
195    // BUT impl'ing Sync unsafely on OwnedSocketHdl + OwnedSocket doesn't seem to help.
196    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
197        Recv { hdl: self }
198    }
199}
200
201impl<T, R, M> Drop for OwnedSocket<T, R, M>
202where
203    T: Serialize + DeserializeOwned + 'static,
204    R: ScopedRawMutex + 'static,
205    M: InterfaceManager + 'static,
206{
207    fn drop(&mut self) {
208        println!("Dropping OwnedSocket!");
209        unsafe {
210            let this = NonNull::from(&self.hdr);
211            self.net.detach_socket(this);
212        }
213    }
214}
215
216unsafe impl<T, R, M> Send for OwnedSocketHdl<'_, T, R, M>
217where
218    T: Send,
219    T: Serialize + DeserializeOwned + 'static,
220    R: ScopedRawMutex + 'static,
221    M: InterfaceManager + 'static,
222{
223}
224
225unsafe impl<T, R, M> Sync for OwnedSocketHdl<'_, T, R, M>
226where
227    T: Send,
228    T: Serialize + DeserializeOwned + 'static,
229    R: ScopedRawMutex + 'static,
230    M: InterfaceManager + 'static,
231{
232}
233
234// impl Recv
235
236impl<T, R, M> Future for Recv<'_, '_, T, R, M>
237where
238    T: Serialize + DeserializeOwned + 'static,
239    R: ScopedRawMutex + 'static,
240    M: InterfaceManager + 'static,
241{
242    type Output = OwnedMessage<T>;
243
244    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245        let net: &'static NetStack<R, M> = self.hdl.stack();
246        let f = || {
247            let this_ref: &OwnedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
248            let box_ref: &mut OneBox<T> = unsafe { &mut *this_ref.inner.get() };
249            if let Some(t) = box_ref.t.take() {
250                Some(t)
251            } else {
252                let new_wake = cx.waker();
253                if let Some(w) = box_ref.wait.take() {
254                    if !w.will_wake(new_wake) {
255                        w.wake();
256                    }
257                }
258                // NOTE: Okay to register waker AFTER checking, because we
259                // have an exclusive lock
260                box_ref.wait = Some(new_wake.clone());
261                None
262            }
263        };
264        let res = unsafe { net.with_lock(f) };
265        if let Some(t) = res {
266            Poll::Ready(t)
267        } else {
268            Poll::Pending
269        }
270    }
271}
272
273unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
274where
275    T: Send,
276    T: Serialize + DeserializeOwned + 'static,
277    R: ScopedRawMutex + 'static,
278    M: InterfaceManager + 'static,
279{
280}
281
282// impl OneBox
283
284impl<T: 'static> OneBox<T> {
285    const fn new() -> Self {
286        Self {
287            wait: None,
288            t: None,
289        }
290    }
291}
292
293impl<T: 'static> Default for OneBox<T> {
294    fn default() -> Self {
295        Self::new()
296    }
297}