1use std::{
2 any::TypeId,
3 cell::UnsafeCell,
4 collections::VecDeque,
5 marker::PhantomData,
6 pin::Pin,
7 ptr::{NonNull, addr_of},
8 task::{Context, Poll, Waker},
9};
10
11use cordyceps::list::Links;
12use mutex::ScopedRawMutex;
13use serde::{Serialize, de::DeserializeOwned};
14
15use crate::{HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager};
16
17use super::{Attributes, OwnedMessage, Response, SocketHeader, SocketSendError, SocketVTable};
18
19#[repr(C)]
21pub struct StdBoundedSocket<T, R, M>
22where
23 T: Serialize + Clone + DeserializeOwned + 'static,
24 R: ScopedRawMutex + 'static,
25 M: InterfaceManager + 'static,
26{
27 hdr: SocketHeader,
29 net: &'static NetStack<R, M>,
30 inner: UnsafeCell<BoundedQueue<T>>,
33}
34
35pub struct StdBoundedSocketHdl<'a, T, R, M>
36where
37 T: Serialize + Clone + DeserializeOwned + 'static,
38 R: ScopedRawMutex + 'static,
39 M: InterfaceManager + 'static,
40{
41 pub(crate) ptr: NonNull<StdBoundedSocket<T, R, M>>,
42 _lt: PhantomData<Pin<&'a mut StdBoundedSocket<T, R, M>>>,
43 port: u8,
44}
45
46pub struct Recv<'a, 'b, T, R, M>
47where
48 T: Serialize + Clone + DeserializeOwned + 'static,
49 R: ScopedRawMutex + 'static,
50 M: InterfaceManager + 'static,
51{
52 hdl: &'a mut StdBoundedSocketHdl<'b, T, R, M>,
53}
54
55struct BoundedQueue<T: 'static> {
56 wait: Option<Waker>,
57 queue: VecDeque<Response<T>>,
61 max_len: usize,
62}
63
64impl<T, R, M> StdBoundedSocket<T, R, M>
73where
74 T: Serialize + Clone + DeserializeOwned + 'static,
75 R: ScopedRawMutex + 'static,
76 M: InterfaceManager + 'static,
77{
78 pub fn new(net: &'static NetStack<R, M>, key: Key, attrs: Attributes, bound: usize) -> Self {
79 Self {
80 hdr: SocketHeader {
81 links: Links::new(),
82 vtable: const { &Self::vtable() },
83 port: 0,
84 attrs,
85 key,
86 },
87 inner: UnsafeCell::new(BoundedQueue::new(bound)),
88 net,
89 }
90 }
91
92 pub fn stack(&self) -> &'static NetStack<R, M> {
93 self.net
94 }
95
96 pub fn attach<'a>(self: Pin<&'a mut Self>) -> StdBoundedSocketHdl<'a, T, R, M> {
97 let stack = self.net;
98 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
99 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
100 let port = unsafe { stack.attach_socket(ptr_erase) };
101 StdBoundedSocketHdl {
102 ptr: ptr_self,
103 _lt: PhantomData,
104 port,
105 }
106 }
107
108 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> StdBoundedSocketHdl<'a, T, R, M> {
109 let stack = self.net;
110 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
111 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
112 unsafe { stack.attach_broadcast_socket(ptr_erase) };
113 StdBoundedSocketHdl {
114 ptr: ptr_self,
115 _lt: PhantomData,
116 port: 255,
117 }
118 }
119
120 const fn vtable() -> SocketVTable {
121 SocketVTable {
122 recv_owned: Some(Self::recv_owned),
123 recv_bor: None,
127 recv_raw: Self::recv_raw,
128 recv_err: Some(Self::recv_err),
129 }
130 }
131
132 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
133 let this: NonNull<Self> = this.cast();
134 let this: &Self = unsafe { this.as_ref() };
135 let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
136
137 if mutitem.queue.len() >= mutitem.max_len {
138 return;
139 }
140
141 mutitem.queue.push_back(Err(OwnedMessage { hdr, t: err }));
142 if let Some(w) = mutitem.wait.take() {
143 w.wake();
144 }
145 }
146
147 fn recv_owned(
148 this: NonNull<()>,
149 that: NonNull<()>,
150 hdr: HeaderSeq,
151 ty: &TypeId,
152 ) -> Result<(), SocketSendError> {
153 if &TypeId::of::<T>() != ty {
154 debug_assert!(false, "Type Mismatch!");
155 return Err(SocketSendError::TypeMismatch);
156 }
157 let that: NonNull<T> = that.cast();
158 let that: &T = unsafe { that.as_ref() };
159 let this: NonNull<Self> = this.cast();
160 let this: &Self = unsafe { this.as_ref() };
161 let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
162
163 if mutitem.queue.len() >= mutitem.max_len {
164 return Err(SocketSendError::NoSpace);
165 }
166
167 mutitem.queue.push_back(Ok(OwnedMessage {
168 hdr,
169 t: that.clone(),
170 }));
171 if let Some(w) = mutitem.wait.take() {
172 w.wake();
173 }
174
175 Ok(())
176 }
177
178 fn recv_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
189 let this: NonNull<Self> = this.cast();
190 let this: &Self = unsafe { this.as_ref() };
191 let mutitem: &mut BoundedQueue<T> = unsafe { &mut *this.inner.get() };
192
193 if mutitem.queue.len() >= mutitem.max_len {
194 return Err(SocketSendError::NoSpace);
195 }
196
197 if let Ok(t) = postcard::from_bytes::<T>(that) {
198 mutitem.queue.push_back(Ok(OwnedMessage { hdr, t }));
199 if let Some(w) = mutitem.wait.take() {
200 w.wake();
201 }
202 Ok(())
203 } else {
204 Err(SocketSendError::DeserFailed)
205 }
206 }
207}
208
209impl<'a, T, R, M> StdBoundedSocketHdl<'a, T, R, M>
213where
214 T: Serialize + Clone + DeserializeOwned + 'static,
215 R: ScopedRawMutex + 'static,
216 M: InterfaceManager + 'static,
217{
218 pub fn port(&self) -> u8 {
219 self.port
220 }
221
222 pub fn stack(&self) -> &'static NetStack<R, M> {
223 unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
224 }
225
226 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
230 Recv { hdl: self }
231 }
232}
233
234impl<T, R, M> Drop for StdBoundedSocket<T, R, M>
235where
236 T: Serialize + Clone + DeserializeOwned + 'static,
237 R: ScopedRawMutex + 'static,
238 M: InterfaceManager + 'static,
239{
240 fn drop(&mut self) {
241 println!("Dropping StdBoundedSocket!");
242 unsafe {
243 let this = NonNull::from(&self.hdr);
244 self.net.detach_socket(this);
245 }
246 }
247}
248
249unsafe impl<T, R, M> Send for StdBoundedSocketHdl<'_, T, R, M>
250where
251 T: Send,
252 T: Serialize + Clone + DeserializeOwned + 'static,
253 R: ScopedRawMutex + 'static,
254 M: InterfaceManager + 'static,
255{
256}
257
258unsafe impl<T, R, M> Sync for StdBoundedSocketHdl<'_, T, R, M>
259where
260 T: Send,
261 T: Serialize + Clone + DeserializeOwned + 'static,
262 R: ScopedRawMutex + 'static,
263 M: InterfaceManager + 'static,
264{
265}
266
267impl<T, R, M> Future for Recv<'_, '_, T, R, M>
270where
271 T: Serialize + Clone + DeserializeOwned + 'static,
272 R: ScopedRawMutex + 'static,
273 M: InterfaceManager + 'static,
274{
275 type Output = Response<T>;
276
277 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
278 let net = self.hdl.stack();
279 let f = || {
280 let this_ref: &StdBoundedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
281 let box_ref: &mut BoundedQueue<T> = unsafe { &mut *this_ref.inner.get() };
282 if let Some(t) = box_ref.queue.pop_front() {
283 Some(t)
284 } else {
285 let new_wake = cx.waker();
286 if let Some(w) = box_ref.wait.take() {
287 if !w.will_wake(new_wake) {
288 w.wake();
289 }
290 }
291 box_ref.wait = Some(new_wake.clone());
294 None
295 }
296 };
297 let res = unsafe { net.with_lock(f) };
298 if let Some(t) = res {
299 Poll::Ready(t)
300 } else {
301 Poll::Pending
302 }
303 }
304}
305
306unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
307where
308 T: Send,
309 T: Serialize + Clone + DeserializeOwned + 'static,
310 R: ScopedRawMutex + 'static,
311 M: InterfaceManager + 'static,
312{
313}
314
315impl<T: 'static> BoundedQueue<T> {
318 fn new(bound: usize) -> Self {
319 Self {
320 wait: None,
321 queue: VecDeque::new(),
322 max_len: bound,
323 }
324 }
325}