1use core::future::{poll_fn, Future};
4use core::mem;
5use core::task::{Context, Poll};
6
7use embassy_net_driver::Driver;
8use smoltcp::iface::{Interface, SocketHandle};
9use smoltcp::socket::raw;
10pub use smoltcp::socket::raw::PacketMetadata;
11pub use smoltcp::wire::{IpProtocol, IpVersion};
12
13use crate::Stack;
14
15#[derive(PartialEq, Eq, Clone, Copy, Debug)]
17#[cfg_attr(feature = "defmt", derive(defmt::Format))]
18pub enum RecvError {
19 Truncated,
21}
22
23pub struct RawSocket<'a> {
25 stack: Stack<'a>,
26 handle: SocketHandle,
27}
28
29impl<'a> RawSocket<'a> {
30 pub fn new<D: Driver>(
32 stack: Stack<'a>,
33 ip_version: IpVersion,
34 ip_protocol: IpProtocol,
35 rx_meta: &'a mut [PacketMetadata],
36 rx_buffer: &'a mut [u8],
37 tx_meta: &'a mut [PacketMetadata],
38 tx_buffer: &'a mut [u8],
39 ) -> Self {
40 let handle = stack.with_mut(|i| {
41 let rx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(rx_meta) };
42 let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
43 let tx_meta: &'static mut [PacketMetadata] = unsafe { mem::transmute(tx_meta) };
44 let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
45 i.sockets.add(raw::Socket::new(
46 ip_version,
47 ip_protocol,
48 raw::PacketBuffer::new(rx_meta, rx_buffer),
49 raw::PacketBuffer::new(tx_meta, tx_buffer),
50 ))
51 });
52
53 Self { stack, handle }
54 }
55
56 fn with_mut<R>(&self, f: impl FnOnce(&mut raw::Socket, &mut Interface) -> R) -> R {
57 self.stack.with_mut(|i| {
58 let socket = i.sockets.get_mut::<raw::Socket>(self.handle);
59 let res = f(socket, &mut i.iface);
60 i.waker.wake();
61 res
62 })
63 }
64
65 pub fn wait_recv_ready(&self) -> impl Future<Output = ()> + '_ {
70 poll_fn(move |cx| self.poll_recv_ready(cx))
71 }
72
73 pub async fn recv(&self, buf: &mut [u8]) -> Result<usize, RecvError> {
77 poll_fn(move |cx| self.poll_recv(buf, cx)).await
78 }
79
80 pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
87 self.with_mut(|s, _| {
88 if s.can_recv() {
89 Poll::Ready(())
90 } else {
91 s.register_recv_waker(cx.waker());
93 Poll::Pending
94 }
95 })
96 }
97
98 pub fn poll_recv(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<Result<usize, RecvError>> {
103 self.with_mut(|s, _| match s.recv_slice(buf) {
104 Ok(n) => Poll::Ready(Ok(n)),
105 Err(raw::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)),
107 Err(raw::RecvError::Exhausted) => {
108 s.register_recv_waker(cx.waker());
109 Poll::Pending
110 }
111 })
112 }
113
114 pub fn wait_send_ready(&self) -> impl Future<Output = ()> + '_ {
119 poll_fn(move |cx| self.poll_send_ready(cx))
120 }
121
122 pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
130 self.with_mut(|s, _| {
131 if s.can_send() {
132 Poll::Ready(())
133 } else {
134 s.register_send_waker(cx.waker());
136 Poll::Pending
137 }
138 })
139 }
140
141 pub fn send<'s>(&'s self, buf: &'s [u8]) -> impl Future<Output = ()> + 's {
145 poll_fn(|cx| self.poll_send(buf, cx))
146 }
147
148 pub fn poll_send(&self, buf: &[u8], cx: &mut Context<'_>) -> Poll<()> {
155 self.with_mut(|s, _| match s.send_slice(buf) {
156 Ok(()) => Poll::Ready(()),
158 Err(raw::SendError::BufferFull) => {
159 s.register_send_waker(cx.waker());
160 Poll::Pending
161 }
162 })
163 }
164
165 pub fn flush(&mut self) -> impl Future<Output = ()> + '_ {
169 poll_fn(|cx| {
170 self.with_mut(|s, _| {
171 if s.send_queue() == 0 {
172 Poll::Ready(())
173 } else {
174 s.register_send_waker(cx.waker());
175 Poll::Pending
176 }
177 })
178 })
179 }
180}
181
182impl Drop for RawSocket<'_> {
183 fn drop(&mut self) {
184 self.stack.with_mut(|i| i.sockets.remove(self.handle));
185 }
186}
187
188fn _assert_covariant<'a, 'b: 'a>(x: RawSocket<'b>) -> RawSocket<'a> {
189 x
190}