1use core::{
2 any::TypeId,
3 cell::UnsafeCell,
4 marker::PhantomData,
5 pin::Pin,
6 ptr::{NonNull, addr_of},
7 task::{Context, Poll, Waker},
8};
9
10use cordyceps::list::Links;
11use mutex::ScopedRawMutex;
12use serde::{Serialize, de::DeserializeOwned};
13
14use crate::{HeaderSeq, Key, NetStack, ProtocolError, interface_manager::InterfaceManager};
15
16use super::{
17 Attributes, Contents, OwnedMessage, Response, SocketHeader, SocketSendError, SocketVTable,
18};
19
20#[repr(C)]
22pub struct OwnedSocket<T, R, M>
23where
24 T: Serialize + Clone + DeserializeOwned + 'static,
25 R: ScopedRawMutex + 'static,
26 M: InterfaceManager + 'static,
27{
28 hdr: SocketHeader,
30 pub(crate) net: &'static NetStack<R, M>,
31 inner: UnsafeCell<OneBox<T>>,
34}
35
36pub struct OwnedSocketHdl<'a, T, R, M>
37where
38 T: Serialize + Clone + DeserializeOwned + 'static,
39 R: ScopedRawMutex + 'static,
40 M: InterfaceManager + 'static,
41{
42 pub(crate) ptr: NonNull<OwnedSocket<T, R, M>>,
43 _lt: PhantomData<Pin<&'a mut OwnedSocket<T, R, M>>>,
44 port: u8,
45}
46
47pub struct Recv<'a, 'b, T, R, M>
48where
49 T: Serialize + Clone + DeserializeOwned + 'static,
50 R: ScopedRawMutex + 'static,
51 M: InterfaceManager + 'static,
52{
53 hdl: &'a mut OwnedSocketHdl<'b, T, R, M>,
54}
55
56struct OneBox<T: 'static> {
57 wait: Option<Waker>,
58 t: Contents<T>,
59}
60
61impl<T, R, M> OwnedSocket<T, R, M>
70where
71 T: Serialize + Clone + DeserializeOwned + 'static,
72 R: ScopedRawMutex + 'static,
73 M: InterfaceManager + 'static,
74{
75 pub const fn new(net: &'static NetStack<R, M>, key: Key, attrs: Attributes) -> Self {
76 Self {
77 hdr: SocketHeader {
78 links: Links::new(),
79 vtable: const { &Self::vtable() },
80 port: 0,
81 attrs,
82 key,
83 },
84 inner: UnsafeCell::new(OneBox::new()),
85 net,
86 }
87 }
88
89 pub fn attach<'a>(self: Pin<&'a mut Self>) -> OwnedSocketHdl<'a, T, R, M> {
90 let stack = self.net;
91 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
92 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
93 let port = unsafe { stack.attach_socket(ptr_erase) };
94 OwnedSocketHdl {
95 ptr: ptr_self,
96 _lt: PhantomData,
97 port,
98 }
99 }
100
101 pub fn attach_broadcast<'a>(self: Pin<&'a mut Self>) -> OwnedSocketHdl<'a, T, R, M> {
102 let stack = self.net;
103 let ptr_self: NonNull<Self> = NonNull::from(unsafe { self.get_unchecked_mut() });
104 let ptr_erase: NonNull<SocketHeader> = ptr_self.cast();
105 unsafe { stack.attach_broadcast_socket(ptr_erase) };
106 OwnedSocketHdl {
107 ptr: ptr_self,
108 _lt: PhantomData,
109 port: 255,
110 }
111 }
112
113 const fn vtable() -> SocketVTable {
114 SocketVTable {
115 recv_owned: Some(Self::recv_owned),
116 recv_bor: None,
120 recv_raw: Self::recv_raw,
121 recv_err: Some(Self::recv_err),
122 }
123 }
124
125 pub fn stack(&self) -> &'static NetStack<R, M> {
126 self.net
127 }
128
129 fn recv_err(this: NonNull<()>, hdr: HeaderSeq, err: ProtocolError) {
130 let this: NonNull<Self> = this.cast();
131 let this: &Self = unsafe { this.as_ref() };
132 let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
133
134 if !matches!(mutitem.t, Contents::None) {
135 return;
136 }
137
138 mutitem.t = Contents::Err(OwnedMessage { hdr, t: err });
139 if let Some(w) = mutitem.wait.take() {
140 w.wake();
141 }
142 }
143
144 fn recv_owned(
145 this: NonNull<()>,
146 that: NonNull<()>,
147 hdr: HeaderSeq,
148 ty: &TypeId,
149 ) -> Result<(), SocketSendError> {
150 if &TypeId::of::<T>() != ty {
151 debug_assert!(false, "Type Mismatch!");
152 return Err(SocketSendError::TypeMismatch);
153 }
154 let that: NonNull<T> = that.cast();
155 let that: &T = unsafe { that.as_ref() };
156 let this: NonNull<Self> = this.cast();
157 let this: &Self = unsafe { this.as_ref() };
158 let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
159
160 if !matches!(mutitem.t, Contents::None) {
161 return Err(SocketSendError::NoSpace);
162 }
163
164 mutitem.t = Contents::Mesg(OwnedMessage {
165 hdr,
166 t: that.clone(),
167 });
168 if let Some(w) = mutitem.wait.take() {
169 w.wake();
170 }
171
172 Ok(())
173 }
174
175 fn recv_raw(this: NonNull<()>, that: &[u8], hdr: HeaderSeq) -> Result<(), SocketSendError> {
186 let this: NonNull<Self> = this.cast();
187 let this: &Self = unsafe { this.as_ref() };
188 let mutitem: &mut OneBox<T> = unsafe { &mut *this.inner.get() };
189
190 if !matches!(mutitem.t, Contents::None) {
191 return Err(SocketSendError::NoSpace);
192 }
193
194 if let Ok(t) = postcard::from_bytes::<T>(that) {
195 mutitem.t = Contents::Mesg(OwnedMessage { hdr, t });
196 if let Some(w) = mutitem.wait.take() {
197 w.wake();
198 }
199 Ok(())
200 } else {
201 Err(SocketSendError::DeserFailed)
202 }
203 }
204}
205
206impl<'a, T, R, M> OwnedSocketHdl<'a, T, R, M>
210where
211 T: Serialize + Clone + DeserializeOwned + 'static,
212 R: ScopedRawMutex + 'static,
213 M: InterfaceManager + 'static,
214{
215 pub fn port(&self) -> u8 {
216 self.port
217 }
218
219 pub fn stack(&self) -> &'static NetStack<R, M> {
220 unsafe { *addr_of!((*self.ptr.as_ptr()).net) }
221 }
222
223 pub fn recv<'b>(&'b mut self) -> Recv<'b, 'a, T, R, M> {
227 Recv { hdl: self }
228 }
229}
230
231impl<T, R, M> Drop for OwnedSocket<T, R, M>
232where
233 T: Serialize + Clone + DeserializeOwned + 'static,
234 R: ScopedRawMutex + 'static,
235 M: InterfaceManager + 'static,
236{
237 fn drop(&mut self) {
238 println!("Dropping OwnedSocket!");
239 unsafe {
240 let this = NonNull::from(&self.hdr);
241 self.net.detach_socket(this);
242 }
243 }
244}
245
246unsafe impl<T, R, M> Send for OwnedSocketHdl<'_, T, R, M>
247where
248 T: Send,
249 T: Serialize + Clone + DeserializeOwned + 'static,
250 R: ScopedRawMutex + 'static,
251 M: InterfaceManager + 'static,
252{
253}
254
255unsafe impl<T, R, M> Sync for OwnedSocketHdl<'_, T, R, M>
256where
257 T: Send,
258 T: Serialize + Clone + DeserializeOwned + 'static,
259 R: ScopedRawMutex + 'static,
260 M: InterfaceManager + 'static,
261{
262}
263
264impl<T, R, M> Future for Recv<'_, '_, T, R, M>
267where
268 T: Serialize + Clone + DeserializeOwned + 'static,
269 R: ScopedRawMutex + 'static,
270 M: InterfaceManager + 'static,
271{
272 type Output = Response<T>;
273
274 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
275 let net: &'static NetStack<R, M> = self.hdl.stack();
276 let f = || {
277 let this_ref: &OwnedSocket<T, R, M> = unsafe { self.hdl.ptr.as_ref() };
278 let box_ref: &mut OneBox<T> = unsafe { &mut *this_ref.inner.get() };
279 match core::mem::replace(&mut box_ref.t, Contents::None) {
280 Contents::Mesg(owned_message) => return Some(Ok(owned_message)),
281 Contents::Err(owned_message) => return Some(Err(owned_message)),
282 Contents::None => {}
283 }
284
285 let new_wake = cx.waker();
286 if let Some(w) = box_ref.wait.take() {
287 if !w.will_wake(new_wake) {
288 w.wake();
289 }
290 }
291 box_ref.wait = Some(new_wake.clone());
294 None
295 };
296 let res = unsafe { net.with_lock(f) };
297 if let Some(t) = res {
298 Poll::Ready(t)
299 } else {
300 Poll::Pending
301 }
302 }
303}
304
305unsafe impl<T, R, M> Sync for Recv<'_, '_, T, R, M>
306where
307 T: Send,
308 T: Serialize + Clone + DeserializeOwned + 'static,
309 R: ScopedRawMutex + 'static,
310 M: InterfaceManager + 'static,
311{
312}
313
314impl<T: 'static> OneBox<T> {
317 const fn new() -> Self {
318 Self {
319 wait: None,
320 t: Contents::None,
321 }
322 }
323}
324
325impl<T: 'static> Default for OneBox<T> {
326 fn default() -> Self {
327 Self::new()
328 }
329}