1use std::{
2 any::TypeId,
3 cell::UnsafeCell,
4 collections::VecDeque,
5 marker::PhantomData,
6 pin::Pin,
7 ptr::{NonNull, addr_of},
8 task::{Context, Poll, Waker},
9};
10
11use cordyceps::list::Links;
12use mutex::ScopedRawMutex;
13use serde::{Serialize, de::DeserializeOwned};
14
15use crate::{FrameKind, HeaderSeq, Key, NetStack, interface_manager::InterfaceManager};
16
17use super::{OwnedMessage, SocketHeader, SocketSendError, SocketVTable};
18
19#[repr(C)]
21pub struct StdBoundedSocket<T, R, M>
22where
23 T: Serialize + DeserializeOwned + 'static,
24 R: ScopedRawMutex + 'static,
25 M: InterfaceManager + 'static,
26{
27 hdr: SocketHeader,
29 net: &'static NetStack<R, M>,
30 inner: UnsafeCell<BoundedQueue<T>>,
33}
34
35pub struct StdBoundedSocketHdl<'a, T, R, M>
36where
37 T: Serialize + DeserializeOwned + 'static,
38 R: ScopedRawMutex + 'static,
39 M: InterfaceManager + 'static,
40{
41 pub(crate) ptr: NonNull<StdBoundedSocket<T, R, M>>,
42 _lt: PhantomData<Pin<&'a mut StdBoundedSocket<T, R, M>>>,
43 port: u8,
44}
45
46pub struct Recv<'a, 'b, T, R, M>
47where
48 T: Serialize + DeserializeOwned + 'static,
49 R: ScopedRawMutex + 'static,
50 M: InterfaceManager + 'static,
51{
52 hdl: &'a mut StdBoundedSocketHdl<'b, T, R, M>,
53}
54
55struct BoundedQueue<T: 'static> {
56 wait: Option<Waker>,
57 queue: VecDeque<OwnedMessage<T>>,
61 max_len: usize,
62}
63
64impl<T, R, M> StdBoundedSocket<T, R, M>
73where
74 T: Serialize + DeserializeOwned + 'static,
75 R: ScopedRawMutex + 'static,
76 M: InterfaceManager + 'static,
77{
78 pub fn new(net: &'static NetStack<R, M>, key: Key, kind: FrameKind, bound: usize) -> Self {
79 Self {
80 hdr: SocketHeader {
81 links: Links::new(),
82 vtable: const { &Self::vtable() },
83 port: 0,
84 kind,
85 key,
86 },
87 inner: UnsafeCell::new(BoundedQueue::new(bound)),
88 net,
89 }
90 }
91
92 pub fn stack(&self) -> &'static NetStack<R, M> {
93 self.net
94 }
95
96 pub fn attach<'a>(self: Pin<&'a mut Self>) -> StdBoundedSocketHdl<'a, T, R, M> {
97 let stack = self.net;
98 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
99 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
100 let port = unsafe { stack.attach_socket(ptr_erase) };
101 StdBoundedSocketHdl {
102 ptr: ptr_self,
103 _lt: PhantomData,
104 port,
105 }
106 }
108
109 const fn vtable() -> SocketVTable {
110 SocketVTable {
111 send_owned: Some(Self::send_owned),
112 send_bor: None,
116 send_raw: Self::send_raw,
117 }
118 }
119
120 fn send_owned(
121 this: NonNull<()>,
122 that: NonNull<()>,
123 hdr: HeaderSeq,
124 ty: &TypeId,
125 ) -> Result<(), SocketSendError> {
126 if &TypeId::of::<T>() != ty {
127 debug_assert!(false, "Type Mismatch!");
128 return Err(SocketSendError::TypeMismatch);
129 }
130 let that: NonNull<T> = that.cast();
131 let this: NonNull<Self> = this.cast();
132 let this: &Self = unsafe { this.as_ref() };
133 let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
134
135 if mutitem.queue.len() >= mutitem.max_len {
136 return Err(SocketSendError::NoSpace);
137 }
138
139 mutitem.queue.push_back(OwnedMessage {
140 hdr,
141 t: unsafe { that.read() },
142 });
143 if let Some(w) = mutitem.wait.take() {
144 w.wake();
145 }
146
147 Ok(())
148 }
149
150 fn send_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
161 let this: NonNull<Self> = this.cast();
162 let this: &Self = unsafe { this.as_ref() };
163 let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
164
165 if mutitem.queue.len() >= mutitem.max_len {
166 return Err(SocketSendError::NoSpace);
167 }
168
169 if let Ok(t) = postcard::from_bytes::<T>(that) {
170 mutitem.queue.push_back(OwnedMessage { hdr, t });
171 if let Some(w) = mutitem.wait.take() {
172 w.wake();
173 }
174 Ok(())
175 } else {
176 Err(SocketSendError::DeserFailed)
177 }
178 }
179}
180
181impl<'a, T, R, M> StdBoundedSocketHdl<'a, T, R, M>
185where
186 T: Serialize + DeserializeOwned + 'static,
187 R: ScopedRawMutex + 'static,
188 M: InterfaceManager + 'static,
189{
190 pub fn port(&self) -> u8 {
191 self.port
192 }
193
194 pub fn stack(&self) -> &'static NetStack<R, M> {
195 unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
196 }
197
198 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
202 Recv { hdl: self }
203 }
204}
205
206impl<T, R, M> Drop for StdBoundedSocket<T, R, M>
207where
208 T: Serialize + DeserializeOwned + 'static,
209 R: ScopedRawMutex + 'static,
210 M: InterfaceManager + 'static,
211{
212 fn drop(&mut self) {
213 println!("Dropping StdBoundedSocket!");
214 unsafe {
215 let this = NonNull::from(&self.hdr);
216 self.net.detach_socket(this);
217 }
218 }
219}
220
221unsafe impl<T, R, M> Send for StdBoundedSocketHdl<'_, T, R, M>
222where
223 T: Send,
224 T: Serialize + DeserializeOwned + 'static,
225 R: ScopedRawMutex + 'static,
226 M: InterfaceManager + 'static,
227{
228}
229
230unsafe impl<T, R, M> Sync for StdBoundedSocketHdl<'_, T, R, M>
231where
232 T: Send,
233 T: Serialize + DeserializeOwned + 'static,
234 R: ScopedRawMutex + 'static,
235 M: InterfaceManager + 'static,
236{
237}
238
239impl<T, R, M> Future for Recv<'_, '_, T, R, M>
242where
243 T: Serialize + DeserializeOwned + 'static,
244 R: ScopedRawMutex + 'static,
245 M: InterfaceManager + 'static,
246{
247 type Output = OwnedMessage<T>;
248
249 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250 let net = self.hdl.stack();
251 let f = || {
252 let this_ref: &StdBoundedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
253 let box_ref: &mut BoundedQueue<T> = unsafe { &mut *this_ref.inner.get() };
254 if let Some(t) = box_ref.queue.pop_front() {
255 Some(t)
256 } else {
257 let new_wake = cx.waker();
258 if let Some(w) = box_ref.wait.take() {
259 if !w.will_wake(new_wake) {
260 w.wake();
261 }
262 }
263 box_ref.wait = Some(new_wake.clone());
266 None
267 }
268 };
269 let res = unsafe { net.with_lock(f) };
270 if let Some(t) = res {
271 Poll::Ready(t)
272 } else {
273 Poll::Pending
274 }
275 }
276}
277
278unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
279where
280 T: Send,
281 T: Serialize + DeserializeOwned + 'static,
282 R: ScopedRawMutex + 'static,
283 M: InterfaceManager + 'static,
284{
285}
286
287impl<T: 'static> BoundedQueue<T> {
290 fn new(bound: usize) -> Self {
291 Self {
292 wait: None,
293 queue: VecDeque::new(),
294 max_len: bound,
295 }
296 }
297}