1use core::fmt::Debug;
2use core::mem::MaybeUninit;
3use core::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
4
5use embedded_io_async::{ErrorKind, ErrorType};
6
7use edge_nal::{MacAddr, RawReceive, RawSend, RawSplit, Readable, UdpReceive, UdpSend, UdpSplit};
8
9use crate as raw;
10
11#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub enum Error<E> {
14 Io(E),
15 UnsupportedProtocol,
16 RawError(raw::Error),
17}
18
19impl<E> From<raw::Error> for Error<E> {
20 fn from(value: raw::Error) -> Self {
21 Self::RawError(value)
22 }
23}
24
25impl<E> embedded_io_async::Error for Error<E>
26where
27 E: embedded_io_async::Error,
28{
29 fn kind(&self) -> ErrorKind {
30 match self {
31 Self::Io(err) => err.kind(),
32 Self::UnsupportedProtocol => ErrorKind::InvalidInput,
33 Self::RawError(_) => ErrorKind::InvalidData,
34 }
35 }
36}
37
38impl<E> core::fmt::Display for Error<E>
39where
40 E: core::fmt::Display,
41{
42 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
43 match self {
44 Self::Io(err) => write!(f, "IO error: {}", err),
45 Self::UnsupportedProtocol => write!(f, "Unsupported protocol"),
46 Self::RawError(err) => write!(f, "Raw error: {}", err),
47 }
48 }
49}
50
51#[cfg(feature = "defmt")]
52impl<E> defmt::Format for Error<E>
53where
54 E: defmt::Format,
55{
56 fn format(&self, f: defmt::Formatter<'_>) {
57 match self {
58 Self::Io(err) => defmt::write!(f, "IO error: {}", err),
59 Self::UnsupportedProtocol => defmt::write!(f, "Unsupported protocol"),
60 Self::RawError(err) => defmt::write!(f, "Raw error: {}", err),
61 }
62 }
63}
64
65impl<E> core::error::Error for Error<E> where E: core::error::Error {}
66
67pub struct RawSocket2Udp<T, const N: usize = 1500> {
76 socket: T,
77 filter_local: Option<SocketAddrV4>,
78 filter_remote: Option<SocketAddrV4>,
79 remote_mac: MacAddr,
80}
81
82impl<T, const N: usize> RawSocket2Udp<T, N> {
83 pub fn new(
84 socket: T,
85 filter_local: Option<SocketAddrV4>,
86 filter_remote: Option<SocketAddrV4>,
87 remote_mac: MacAddr,
88 ) -> Self {
89 Self {
90 socket,
91 filter_local,
92 filter_remote,
93 remote_mac,
94 }
95 }
96}
97
98impl<T, const N: usize> ErrorType for RawSocket2Udp<T, N>
99where
100 T: ErrorType,
101{
102 type Error = Error<T::Error>;
103}
104
105impl<T, const N: usize> UdpReceive for RawSocket2Udp<T, N>
106where
107 T: RawReceive,
108{
109 async fn receive(&mut self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
110 let (len, _local, remote, _) = udp_receive::<_, N>(
111 &mut self.socket,
112 self.filter_local,
113 self.filter_remote,
114 buffer,
115 )
116 .await?;
117
118 Ok((len, remote))
119 }
120}
121
122impl<T, const N: usize> Readable for RawSocket2Udp<T, N>
123where
124 T: Readable,
125{
126 async fn readable(&mut self) -> Result<(), Self::Error> {
127 self.socket.readable().await.map_err(Error::Io)
128 }
129}
130
131impl<T, const N: usize> UdpSend for RawSocket2Udp<T, N>
132where
133 T: RawSend,
134{
135 async fn send(&mut self, remote: SocketAddr, data: &[u8]) -> Result<(), Self::Error> {
136 let remote = match remote {
137 SocketAddr::V4(remote) => remote,
138 SocketAddr::V6(_) => Err(Error::UnsupportedProtocol)?,
139 };
140
141 udp_send::<_, N>(
142 &mut self.socket,
143 SocketAddr::V4(
144 self.filter_local
145 .unwrap_or(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
146 ),
147 SocketAddr::V4(remote),
148 self.remote_mac,
149 data,
150 )
151 .await
152 }
153}
154
155impl<T, const N: usize> UdpSplit for RawSocket2Udp<T, N>
156where
157 T: RawSplit,
158{
159 type Receive<'a>
160 = RawSocket2Udp<T::Receive<'a>, N>
161 where
162 Self: 'a;
163 type Send<'a>
164 = RawSocket2Udp<T::Send<'a>, N>
165 where
166 Self: 'a;
167
168 fn split(&mut self) -> (Self::Receive<'_>, Self::Send<'_>) {
169 let (receive, send) = self.socket.split();
170
171 (
172 RawSocket2Udp::new(
173 receive,
174 self.filter_local,
175 self.filter_remote,
176 self.remote_mac,
177 ),
178 RawSocket2Udp::new(send, self.filter_local, self.filter_remote, self.remote_mac),
179 )
180 }
181}
182
183pub async fn udp_send<T: RawSend, const N: usize>(
185 mut socket: T,
186 local: SocketAddr,
187 remote: SocketAddr,
188 remote_mac: MacAddr,
189 data: &[u8],
190) -> Result<(), Error<T::Error>> {
191 let (SocketAddr::V4(local), SocketAddr::V4(remote)) = (local, remote) else {
192 Err(Error::UnsupportedProtocol)?
193 };
194
195 let mut buf = MaybeUninit::<[u8; N]>::uninit();
196 let buf = unsafe { buf.assume_init_mut() };
197
198 let data = raw::ip_udp_encode(buf, local, remote, |buf| {
199 if data.len() <= buf.len() {
200 buf[..data.len()].copy_from_slice(data);
201
202 Ok(data.len())
203 } else {
204 Err(raw::Error::BufferOverflow)
205 }
206 })?;
207
208 socket.send(remote_mac, data).await.map_err(Error::Io)
209}
210
211pub async fn udp_receive<T: RawReceive, const N: usize>(
213 mut socket: T,
214 filter_local: Option<SocketAddrV4>,
215 filter_remote: Option<SocketAddrV4>,
216 buffer: &mut [u8],
217) -> Result<(usize, SocketAddr, SocketAddr, MacAddr), Error<T::Error>> {
218 let mut buf = MaybeUninit::<[u8; N]>::uninit();
219 let buf = unsafe { buf.assume_init_mut() };
220
221 let (len, local, remote, remote_mac) = loop {
222 let (len, remote_mac) = socket.receive(buf).await.map_err(Error::Io)?;
223
224 match raw::ip_udp_decode(&buf[..len], filter_remote, filter_local) {
225 Ok(Some((remote, local, data))) => {
226 if data.len() > buffer.len() {
227 Err(Error::RawError(raw::Error::BufferOverflow))?;
228 }
229
230 buffer[..data.len()].copy_from_slice(data);
231
232 break (data.len(), local, remote, remote_mac);
233 }
234 Ok(None) => continue,
235 Err(raw::Error::InvalidFormat) | Err(raw::Error::InvalidChecksum) => continue,
236 Err(other) => Err(other)?,
237 }
238 };
239
240 Ok((
241 len,
242 SocketAddr::V4(local),
243 SocketAddr::V4(remote),
244 remote_mac,
245 ))
246}