ergot_base/socket/
std_bounded.rs

1use std::{
2    any::TypeId,
3    cell::UnsafeCell,
4    collections::VecDeque,
5    marker::PhantomData,
6    pin::Pin,
7    ptr::{NonNull, addr_of},
8    task::{Context, Poll, Waker},
9};
10
11use cordyceps::list::Links;
12use mutex::ScopedRawMutex;
13use serde::{Serialize, de::DeserializeOwned};
14
15use crate::{HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager};
16
17use super::{Attributes, OwnedMessage, Response, SocketHeader, SocketSendError, SocketVTable};
18
19// Owned Socket
20#[repr(C)]
21pub struct StdBoundedSocket<T, R, M>
22where
23    T: Serialize + Clone + DeserializeOwned + 'static,
24    R: ScopedRawMutex + 'static,
25    M: InterfaceManager + 'static,
26{
27    // LOAD BEARING: must be first
28    hdr: SocketHeader,
29    net: &'static NetStack<R, M>,
30    // TODO: just a single item, we probably want a more ring-buffery
31    // option for this.
32    inner: UnsafeCell<BoundedQueue<T>>,
33}
34
35pub struct StdBoundedSocketHdl<'a, T, R, M>
36where
37    T: Serialize + Clone + DeserializeOwned + 'static,
38    R: ScopedRawMutex + 'static,
39    M: InterfaceManager + 'static,
40{
41    pub(crate) ptr: NonNull<StdBoundedSocket<T, R, M>>,
42    _lt: PhantomData<Pin<&'a mut StdBoundedSocket<T, R, M>>>,
43    port: u8,
44}
45
46pub struct Recv<'a, 'b, T, R, M>
47where
48    T: Serialize + Clone + DeserializeOwned + 'static,
49    R: ScopedRawMutex + 'static,
50    M: InterfaceManager + 'static,
51{
52    hdl: &'a mut StdBoundedSocketHdl<'b, T, R, M>,
53}
54
55struct BoundedQueue<T: 'static> {
56    wait: Option<Waker>,
57    // TODO: We could probably do better than a VecDeque with a boxed slice
58    // and a ringbuffer, which could maybe also be shared with the std
59    // inline buffer, but for now this is fine.
60    queue: VecDeque<Response<T>>,
61    max_len: usize,
62}
63
64// ---- impls ----
65
66// impl OwnedMessage
67
68// ...
69
70// impl StdBoundedSocket
71
72impl<T, R, M> StdBoundedSocket<T, R, M>
73where
74    T: Serialize + Clone + DeserializeOwned + 'static,
75    R: ScopedRawMutex + 'static,
76    M: InterfaceManager + 'static,
77{
78    pub fn new(net: &'static NetStack<R, M>, key: Key, attrs: Attributes, bound: usize) -> Self {
79        Self {
80            hdr: SocketHeader {
81                links: Links::new(),
82                vtable: const { &Self::vtable() },
83                port: 0,
84                attrs,
85                key,
86            },
87            inner: UnsafeCell::new(BoundedQueue::new(bound)),
88            net,
89        }
90    }
91
92    pub fn stack(&self) -> &'static NetStack<R, M> {
93        self.net
94    }
95
96    pub fn attach<'a>(self: Pin<&'a mut Self>) -> StdBoundedSocketHdl<'a, T, R, M> {
97        let stack = self.net;
98        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
99        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
100        let port = unsafe { stack.attach_socket(ptr_erase) };
101        StdBoundedSocketHdl {
102            ptr: ptr_self,
103            _lt: PhantomData,
104            port,
105        }
106    }
107
108    pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> StdBoundedSocketHdl<'a, T, R, M> {
109        let stack = self.net;
110        let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
111        let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
112        unsafe { stack.attach_broadcast_socket(ptr_erase) };
113        StdBoundedSocketHdl {
114            ptr: ptr_self,
115            _lt: PhantomData,
116            port: 255,
117        }
118    }
119
120    const fn vtable() -> SocketVTable {
121        SocketVTable {
122            recv_owned: Some(Self::recv_owned),
123            // TODO: We probably COULD support this, but I'm pretty sure it
124            // would require serializing, copying to a buffer, then later
125            // deserializing. I really don't know if we WANT this.
126            recv_bor: None,
127            recv_raw: Self::recv_raw,
128            recv_err: Some(Self::recv_err),
129        }
130    }
131
132    fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
133        let this: NonNull<Self> = this.cast();
134        let this: &Self = unsafe { this.as_ref() };
135        let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
136
137        if mutitem.queue.len() >= mutitem.max_len {
138            return;
139        }
140
141        mutitem.queue.push_back(Err(OwnedMessage { hdr, t: err }));
142        if let Some(w) = mutitem.wait.take() {
143            w.wake();
144        }
145    }
146
147    fn recv_owned(
148        this: NonNull<()>,
149        that: NonNull<()>,
150        hdr: HeaderSeq,
151        ty: &TypeId,
152    ) -> Result<(), SocketSendError> {
153        if &TypeId::of::<T>() != ty {
154            debug_assert!(false, "Type Mismatch!");
155            return Err(SocketSendError::TypeMismatch);
156        }
157        let that: NonNull<T> = that.cast();
158        let that: &T = unsafe { that.as_ref() };
159        let this: NonNull<Self> = this.cast();
160        let this: &Self = unsafe { this.as_ref() };
161        let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
162
163        if mutitem.queue.len() >= mutitem.max_len {
164            return Err(SocketSendError::NoSpace);
165        }
166
167        mutitem.queue.push_back(Ok(OwnedMessage {
168            hdr,
169            t: that.clone(),
170        }));
171        if let Some(w) = mutitem.wait.take() {
172            w.wake();
173        }
174
175        Ok(())
176    }
177
178    // fn send_bor(
179    //     this: NonNull<()>,
180    //     that: NonNull<()>,
181    //     src: Address,
182    //     dst: Address,
183    // ) -> Result<(), ()> {
184    //     // I don't think we can support this?
185    //     Err(())
186    // }
187
188    fn recv_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
189        let this: NonNull<Self> = this.cast();
190        let this: &Self = unsafe { this.as_ref() };
191        let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
192
193        if mutitem.queue.len() >= mutitem.max_len {
194            return Err(SocketSendError::NoSpace);
195        }
196
197        if let Ok(t) = postcard::from_bytes::<T>(that) {
198            mutitem.queue.push_back(Ok(OwnedMessage { hdr, t }));
199            if let Some(w) = mutitem.wait.take() {
200                w.wake();
201            }
202            Ok(())
203        } else {
204            Err(SocketSendError::DeserFailed)
205        }
206    }
207}
208
209// impl StdBoundedSocketHdl
210
211// TODO: impl drop, remove waker, remove socket
212impl<'a, T, R, M> StdBoundedSocketHdl<'a, T, R, M>
213where
214    T: Serialize + Clone + DeserializeOwned + 'static,
215    R: ScopedRawMutex + 'static,
216    M: InterfaceManager + 'static,
217{
218    pub fn port(&self) -> u8 {
219        self.port
220    }
221
222    pub fn stack(&self) -> &'static NetStack<R, M> {
223        unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
224    }
225
226    // TODO: This future is !Send? I don't fully understand why, but rustc complains
227    // that since `NonNull<StdBoundedSocket<E>>` is !Sync, then this future can't be Send,
228    // BUT impl'ing Sync unsafely on StdBoundedSocketHdl + StdBoundedSocket doesn't seem to help.
229    pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
230        Recv { hdl: self }
231    }
232}
233
234impl<T, R, M> Drop for StdBoundedSocket<T, R, M>
235where
236    T: Serialize + Clone + DeserializeOwned + 'static,
237    R: ScopedRawMutex + 'static,
238    M: InterfaceManager + 'static,
239{
240    fn drop(&mut self) {
241        println!("Dropping StdBoundedSocket!");
242        unsafe {
243            let this = NonNull::from(&self.hdr);
244            self.net.detach_socket(this);
245        }
246    }
247}
248
249unsafe impl<T, R, M> Send for StdBoundedSocketHdl<'_, T, R, M>
250where
251    T: Send,
252    T: Serialize + Clone + DeserializeOwned + 'static,
253    R: ScopedRawMutex + 'static,
254    M: InterfaceManager + 'static,
255{
256}
257
258unsafe impl<T, R, M> Sync for StdBoundedSocketHdl<'_, T, R, M>
259where
260    T: Send,
261    T: Serialize + Clone + DeserializeOwned + 'static,
262    R: ScopedRawMutex + 'static,
263    M: InterfaceManager + 'static,
264{
265}
266
267// impl Recv
268
269impl<T, R, M> Future for Recv<'_, '_, T, R, M>
270where
271    T: Serialize + Clone + DeserializeOwned + 'static,
272    R: ScopedRawMutex + 'static,
273    M: InterfaceManager + 'static,
274{
275    type Output = Response<T>;
276
277    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
278        let net = self.hdl.stack();
279        let f = || {
280            let this_ref: &StdBoundedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
281            let box_ref: &mut BoundedQueue<T> = unsafe { &mut *this_ref.inner.get() };
282            if let Some(t) = box_ref.queue.pop_front() {
283                Some(t)
284            } else {
285                let new_wake = cx.waker();
286                if let Some(w) = box_ref.wait.take() {
287                    if !w.will_wake(new_wake) {
288                        w.wake();
289                    }
290                }
291                // NOTE: Okay to register waker AFTER checking, because we
292                // have an exclusive lock
293                box_ref.wait = Some(new_wake.clone());
294                None
295            }
296        };
297        let res = unsafe { net.with_lock(f) };
298        if let Some(t) = res {
299            Poll::Ready(t)
300        } else {
301            Poll::Pending
302        }
303    }
304}
305
306unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
307where
308    T: Send,
309    T: Serialize + Clone + DeserializeOwned + 'static,
310    R: ScopedRawMutex + 'static,
311    M: InterfaceManager + 'static,
312{
313}
314
315// impl BoundedQueue
316
317impl<T: 'static> BoundedQueue<T> {
318    fn new(bound: usize) -> Self {
319        Self {
320            wait: None,
321            queue: VecDeque::new(),
322            max_len: bound,
323        }
324    }
325}