1use core::{
12 any::TypeId,
13 cell::UnsafeCell,
14 marker::PhantomData,
15 pin::Pin,
16 ptr::{NonNull, addr_of},
17 task::{Context, Poll, Waker},
18};
19
20use cordyceps::list::Links;
21use mutex::ScopedRawMutex;
22use serde::de::DeserializeOwned;
23
24use crate::{
25 HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager, nash::NameHash,
26};
27
28use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
29
30#[derive(Debug, PartialEq)]
31pub struct StorageFull;
32
33pub trait Storage<T: 'static>: 'static {
34 fn is_full(&self) -> bool;
35 fn is_empty(&self) -> bool;
36 fn push(&mut self, t: T) -> Result<(), StorageFull>;
37 fn try_pop(&mut self) -> Option<T>;
38}
39
40#[repr(C)]
42pub struct Socket<S, T, R, M>
43where
44 S: Storage<Response<T>>,
45 T: Clone + DeserializeOwned + 'static,
46 R: ScopedRawMutex + 'static,
47 M: InterfaceManager + 'static,
48{
49 hdr: SocketHeader,
51 pub(crate) net: &'static NetStack<R, M>,
52 inner: UnsafeCell<StoreBox<S, Response<T>>>,
55}
56
57pub struct SocketHdl<'a, S, T, R, M>
58where
59 S: Storage<Response<T>>,
60 T: Clone + DeserializeOwned + 'static,
61 R: ScopedRawMutex + 'static,
62 M: InterfaceManager + 'static,
63{
64 pub(crate) ptr: NonNull<Socket<S, T, R, M>>,
65 _lt: PhantomData<Pin<&'a mut Socket<S, T, R, M>>>,
66 port: u8,
67}
68
69pub struct Recv<'a, 'b, S, T, R, M>
70where
71 S: Storage<Response<T>>,
72 T: Clone + DeserializeOwned + 'static,
73 R: ScopedRawMutex + 'static,
74 M: InterfaceManager + 'static,
75{
76 hdl: &'a mut SocketHdl<'b, S, T, R, M>,
77}
78
79struct StoreBox<S: Storage<T>, T: 'static> {
80 wait: Option<Waker>,
81 sto: S,
82 _pd: PhantomData<fn() -> T>,
83}
84
85impl<S, T, R, M> Socket<S, T, R, M>
90where
91 S: Storage<Response<T>>,
92 T: Clone + DeserializeOwned + 'static,
93 R: ScopedRawMutex + 'static,
94 M: InterfaceManager + 'static,
95{
96 pub const fn new(
97 net: &'static NetStack<R, M>,
98 key: Key,
99 attrs: Attributes,
100 sto: S,
101 name: Option<&str>,
102 ) -> Self {
103 Self {
104 hdr: SocketHeader {
105 links: Links::new(),
106 vtable: const { &Self::vtable() },
107 port: 0,
108 attrs,
109 key,
110 nash: if let Some(n) = name {
111 Some(NameHash::new(n))
112 } else {
113 None
114 },
115 },
116 inner: UnsafeCell::new(StoreBox::new(sto)),
117 net,
118 }
119 }
120
121 pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
122 let stack = self.net;
123 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
124 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
125 let port = unsafe { stack.attach_socket(ptr_erase) };
126 SocketHdl {
127 ptr: ptr_self,
128 _lt: PhantomData,
129 port,
130 }
131 }
132
133 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, R, M> {
134 let stack = self.net;
135 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
136 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
137 unsafe { stack.attach_broadcast_socket(ptr_erase) };
138 SocketHdl {
139 ptr: ptr_self,
140 _lt: PhantomData,
141 port: 255,
142 }
143 }
144
145 const fn vtable() -> SocketVTable {
146 SocketVTable {
147 recv_owned: Some(Self::recv_owned),
148 recv_bor: None,
152 recv_raw: Self::recv_raw,
153 recv_err: Some(Self::recv_err),
154 }
155 }
156
157 pub fn stack(&self) -> &'static NetStack<R, M> {
158 self.net
159 }
160
161 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
162 let this: NonNull<Self> = this.cast();
163 let this: &Self = unsafe { this.as_ref() };
164 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
165
166 let msg = Err(HeaderMessage { hdr, t: err });
167 if mutitem.sto.push(msg).is_ok() {
168 if let Some(w) = mutitem.wait.take() {
169 w.wake();
170 }
171 }
172 }
173
174 fn recv_owned(
175 this: NonNull<()>,
176 that: NonNull<()>,
177 hdr: HeaderSeq,
178 ty: &TypeId,
179 ) -> Result<(), SocketSendError> {
180 if &TypeId::of::<T>() != ty {
181 debug_assert!(false, "Type Mismatch!");
182 return Err(SocketSendError::TypeMismatch);
183 }
184 let that: NonNull<T> = that.cast();
185 let that: &T = unsafe { that.as_ref() };
186 let this: NonNull<Self> = this.cast();
187 let this: &Self = unsafe { this.as_ref() };
188 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
189
190 let msg = Ok(HeaderMessage {
191 hdr,
192 t: that.clone(),
193 });
194
195 match mutitem.sto.push(msg) {
196 Ok(()) => {
197 if let Some(w) = mutitem.wait.take() {
198 w.wake();
199 }
200 Ok(())
201 }
202 Err(StorageFull) => Err(SocketSendError::NoSpace),
203 }
204 }
205
206 fn recv_raw(
217 this: NonNull<()>,
218 that: &[u8],
219 hdr: HeaderSeq,
220 _hdr_raw: &[u8],
221 ) -> Result<(), SocketSendError> {
222 let this: NonNull<Self> = this.cast();
223 let this: &Self = unsafe { this.as_ref() };
224 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
225
226 if mutitem.sto.is_full() {
227 return Err(SocketSendError::NoSpace);
228 }
229
230 if let Ok(t) = postcard::from_bytes::<T>(that) {
231 let msg = Ok(HeaderMessage { hdr, t });
232 let _ = mutitem.sto.push(msg);
233 if let Some(w) = mutitem.wait.take() {
234 w.wake();
235 }
236 Ok(())
237 } else {
238 Err(SocketSendError::DeserFailed)
239 }
240 }
241}
242
243impl<'a, S, T, R, M> SocketHdl<'a, S, T, R, M>
247where
248 S: Storage<Response<T>>,
249 T: Clone + DeserializeOwned + 'static,
250 R: ScopedRawMutex + 'static,
251 M: InterfaceManager + 'static,
252{
253 pub fn port(&self) -> u8 {
254 self.port
255 }
256
257 pub fn stack(&self) -> &'static NetStack<R, M> {
258 unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
259 }
260
261 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, R, M> {
262 Recv { hdl: self }
263 }
264}
265
266impl<S, T, R, M> Drop for Socket<S, T, R, M>
267where
268 S: Storage<Response<T>>,
269 T: Clone + DeserializeOwned + 'static,
270 R: ScopedRawMutex + 'static,
271 M: InterfaceManager + 'static,
272{
273 fn drop(&mut self) {
274 unsafe {
275 let this = NonNull::from(&self.hdr);
276 self.net.detach_socket(this);
277 }
278 }
279}
280
281unsafe impl<S, T, R, M> Send for SocketHdl<'_, S, T, R, M>
282where
283 S: Storage<Response<T>>,
284 T: Send,
285 T: Clone + DeserializeOwned + 'static,
286 R: ScopedRawMutex + 'static,
287 M: InterfaceManager + 'static,
288{
289}
290
291unsafe impl<S, T, R, M> Sync for SocketHdl<'_, S, T, R, M>
292where
293 S: Storage<Response<T>>,
294 T: Send,
295 T: Clone + DeserializeOwned + 'static,
296 R: ScopedRawMutex + 'static,
297 M: InterfaceManager + 'static,
298{
299}
300
301impl<S, T, R, M> Future for Recv<'_, '_, S, T, R, M>
304where
305 S: Storage<Response<T>>,
306 T: Clone + DeserializeOwned + 'static,
307 R: ScopedRawMutex + 'static,
308 M: InterfaceManager + 'static,
309{
310 type Output = Response<T>;
311
312 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313 let net: &'static NetStack<R, M> = self.hdl.stack();
314 let f = || {
315 let this_ref: &Socket<S, T, R, M> = unsafe { self.hdl.ptr.as_ref() };
316 let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
317
318 if let Some(resp) = box_ref.sto.try_pop() {
319 return Some(resp);
320 }
321
322 let new_wake = cx.waker();
323 if let Some(w) = box_ref.wait.take() {
324 if !w.will_wake(new_wake) {
325 w.wake();
326 }
327 }
328 box_ref.wait = Some(new_wake.clone());
331 None
332 };
333 let res = unsafe { net.with_lock(f) };
334 if let Some(t) = res {
335 Poll::Ready(t)
336 } else {
337 Poll::Pending
338 }
339 }
340}
341
342unsafe impl<S, T, R, M> Sync for Recv<'_, '_, S, T, R, M>
343where
344 S: Storage<Response<T>>,
345 T: Send,
346 T: Clone + DeserializeOwned + 'static,
347 R: ScopedRawMutex + 'static,
348 M: InterfaceManager + 'static,
349{
350}
351
352impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
355 const fn new(sto: S) -> Self {
356 Self {
357 wait: None,
358 sto,
359 _pd: PhantomData,
360 }
361 }
362}