ergot_base/socket/
std_bounded.rs

1use std::{
2    any::TypeId,
3    cell::UnsafeCell,
4    collections::VecDeque,
5    marker::PhantomData,
6    pin::Pin,
7    ptr::{NonNull, addr_of},
8    task::{Context, Poll, Waker},
9};
10
11use cordyceps::list::Links;
12use mutex::ScopedRawMutex;
13use serde::{Serialize, de::DeserializeOwned};
14
15use crate::{FrameKind, HeaderSeq, Key, NetStack, interface_manager::InterfaceManager};
16
17use super::{OwnedMessage, SocketHeader, SocketSendError, SocketVTable};
18
19// Owned Socket
20#[repr(C)]
21pub struct StdBoundedSocket<T, R, M>
22where
23    T: Serialize + DeserializeOwned + 'static,
24    R: ScopedRawMutex + 'static,
25    M: InterfaceManager + 'static,
26{
27    // LOAD BEARING: must be first
28    hdr: SocketHeader,
29    net: &'static NetStack<R, M>,
30    // TODO: just a single item, we probably want a more ring-buffery
31    // option for this.
32    inner: UnsafeCell<BoundedQueue<T>>,
33}
34
35pub struct StdBoundedSocketHdl<'a, T, R, M>
36where
37    T: Serialize + DeserializeOwned + 'static,
38    R: ScopedRawMutex + 'static,
39    M: InterfaceManager + 'static,
40{
41    pub(crate) ptr: NonNull<StdBoundedSocket<T, R, M>>,
42    _lt: PhantomData<Pin<&'a mut StdBoundedSocket<T, R, M>>>,
43    port: u8,
44}
45
46pub struct Recv<'a, 'b, T, R, M>
47where
48    T: Serialize + DeserializeOwned + 'static,
49    R: ScopedRawMutex + 'static,
50    M: InterfaceManager + 'static,
51{
52    hdl: &'a mut StdBoundedSocketHdl<'b, T, R, M>,
53}
54
55struct BoundedQueue<T: 'static> {
56    wait: Option<Waker>,
57    // TODO: We could probably do better than a VecDeque with a boxed slice
58    // and a ringbuffer, which could maybe also be shared with the std
59    // inline buffer, but for now this is fine.
60    queue: VecDeque<OwnedMessage<T>>,
61    max_len: usize,
62}
63
64// ---- impls ----
65
66// impl OwnedMessage
67
68// ...
69
70// impl StdBoundedSocket
71
72impl<T, R, M> StdBoundedSocket<T, R, M>
73where
74    T: Serialize + DeserializeOwned + 'static,
75    R: ScopedRawMutex + 'static,
76    M: InterfaceManager + 'static,
77{
78    pub fn new(net: &'static NetStack<R, M>, key: Key, kind: FrameKind, bound: usize) -> Self {
79        Self {
80            hdr: SocketHeader {
81                links: Links::new(),
82                vtable: const { &Self::vtable() },
83                port: 0,
84                kind,
85                key,
86            },
87            inner: UnsafeCell::new(BoundedQueue::new(bound)),
88            net,
89        }
90    }
91
92    pub fn stack(&self) -> &'static NetStack<R, M> {
93        self.net
94    }
95
96    pub fn attach<'a>(self: Pin<&'a mut Self>) -> StdBoundedSocketHdl<'a, T, R, M> {
97        let stack = self.net;
98        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
99        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
100        let port = unsafe { stack.attach_socket(ptr_erase) };
101        StdBoundedSocketHdl {
102            ptr: ptr_self,
103            _lt: PhantomData,
104            port,
105        }
106        // TODO: once-check?
107    }
108
109    const fn vtable() -> SocketVTable {
110        SocketVTable {
111            send_owned: Some(Self::send_owned),
112            // TODO: We probably COULD support this, but I'm pretty sure it
113            // would require serializing, copying to a buffer, then later
114            // deserializing. I really don't know if we WANT this.
115            send_bor: None,
116            send_raw: Self::send_raw,
117        }
118    }
119
120    fn send_owned(
121        this: NonNull<()>,
122        that: NonNull<()>,
123        hdr: HeaderSeq,
124        ty: &TypeId,
125    ) -> Result<(), SocketSendError> {
126        if &TypeId::of::<T>() != ty {
127            debug_assert!(false, "Type Mismatch!");
128            return Err(SocketSendError::TypeMismatch);
129        }
130        let that: NonNull<T> = that.cast();
131        let this: NonNull<Self> = this.cast();
132        let this: &Self = unsafe { this.as_ref() };
133        let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
134
135        if mutitem.queue.len() >= mutitem.max_len {
136            return Err(SocketSendError::NoSpace);
137        }
138
139        mutitem.queue.push_back(OwnedMessage {
140            hdr,
141            t: unsafe { that.read() },
142        });
143        if let Some(w) = mutitem.wait.take() {
144            w.wake();
145        }
146
147        Ok(())
148    }
149
150    // fn send_bor(
151    //     this: NonNull<()>,
152    //     that: NonNull<()>,
153    //     src: Address,
154    //     dst: Address,
155    // ) -> Result<(), ()> {
156    //     // I don't think we can support this?
157    //     Err(())
158    // }
159
160    fn send_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
161        let this: NonNull<Self> = this.cast();
162        let this: &Self = unsafe { this.as_ref() };
163        let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
164
165        if mutitem.queue.len() >= mutitem.max_len {
166            return Err(SocketSendError::NoSpace);
167        }
168
169        if let Ok(t) = postcard::from_bytes::<T>(that) {
170            mutitem.queue.push_back(OwnedMessage { hdr, t });
171            if let Some(w) = mutitem.wait.take() {
172                w.wake();
173            }
174            Ok(())
175        } else {
176            Err(SocketSendError::DeserFailed)
177        }
178    }
179}
180
181// impl StdBoundedSocketHdl
182
183// TODO: impl drop, remove waker, remove socket
184impl<'a, T, R, M> StdBoundedSocketHdl<'a, T, R, M>
185where
186    T: Serialize + DeserializeOwned + 'static,
187    R: ScopedRawMutex + 'static,
188    M: InterfaceManager + 'static,
189{
190    pub fn port(&self) -> u8 {
191        self.port
192    }
193
194    pub fn stack(&self) -> &'static NetStack<R, M> {
195        unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
196    }
197
198    // TODO: This future is !Send? I don't fully understand why, but rustc complains
199    // that since `NonNull<StdBoundedSocket<E>>` is !Sync, then this future can't be Send,
200    // BUT impl'ing Sync unsafely on StdBoundedSocketHdl + StdBoundedSocket doesn't seem to help.
201    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
202        Recv { hdl: self }
203    }
204}
205
206impl<T, R, M> Drop for StdBoundedSocket<T, R, M>
207where
208    T: Serialize + DeserializeOwned + 'static,
209    R: ScopedRawMutex + 'static,
210    M: InterfaceManager + 'static,
211{
212    fn drop(&mut self) {
213        println!("Dropping StdBoundedSocket!");
214        unsafe {
215            let this = NonNull::from(&self.hdr);
216            self.net.detach_socket(this);
217        }
218    }
219}
220
221unsafe impl<T, R, M> Send for StdBoundedSocketHdl<'_, T, R, M>
222where
223    T: Send,
224    T: Serialize + DeserializeOwned + 'static,
225    R: ScopedRawMutex + 'static,
226    M: InterfaceManager + 'static,
227{
228}
229
230unsafe impl<T, R, M> Sync for StdBoundedSocketHdl<'_, T, R, M>
231where
232    T: Send,
233    T: Serialize + DeserializeOwned + 'static,
234    R: ScopedRawMutex + 'static,
235    M: InterfaceManager + 'static,
236{
237}
238
239// impl Recv
240
241impl<T, R, M> Future for Recv<'_, '_, T, R, M>
242where
243    T: Serialize + DeserializeOwned + 'static,
244    R: ScopedRawMutex + 'static,
245    M: InterfaceManager + 'static,
246{
247    type Output = OwnedMessage<T>;
248
249    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250        let net = self.hdl.stack();
251        let f = || {
252            let this_ref: &StdBoundedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
253            let box_ref: &mut BoundedQueue<T> = unsafe { &mut *this_ref.inner.get() };
254            if let Some(t) = box_ref.queue.pop_front() {
255                Some(t)
256            } else {
257                let new_wake = cx.waker();
258                if let Some(w) = box_ref.wait.take() {
259                    if !w.will_wake(new_wake) {
260                        w.wake();
261                    }
262                }
263                // NOTE: Okay to register waker AFTER checking, because we
264                // have an exclusive lock
265                box_ref.wait = Some(new_wake.clone());
266                None
267            }
268        };
269        let res = unsafe { net.with_lock(f) };
270        if let Some(t) = res {
271            Poll::Ready(t)
272        } else {
273            Poll::Pending
274        }
275    }
276}
277
278unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
279where
280    T: Send,
281    T: Serialize + DeserializeOwned + 'static,
282    R: ScopedRawMutex + 'static,
283    M: InterfaceManager + 'static,
284{
285}
286
287// impl BoundedQueue
288
289impl<T: 'static> BoundedQueue<T> {
290    fn new(bound: usize) -> Self {
291        Self {
292            wait: None,
293            queue: VecDeque::new(),
294            max_len: bound,
295        }
296    }
297}