1use core::{
12 any::TypeId,
13 cell::UnsafeCell,
14 marker::PhantomData,
15 pin::Pin,
16 ptr::{NonNull, addr_of},
17 task::{Context, Poll, Waker},
18};
19
20use cordyceps::list::Links;
21use serde::de::DeserializeOwned;
22
23use crate::{HeaderSeq, Key, ProtocolError, nash::NameHash, net_stack::NetStackHandle};
24
25use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
26
27#[derive(Debug, PartialEq)]
28pub struct StorageFull;
29
30pub trait Storage<T: 'static>: 'static {
31 fn is_full(&self) -> bool;
32 fn is_empty(&self) -> bool;
33 fn push(&mut self, t: T) -> Result<(), StorageFull>;
34 fn try_pop(&mut self) -> Option<T>;
35}
36
37#[repr(C)]
39pub struct Socket<S, T, N>
40where
41 S: Storage<Response<T>>,
42 T: Clone + DeserializeOwned + 'static,
43 N: NetStackHandle,
44{
45 hdr: SocketHeader,
47 pub(crate) net: N::Target,
48 inner: UnsafeCell<StoreBox<S, Response<T>>>,
49}
50
51pub struct SocketHdl<'a, S, T, N>
52where
53 S: Storage<Response<T>>,
54 T: Clone + DeserializeOwned + 'static,
55 N: NetStackHandle,
56{
57 pub(crate) ptr: NonNull<Socket<S, T, N>>,
58 _lt: PhantomData<Pin<&'a mut Socket<S, T, N>>>,
59 port: u8,
60}
61
62pub struct Recv<'a, 'b, S, T, N>
63where
64 S: Storage<Response<T>>,
65 T: Clone + DeserializeOwned + 'static,
66 N: NetStackHandle,
67{
68 hdl: &'a mut SocketHdl<'b, S, T, N>,
69}
70
71struct StoreBox<S: Storage<T>, T: 'static> {
72 wait: Option<Waker>,
73 sto: S,
74 _pd: PhantomData<fn() -> T>,
75}
76
77impl<S, T, N> Socket<S, T, N>
82where
83 S: Storage<Response<T>>,
84 T: Clone + DeserializeOwned + 'static,
85 N: NetStackHandle,
86{
87 pub const fn new(
88 net: N::Target,
89 key: Key,
90 attrs: Attributes,
91 sto: S,
92 name: Option<&str>,
93 ) -> Self {
94 Self {
95 hdr: SocketHeader {
96 links: Links::new(),
97 vtable: const { &Self::vtable() },
98 port: 0,
99 attrs,
100 key,
101 nash: if let Some(n) = name {
102 Some(NameHash::new(n))
103 } else {
104 None
105 },
106 },
107 inner: UnsafeCell::new(StoreBox::new(sto)),
108 net,
109 }
110 }
111
112 pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
113 let stack = self.net.clone();
114 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
115 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
116 let port = unsafe { stack.attach_socket(ptr_erase) };
117 SocketHdl {
118 ptr: ptr_self,
119 _lt: PhantomData,
120 port,
121 }
122 }
123
124 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
125 let stack = self.net.clone();
126 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
127 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
128 unsafe { stack.attach_broadcast_socket(ptr_erase) };
129 SocketHdl {
130 ptr: ptr_self,
131 _lt: PhantomData,
132 port: 255,
133 }
134 }
135
136 const fn vtable() -> SocketVTable {
137 SocketVTable {
138 recv_owned: Some(Self::recv_owned),
139 recv_bor: None,
150 recv_raw: Self::recv_raw,
151 recv_err: Some(Self::recv_err),
152 }
153 }
154
155 pub fn stack(&self) -> N::Target {
156 self.net.clone()
157 }
158
159 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
160 let this: NonNull<Self> = this.cast();
161 let this: &Self = unsafe { this.as_ref() };
162 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
163
164 let msg = Err(HeaderMessage { hdr, t: err });
165 if mutitem.sto.push(msg).is_ok() {
166 if let Some(w) = mutitem.wait.take() {
167 w.wake();
168 }
169 }
170 }
171
172 fn recv_owned(
173 this: NonNull<()>,
174 that: NonNull<()>,
175 hdr: HeaderSeq,
176 ty: &TypeId,
177 ) -> Result<(), SocketSendError> {
178 if &TypeId::of::<T>() != ty {
179 debug_assert!(false, "Type Mismatch!");
180 return Err(SocketSendError::TypeMismatch);
181 }
182 let that: NonNull<T> = that.cast();
183 let that: &T = unsafe { that.as_ref() };
184 let this: NonNull<Self> = this.cast();
185 let this: &Self = unsafe { this.as_ref() };
186 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
187
188 let msg = Ok(HeaderMessage {
189 hdr,
190 t: that.clone(),
191 });
192
193 match mutitem.sto.push(msg) {
194 Ok(()) => {
195 if let Some(w) = mutitem.wait.take() {
196 w.wake();
197 }
198 Ok(())
199 }
200 Err(StorageFull) => Err(SocketSendError::NoSpace),
201 }
202 }
203
204 fn recv_raw(
205 this: NonNull<()>,
206 that: &[u8],
207 hdr: HeaderSeq,
208 _hdr_raw: &[u8],
209 ) -> Result<(), SocketSendError> {
210 let this: NonNull<Self> = this.cast();
211 let this: &Self = unsafe { this.as_ref() };
212 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
213
214 if mutitem.sto.is_full() {
215 return Err(SocketSendError::NoSpace);
216 }
217
218 if let Ok(t) = postcard::from_bytes::<T>(that) {
219 let msg = Ok(HeaderMessage { hdr, t });
220 let _ = mutitem.sto.push(msg);
221 if let Some(w) = mutitem.wait.take() {
222 w.wake();
223 }
224 Ok(())
225 } else {
226 Err(SocketSendError::DeserFailed)
227 }
228 }
229}
230
231impl<'a, S, T, N> SocketHdl<'a, S, T, N>
234where
235 S: Storage<Response<T>>,
236 T: Clone + DeserializeOwned + 'static,
237 N: NetStackHandle,
238{
239 pub fn port(&self) -> u8 {
240 self.port
241 }
242
243 pub fn stack(&self) -> N::Target {
244 unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
245 }
246
247 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, N> {
248 Recv { hdl: self }
249 }
250}
251
252impl<S, T, N> Drop for Socket<S, T, N>
253where
254 S: Storage<Response<T>>,
255 T: Clone + DeserializeOwned + 'static,
256 N: NetStackHandle,
257{
258 fn drop(&mut self) {
259 unsafe {
260 let this = NonNull::from(&self.hdr);
261 self.net.detach_socket(this);
262 }
263 }
264}
265
266unsafe impl<S, T, N> Send for SocketHdl<'_, S, T, N>
267where
268 S: Storage<Response<T>>,
269 T: Send,
270 T: Clone + DeserializeOwned + 'static,
271 N: NetStackHandle,
272{
273}
274
275unsafe impl<S, T, N> Sync for SocketHdl<'_, S, T, N>
276where
277 S: Storage<Response<T>>,
278 T: Send,
279 T: Clone + DeserializeOwned + 'static,
280 N: NetStackHandle,
281{
282}
283
284impl<S, T, N> Future for Recv<'_, '_, S, T, N>
287where
288 S: Storage<Response<T>>,
289 T: Clone + DeserializeOwned + 'static,
290 N: NetStackHandle,
291{
292 type Output = Response<T>;
293
294 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
295 let net: N::Target = self.hdl.stack();
296 let f = || {
297 let this_ref: &Socket<S, T, N> = unsafe { self.hdl.ptr.as_ref() };
298 let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
299
300 if let Some(resp) = box_ref.sto.try_pop() {
301 return Some(resp);
302 }
303
304 let new_wake = cx.waker();
305 if let Some(w) = box_ref.wait.take() {
306 if !w.will_wake(new_wake) {
307 w.wake();
308 }
309 }
310 box_ref.wait = Some(new_wake.clone());
313 None
314 };
315 let res = unsafe { net.with_lock(f) };
316 if let Some(t) = res {
317 Poll::Ready(t)
318 } else {
319 Poll::Pending
320 }
321 }
322}
323
324unsafe impl<S, T, N> Sync for Recv<'_, '_, S, T, N>
325where
326 S: Storage<Response<T>>,
327 T: Send,
328 T: Clone + DeserializeOwned + 'static,
329 N: NetStackHandle,
330{
331}
332
333impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
336 const fn new(sto: S) -> Self {
337 Self {
338 wait: None,
339 sto,
340 _pd: PhantomData,
341 }
342 }
343}