1use core::{
2 any::TypeId,
3 cell::UnsafeCell,
4 marker::PhantomData,
5 pin::Pin,
6 ptr::{NonNull, addr_of},
7 task::{Context, Poll, Waker},
8};
9
10use cordyceps::list::Links;
11use mutex::ScopedRawMutex;
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{FrameKind, HeaderSeq, Key, NetStack, interface_manager::InterfaceManager};
15
16use super::{OwnedMessage, SocketHeader, SocketSendError, SocketVTable};
17
18#[repr(C)]
20pub struct OwnedSocket<T, R, M>
21where
22 T: Serialize + DeserializeOwned + 'static,
23 R: ScopedRawMutex + 'static,
24 M: InterfaceManager + 'static,
25{
26 hdr: SocketHeader,
28 pub(crate) net: &'static NetStack<R, M>,
29 inner: UnsafeCell<OneBox<T>>,
32}
33
34pub struct OwnedSocketHdl<'a, T, R, M>
35where
36 T: Serialize + DeserializeOwned + 'static,
37 R: ScopedRawMutex + 'static,
38 M: InterfaceManager + 'static,
39{
40 pub(crate) ptr: NonNull<OwnedSocket<T, R, M>>,
41 _lt: PhantomData<Pin<&'a mut OwnedSocket<T, R, M>>>,
42 port: u8,
43}
44
45pub struct Recv<'a, 'b, T, R, M>
46where
47 T: Serialize + DeserializeOwned + 'static,
48 R: ScopedRawMutex + 'static,
49 M: InterfaceManager + 'static,
50{
51 hdl: &'a mut OwnedSocketHdl<'b, T, R, M>,
52}
53
54struct OneBox<T: 'static> {
55 wait: Option<Waker>,
56 t: Option<OwnedMessage<T>>,
57}
58
59impl<T, R, M> OwnedSocket<T, R, M>
68where
69 T: Serialize + DeserializeOwned + 'static,
70 R: ScopedRawMutex + 'static,
71 M: InterfaceManager + 'static,
72{
73 pub const fn new(net: &'static NetStack<R, M>, key: Key, kind: FrameKind) -> Self {
74 Self {
75 hdr: SocketHeader {
76 links: Links::new(),
77 vtable: const { &Self::vtable() },
78 port: 0,
79 kind,
80 key,
81 },
82 inner: UnsafeCell::new(OneBox::new()),
83 net,
84 }
85 }
86
87 pub fn attach<'a>(self: Pin<&'a mut Self>) -> OwnedSocketHdl<'a, T, R, M> {
88 let stack = self.net;
89 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
90 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
91 let port = unsafe { stack.attach_socket(ptr_erase) };
92 OwnedSocketHdl {
93 ptr: ptr_self,
94 _lt: PhantomData,
95 port,
96 }
97 }
99
100 const fn vtable() -> SocketVTable {
101 SocketVTable {
102 send_owned: Some(Self::send_owned),
103 send_bor: None,
107 send_raw: Self::send_raw,
108 }
109 }
110
111 pub fn stack(&self) -> &'static NetStack<R, M> {
112 self.net
113 }
114
115 fn send_owned(
116 this: NonNull<()>,
117 that: NonNull<()>,
118 hdr: HeaderSeq,
119 ty: &TypeId,
120 ) -> Result<(), SocketSendError> {
121 if &TypeId::of::<T>() != ty {
122 debug_assert!(false, "Type Mismatch!");
123 return Err(SocketSendError::TypeMismatch);
124 }
125 let that: NonNull<T> = that.cast();
126 let this: NonNull<Self> = this.cast();
127 let this: &Self = unsafe { this.as_ref() };
128 let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
129
130 if mutitem.t.is_some() {
131 return Err(SocketSendError::NoSpace);
132 }
133
134 mutitem.t = Some(OwnedMessage {
135 hdr,
136 t: unsafe { that.read() },
137 });
138 if let Some(w) = mutitem.wait.take() {
139 w.wake();
140 }
141
142 Ok(())
143 }
144
145 fn send_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
156 let this: NonNull<Self> = this.cast();
157 let this: &Self = unsafe { this.as_ref() };
158 let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
159
160 if mutitem.t.is_some() {
161 return Err(SocketSendError::NoSpace);
162 }
163
164 if let Ok(t) = postcard::from_bytes::<T>(that) {
165 mutitem.t = Some(OwnedMessage { hdr, t });
166 if let Some(w) = mutitem.wait.take() {
167 w.wake();
168 }
169 Ok(())
170 } else {
171 Err(SocketSendError::DeserFailed)
172 }
173 }
174}
175
176impl<'a, T, R, M> OwnedSocketHdl<'a, T, R, M>
180where
181 T: Serialize + DeserializeOwned + 'static,
182 R: ScopedRawMutex + 'static,
183 M: InterfaceManager + 'static,
184{
185 pub fn port(&self) -> u8 {
186 self.port
187 }
188
189 pub fn stack(&self) -> &'static NetStack<R, M> {
190 unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
191 }
192
193 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
197 Recv { hdl: self }
198 }
199}
200
201impl<T, R, M> Drop for OwnedSocket<T, R, M>
202where
203 T: Serialize + DeserializeOwned + 'static,
204 R: ScopedRawMutex + 'static,
205 M: InterfaceManager + 'static,
206{
207 fn drop(&mut self) {
208 println!("Dropping OwnedSocket!");
209 unsafe {
210 let this = NonNull::from(&self.hdr);
211 self.net.detach_socket(this);
212 }
213 }
214}
215
216unsafe impl<T, R, M> Send for OwnedSocketHdl<'_, T, R, M>
217where
218 T: Send,
219 T: Serialize + DeserializeOwned + 'static,
220 R: ScopedRawMutex + 'static,
221 M: InterfaceManager + 'static,
222{
223}
224
225unsafe impl<T, R, M> Sync for OwnedSocketHdl<'_, T, R, M>
226where
227 T: Send,
228 T: Serialize + DeserializeOwned + 'static,
229 R: ScopedRawMutex + 'static,
230 M: InterfaceManager + 'static,
231{
232}
233
234impl<T, R, M> Future for Recv<'_, '_, T, R, M>
237where
238 T: Serialize + DeserializeOwned + 'static,
239 R: ScopedRawMutex + 'static,
240 M: InterfaceManager + 'static,
241{
242 type Output = OwnedMessage<T>;
243
244 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245 let net: &'static NetStack<R, M> = self.hdl.stack();
246 let f = || {
247 let this_ref: &OwnedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
248 let box_ref: &mut OneBox<T> = unsafe { &mut *this_ref.inner.get() };
249 if let Some(t) = box_ref.t.take() {
250 Some(t)
251 } else {
252 let new_wake = cx.waker();
253 if let Some(w) = box_ref.wait.take() {
254 if !w.will_wake(new_wake) {
255 w.wake();
256 }
257 }
258 box_ref.wait = Some(new_wake.clone());
261 None
262 }
263 };
264 let res = unsafe { net.with_lock(f) };
265 if let Some(t) = res {
266 Poll::Ready(t)
267 } else {
268 Poll::Pending
269 }
270 }
271}
272
273unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
274where
275 T: Send,
276 T: Serialize + DeserializeOwned + 'static,
277 R: ScopedRawMutex + 'static,
278 M: InterfaceManager + 'static,
279{
280}
281
282impl<T: 'static> OneBox<T> {
285 const fn new() -> Self {
286 Self {
287 wait: None,
288 t: None,
289 }
290 }
291}
292
293impl<T: 'static> Default for OneBox<T> {
294 fn default() -> Self {
295 Self::new()
296 }
297}