1use core::{
12 any::TypeId,
13 cell::UnsafeCell,
14 marker::PhantomData,
15 pin::Pin,
16 ptr::{NonNull, addr_of},
17 task::{Context, Poll, Waker},
18};
19
20use cordyceps::list::Links;
21use serde::de::DeserializeOwned;
22
23use crate::{HeaderSeq, Key, ProtocolError, nash::NameHash, net_stack::NetStackHandle};
24
25use super::{Attributes, HeaderMessage, Response, SocketHeader, SocketSendError, SocketVTable};
26
27#[derive(Debug, PartialEq)]
28pub struct StorageFull;
29
30pub trait Storage<T: 'static>: 'static {
31 fn is_full(&self) -> bool;
32 fn is_empty(&self) -> bool;
33 fn push(&mut self, t: T) -> Result<(), StorageFull>;
34 fn try_pop(&mut self) -> Option<T>;
35}
36
37#[repr(C)]
39pub struct Socket<S, T, N>
40where
41 S: Storage<Response<T>>,
42 T: Clone + DeserializeOwned + 'static,
43 N: NetStackHandle,
44{
45 hdr: SocketHeader,
47 pub(crate) net: N::Target,
48 inner: UnsafeCell<StoreBox<S, Response<T>>>,
49}
50
51pub struct SocketHdl<'a, S, T, N>
52where
53 S: Storage<Response<T>>,
54 T: Clone + DeserializeOwned + 'static,
55 N: NetStackHandle,
56{
57 pub(crate) ptr: NonNull<Socket<S, T, N>>,
58 _lt: PhantomData<Pin<&'a mut Socket<S, T, N>>>,
59 port: u8,
60}
61
62pub struct Recv<'a, 'b, S, T, N>
63where
64 S: Storage<Response<T>>,
65 T: Clone + DeserializeOwned + 'static,
66 N: NetStackHandle,
67{
68 hdl: &'a mut SocketHdl<'b, S, T, N>,
69}
70
71struct StoreBox<S: Storage<T>, T: 'static> {
72 wait: Option<Waker>,
73 sto: S,
74 _pd: PhantomData<fn() -> T>,
75}
76
77impl<S, T, N> Socket<S, T, N>
82where
83 S: Storage<Response<T>>,
84 T: Clone + DeserializeOwned + 'static,
85 N: NetStackHandle,
86{
87 pub const fn new(
88 net: N::Target,
89 key: Key,
90 attrs: Attributes,
91 sto: S,
92 name: Option<&str>,
93 ) -> Self {
94 Self {
95 hdr: SocketHeader {
96 links: Links::new(),
97 vtable: const { &Self::vtable() },
98 port: 0,
99 attrs,
100 key,
101 nash: if let Some(n) = name {
102 Some(NameHash::new(n))
103 } else {
104 None
105 },
106 },
107 inner: UnsafeCell::new(StoreBox::new(sto)),
108 net,
109 }
110 }
111
112 pub fn attach<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
113 let stack = self.net.clone();
114 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
115 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
116 let port = unsafe { stack.attach_socket(ptr_erase) };
117 SocketHdl {
118 ptr: ptr_self,
119 _lt: PhantomData,
120 port,
121 }
122 }
123
124 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> SocketHdl<'a, S, T, N> {
125 let stack = self.net.clone();
126 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
127 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
128 unsafe { stack.attach_broadcast_socket(ptr_erase) };
129 SocketHdl {
130 ptr: ptr_self,
131 _lt: PhantomData,
132 port: 255,
133 }
134 }
135
136 const fn vtable() -> SocketVTable {
137 SocketVTable {
138 recv_owned: Some(Self::recv_owned),
139 recv_bor: None,
143 recv_raw: Self::recv_raw,
144 recv_err: Some(Self::recv_err),
145 }
146 }
147
148 pub fn stack(&self) -> N::Target {
149 self.net.clone()
150 }
151
152 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
153 let this: NonNull<Self> = this.cast();
154 let this: &Self = unsafe { this.as_ref() };
155 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
156
157 let msg = Err(HeaderMessage { hdr, t: err });
158 if mutitem.sto.push(msg).is_ok() {
159 if let Some(w) = mutitem.wait.take() {
160 w.wake();
161 }
162 }
163 }
164
165 fn recv_owned(
166 this: NonNull<()>,
167 that: NonNull<()>,
168 hdr: HeaderSeq,
169 ty: &TypeId,
170 ) -> Result<(), SocketSendError> {
171 if &TypeId::of::<T>() != ty {
172 debug_assert!(false, "Type Mismatch!");
173 return Err(SocketSendError::TypeMismatch);
174 }
175 let that: NonNull<T> = that.cast();
176 let that: &T = unsafe { that.as_ref() };
177 let this: NonNull<Self> = this.cast();
178 let this: &Self = unsafe { this.as_ref() };
179 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
180
181 let msg = Ok(HeaderMessage {
182 hdr,
183 t: that.clone(),
184 });
185
186 match mutitem.sto.push(msg) {
187 Ok(()) => {
188 if let Some(w) = mutitem.wait.take() {
189 w.wake();
190 }
191 Ok(())
192 }
193 Err(StorageFull) => Err(SocketSendError::NoSpace),
194 }
195 }
196
197 fn recv_raw(
198 this: NonNull<()>,
199 that: &[u8],
200 hdr: HeaderSeq,
201 _hdr_raw: &[u8],
202 ) -> Result<(), SocketSendError> {
203 let this: NonNull<Self> = this.cast();
204 let this: &Self = unsafe { this.as_ref() };
205 let mutitem: &mut StoreBox<S, Response<T>> = unsafe { &mut *this.inner.get() };
206
207 if mutitem.sto.is_full() {
208 return Err(SocketSendError::NoSpace);
209 }
210
211 if let Ok(t) = postcard::from_bytes::<T>(that) {
212 let msg = Ok(HeaderMessage { hdr, t });
213 let _ = mutitem.sto.push(msg);
214 if let Some(w) = mutitem.wait.take() {
215 w.wake();
216 }
217 Ok(())
218 } else {
219 Err(SocketSendError::DeserFailed)
220 }
221 }
222}
223
224impl<'a, S, T, N> SocketHdl<'a, S, T, N>
227where
228 S: Storage<Response<T>>,
229 T: Clone + DeserializeOwned + 'static,
230 N: NetStackHandle,
231{
232 pub fn port(&self) -> u8 {
233 self.port
234 }
235
236 pub fn stack(&self) -> N::Target {
237 unsafe { (*addr_of!((*self.ptr.as_ptr()).net)).clone() }
238 }
239
240 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, S, T, N> {
241 Recv { hdl: self }
242 }
243}
244
245impl<S, T, N> Drop for Socket<S, T, N>
246where
247 S: Storage<Response<T>>,
248 T: Clone + DeserializeOwned + 'static,
249 N: NetStackHandle,
250{
251 fn drop(&mut self) {
252 unsafe {
253 let this = NonNull::from(&self.hdr);
254 self.net.detach_socket(this);
255 }
256 }
257}
258
259unsafe impl<S, T, N> Send for SocketHdl<'_, S, T, N>
260where
261 S: Storage<Response<T>>,
262 T: Send,
263 T: Clone + DeserializeOwned + 'static,
264 N: NetStackHandle,
265{
266}
267
268unsafe impl<S, T, N> Sync for SocketHdl<'_, S, T, N>
269where
270 S: Storage<Response<T>>,
271 T: Send,
272 T: Clone + DeserializeOwned + 'static,
273 N: NetStackHandle,
274{
275}
276
277impl<S, T, N> Future for Recv<'_, '_, S, T, N>
280where
281 S: Storage<Response<T>>,
282 T: Clone + DeserializeOwned + 'static,
283 N: NetStackHandle,
284{
285 type Output = Response<T>;
286
287 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288 let net: N::Target = self.hdl.stack();
289 let f = || {
290 let this_ref: &Socket<S, T, N> = unsafe { self.hdl.ptr.as_ref() };
291 let box_ref: &mut StoreBox<S, Response<T>> = unsafe { &mut *this_ref.inner.get() };
292
293 if let Some(resp) = box_ref.sto.try_pop() {
294 return Some(resp);
295 }
296
297 let new_wake = cx.waker();
298 if let Some(w) = box_ref.wait.take() {
299 if !w.will_wake(new_wake) {
300 w.wake();
301 }
302 }
303 box_ref.wait = Some(new_wake.clone());
306 None
307 };
308 let res = unsafe { net.with_lock(f) };
309 if let Some(t) = res {
310 Poll::Ready(t)
311 } else {
312 Poll::Pending
313 }
314 }
315}
316
317unsafe impl<S, T, N> Sync for Recv<'_, '_, S, T, N>
318where
319 S: Storage<Response<T>>,
320 T: Send,
321 T: Clone + DeserializeOwned + 'static,
322 N: NetStackHandle,
323{
324}
325
326impl<S: Storage<T>, T: 'static> StoreBox<S, T> {
329 const fn new(sto: S) -> Self {
330 Self {
331 wait: None,
332 sto,
333 _pd: PhantomData,
334 }
335 }
336}