1use core::{
2 any::TypeId,
3 cell::UnsafeCell,
4 marker::PhantomData,
5 pin::Pin,
6 ptr::{NonNull, addr_of},
7 task::{Context, Poll, Waker},
8};
9
10use cordyceps::list::Links;
11use mutex::ScopedRawMutex;
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager};
15
16use super::{Attributes, OwnedMessage, Response, SocketHeader, SocketSendError, SocketVTable};
17
18#[derive(Debug, PartialEq)]
19pub struct StorageFull;
20
21pub trait Storage<T: 'static>: 'static {
22 fn is_full(&self) -> bool;
23 fn is_empty(&self) -> bool;
24 fn push(&mut self, t: T) -> Result<(), StorageFull>;
25 fn try_pop(&mut self) -> Option<T>;
26}
27
28#[repr(C)]
30pub struct Socket<S, T, R, M>
31where
32 S: Storage<Response<T>>,
33 T: Serialize + Clone + DeserializeOwned + 'static,
34 R: ScopedRawMutex + 'static,
35 M: InterfaceManager + 'static,
36{
37 hdr: SocketHeader,
39 pub(crate) net: &'static NetStack<R, M>,
40 inner: UnsafeCell<StoreBox<S, Response<T>>>,
43}
44
45pub struct SocketHdl<'a, S, T, R, M>
46where
47 S: Storage<Response<T>>,
48 T: Serialize + Clone + DeserializeOwned + 'static,
49 R: ScopedRawMutex + 'static,
50 M: InterfaceManager + 'static,
51{
52 pub(crate) ptr: NonNull<Socket<S, T, R, M>>,
53 _lt: PhantomData<Pin<&'a mut Socket<S, T, R, M>>>,
54 port: u8,
55}
56
57pub struct Recv<'a, 'b, S, T, R, M>
58where
59 S: Storage<Response<T>>,
60 T: Serialize + Clone + DeserializeOwned + 'static,
61 R: ScopedRawMutex + 'static,
62 M: InterfaceManager + 'static,
63{
64 hdl: &'a mut SocketHdl<'b, S, T, R, M>,
65}
66
67struct StoreBox<S: Storage<T>, T: 'static> {
68 wait: Option<Waker>,
69 sto: S,
70 _pd: PhantomData<fn() -> T>,
71}
72
73impl<S, T, R, M> Socket<S, T, R, M>
82where
83 S: Storage<Response<T>>,
84 T: Serialize + Clone + DeserializeOwned + 'static,
85 R: ScopedRawMutex + 'static,
86 M: InterfaceManager + 'static,
87{
88 pub const fn new(net: &'static NetStack<R, M>, key: Key, attrs: Attributes, sto: S) -> Self {
89 Self {
90 hdr: SocketHeader {
91 links: Links::new(),
92 vtable: const { &Self::vtable() },
93 port: 0,
94 attrs,
95 key,
96 },
97 inner: UnsafeCell::new(StoreBox::new(sto)),
98 net,
99 }
100 }
101
102 pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
103 let stack = self.net;
104 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
105 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
106 let port = unsafe { stack.attach_socket(ptr_erase) };
107 SocketHdl {
108 ptr: ptr_self,
109 _lt: PhantomData,
110 port,
111 }
112 }
113
114 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
115 let stack = self.net;
116 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
117 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
118 unsafe { stack.attach_broadcast_socket(ptr_erase) };
119 SocketHdl {
120 ptr: ptr_self,
121 _lt: PhantomData,
122 port: 255,
123 }
124 }
125
126 const fn vtable() -> SocketVTable {
127 SocketVTable {
128 recv_owned: Some(Self::recv_owned),
129 recv_bor: None,
133 recv_raw: Self::recv_raw,
134 recv_err: Some(Self::recv_err),
135 }
136 }
137
138 pub fn stack(&self) -> &'static NetStack<R, M> {
139 self.net
140 }
141
142 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
143 let this: NonNull<Self> = this.cast();
144 let this: &Self = unsafe { this.as_ref() };
145 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
146
147 let msg = Err(OwnedMessage { hdr, t: err });
148 if mutitem.sto.push(msg).is_ok() {
149 if let Some(w) = mutitem.wait.take() {
150 w.wake();
151 }
152 }
153 }
154
155 fn recv_owned(
156 this: NonNull<()>,
157 that: NonNull<()>,
158 hdr: HeaderSeq,
159 ty: &TypeId,
160 ) -> Result<(), SocketSendError> {
161 if &TypeId::of::<T>() != ty {
162 debug_assert!(false, "Type Mismatch!");
163 return Err(SocketSendError::TypeMismatch);
164 }
165 let that: NonNull<T> = that.cast();
166 let that: &T = unsafe { that.as_ref() };
167 let this: NonNull<Self> = this.cast();
168 let this: &Self = unsafe { this.as_ref() };
169 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
170
171 let msg = Ok(OwnedMessage {
172 hdr,
173 t: that.clone(),
174 });
175
176 match mutitem.sto.push(msg) {
177 Ok(()) => {
178 if let Some(w) = mutitem.wait.take() {
179 w.wake();
180 }
181 Ok(())
182 }
183 Err(StorageFull) => Err(SocketSendError::NoSpace),
184 }
185 }
186
187 fn recv_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
198 let this: NonNull<Self> = this.cast();
199 let this: &Self = unsafe { this.as_ref() };
200 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
201
202 if mutitem.sto.is_full() {
203 return Err(SocketSendError::NoSpace);
204 }
205
206 if let Ok(t) = postcard::from_bytes::<T>(that) {
207 let msg = Ok(OwnedMessage { hdr, t });
208 let _ = mutitem.sto.push(msg);
209 if let Some(w) = mutitem.wait.take() {
210 w.wake();
211 }
212 Ok(())
213 } else {
214 Err(SocketSendError::DeserFailed)
215 }
216 }
217}
218
219impl<'a, S, T, R, M> SocketHdl<'a, S, T, R, M>
223where
224 S: Storage<Response<T>>,
225 T: Serialize + Clone + DeserializeOwned + 'static,
226 R: ScopedRawMutex + 'static,
227 M: InterfaceManager + 'static,
228{
229 pub fn port(&self) -> u8 {
230 self.port
231 }
232
233 pub fn stack(&self) -> &'static NetStack<R, M> {
234 unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
235 }
236
237 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, R, M> {
241 Recv { hdl: self }
242 }
243}
244
245impl<S, T, R, M> Drop for Socket<S, T, R, M>
246where
247 S: Storage<Response<T>>,
248 T: Serialize + Clone + DeserializeOwned + 'static,
249 R: ScopedRawMutex + 'static,
250 M: InterfaceManager + 'static,
251{
252 fn drop(&mut self) {
253 unsafe {
254 let this = NonNull::from(&self.hdr);
255 self.net.detach_socket(this);
256 }
257 }
258}
259
260unsafe impl<S, T, R, M> Send for SocketHdl<'_, S, T, R, M>
261where
262 S: Storage<Response<T>>,
263 T: Send,
264 T: Serialize + Clone + DeserializeOwned + 'static,
265 R: ScopedRawMutex + 'static,
266 M: InterfaceManager + 'static,
267{
268}
269
270unsafe impl<S, T, R, M> Sync for SocketHdl<'_, S, T, R, M>
271where
272 S: Storage<Response<T>>,
273 T: Send,
274 T: Serialize + Clone + DeserializeOwned + 'static,
275 R: ScopedRawMutex + 'static,
276 M: InterfaceManager + 'static,
277{
278}
279
280impl<S, T, R, M> Future for Recv<'_, '_, S, T, R, M>
283where
284 S: Storage<Response<T>>,
285 T: Serialize + Clone + DeserializeOwned + 'static,
286 R: ScopedRawMutex + 'static,
287 M: InterfaceManager + 'static,
288{
289 type Output = Response<T>;
290
291 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292 let net: &'static NetStack<R, M> = self.hdl.stack();
293 let f = || {
294 let this_ref: &Socket<S, T, R, M> = unsafe { self.hdl.ptr.as_ref() };
295 let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
296
297 if let Some(resp) = box_ref.sto.try_pop() {
298 return Some(resp);
299 }
300
301 let new_wake = cx.waker();
302 if let Some(w) = box_ref.wait.take() {
303 if !w.will_wake(new_wake) {
304 w.wake();
305 }
306 }
307 box_ref.wait = Some(new_wake.clone());
310 None
311 };
312 let res = unsafe { net.with_lock(f) };
313 if let Some(t) = res {
314 Poll::Ready(t)
315 } else {
316 Poll::Pending
317 }
318 }
319}
320
321unsafe impl<S, T, R, M> Sync for Recv<'_, '_, S, T, R, M>
322where
323 S: Storage<Response<T>>,
324 T: Send,
325 T: Serialize + Clone + DeserializeOwned + 'static,
326 R: ScopedRawMutex + 'static,
327 M: InterfaceManager + 'static,
328{
329}
330
331impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
334 const fn new(sto: S) -> Self {
335 Self {
336 wait: None,
337 sto,
338 _pd: PhantomData,
339 }
340 }
341}