1#![allow(async_fn_in_trait)]
2#![warn(clippy::large_futures)]
3
4use core::pin::pin;
5
6use std::io;
7use std::net::{self, TcpStream, ToSocketAddrs, UdpSocket};
8
9use async_io::Async;
10use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
11
12use embedded_io_async::{ErrorType, Read, Write};
13
14use embedded_nal_async::{
15 AddrType, ConnectedUdp, Dns, IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6,
16 TcpConnect, UdpStack, UnconnectedUdp,
17};
18
19use embedded_nal_async_xtra::{Multicast, TcpAccept, TcpListen, TcpSplittableConnection};
20
21#[cfg(all(unix, not(target_os = "espidf")))]
22pub use raw::*;
23
24#[derive(Default)]
25pub struct Stack(());
26
27impl Stack {
28 pub const fn new() -> Self {
29 Self(())
30 }
31}
32
33impl TcpConnect for Stack {
34 type Error = io::Error;
35
36 type Connection<'a> = StdTcpConnection where Self: 'a;
37
38 async fn connect(&self, remote: SocketAddr) -> Result<Self::Connection<'_>, Self::Error> {
39 let connection = Async::<TcpStream>::connect(to_std_addr(remote)).await?;
40
41 Ok(StdTcpConnection(connection))
42 }
43}
44
45impl TcpListen for Stack {
46 type Error = io::Error;
47
48 type Acceptor<'m>
49 = StdTcpAccept where Self: 'm;
50
51 async fn listen(&self, remote: SocketAddr) -> Result<Self::Acceptor<'_>, Self::Error> {
52 Async::<net::TcpListener>::bind(to_std_addr(remote)).map(StdTcpAccept)
53 }
54}
55
56pub struct StdTcpAccept(Async<net::TcpListener>);
57
58impl TcpAccept for StdTcpAccept {
59 type Error = io::Error;
60
61 type Connection<'m> = StdTcpConnection;
62
63 #[cfg(not(target_os = "espidf"))]
64 async fn accept(&self) -> Result<Self::Connection<'_>, Self::Error> {
65 let connection = self.0.accept().await.map(|(socket, _)| socket)?;
66
67 Ok(StdTcpConnection(connection))
68 }
69
70 #[cfg(target_os = "espidf")]
71 async fn accept(&self) -> Result<Self::Connection<'_>, Self::Error> {
72 loop {
87 match self.0.as_ref().accept() {
88 Ok((connection, _)) => break Ok(StdTcpConnection(Async::new(connection)?)),
89 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
90 async_io::Timer::after(core::time::Duration::from_millis(5)).await;
91 }
92 Err(err) => break Err(err),
93 }
94 }
95 }
96}
97
98pub struct StdTcpConnection(Async<TcpStream>);
99
100impl ErrorType for StdTcpConnection {
101 type Error = io::Error;
102}
103
104impl Read for StdTcpConnection {
105 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
106 self.0.read(buf).await
107 }
108}
109
110impl Write for StdTcpConnection {
111 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
112 self.0.write(buf).await
113 }
114
115 async fn flush(&mut self) -> Result<(), Self::Error> {
116 self.0.flush().await
117 }
118}
119
120impl ErrorType for &StdTcpConnection {
121 type Error = io::Error;
122}
123
124impl Read for &StdTcpConnection {
125 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
126 (&self.0).read(buf).await
127 }
128}
129
130impl Write for &StdTcpConnection {
131 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
132 (&self.0).write(buf).await
133 }
134
135 async fn flush(&mut self) -> Result<(), Self::Error> {
136 (&self.0).flush().await
137 }
138}
139
140impl TcpSplittableConnection for StdTcpConnection {
141 type Read<'a> = &'a StdTcpConnection where Self: 'a;
142
143 type Write<'a> = &'a StdTcpConnection where Self: 'a;
144
145 fn split(&mut self) -> Result<(Self::Read<'_>, Self::Write<'_>), io::Error> {
146 let socket = &*self;
147
148 Ok((socket, socket))
149 }
150}
151
152impl UdpStack for Stack {
153 type Error = io::Error;
154
155 type Connected = StdUdpSocket;
156
157 type UniquelyBound = StdUdpSocket;
158
159 type MultiplyBound = StdUdpSocket;
160
161 async fn connect_from(
162 &self,
163 local: SocketAddr,
164 remote: SocketAddr,
165 ) -> Result<(SocketAddr, Self::Connected), Self::Error> {
166 let socket = Async::<UdpSocket>::bind(to_std_addr(local))?;
167
168 socket.as_ref().connect(to_std_addr(remote))?;
169
170 Ok((
171 to_nal_addr(socket.as_ref().local_addr()?),
172 StdUdpSocket(socket),
173 ))
174 }
175
176 async fn bind_single(
177 &self,
178 local: SocketAddr,
179 ) -> Result<(SocketAddr, Self::UniquelyBound), Self::Error> {
180 let socket = Async::<UdpSocket>::bind(to_std_addr(local))?;
181
182 socket.as_ref().set_broadcast(true)?;
183
184 Ok((
185 to_nal_addr(socket.as_ref().local_addr()?),
186 StdUdpSocket(socket),
187 ))
188 }
189
190 async fn bind_multiple(&self, _local: SocketAddr) -> Result<Self::MultiplyBound, Self::Error> {
191 unimplemented!() }
193}
194
195pub struct StdUdpSocket(Async<UdpSocket>);
196
197impl ConnectedUdp for StdUdpSocket {
198 type Error = io::Error;
199
200 async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> {
201 let mut offset = 0;
202
203 loop {
204 let fut = pin!(self.0.send(&data[offset..]));
205 offset += fut.await?;
206
207 if offset == data.len() {
208 break;
209 }
210 }
211
212 Ok(())
213 }
214
215 async fn receive_into(&mut self, buffer: &mut [u8]) -> Result<usize, Self::Error> {
216 let fut = pin!(self.0.recv(buffer));
217 fut.await
218 }
219}
220
221impl UnconnectedUdp for StdUdpSocket {
222 type Error = io::Error;
223
224 async fn send(
225 &mut self,
226 local: SocketAddr,
227 remote: SocketAddr,
228 data: &[u8],
229 ) -> Result<(), Self::Error> {
230 assert!(local == to_nal_addr(self.0.as_ref().local_addr()?));
231
232 let mut offset = 0;
233
234 loop {
235 let fut = pin!(self.0.send_to(data, to_std_addr(remote)));
236 offset += fut.await?;
237
238 if offset == data.len() {
239 break;
240 }
241 }
242
243 Ok(())
244 }
245
246 async fn receive_into(
247 &mut self,
248 buffer: &mut [u8],
249 ) -> Result<(usize, SocketAddr, SocketAddr), Self::Error> {
250 let fut = pin!(self.0.recv_from(buffer));
251 let (len, addr) = fut.await?;
252
253 Ok((
254 len,
255 to_nal_addr(self.0.as_ref().local_addr()?),
256 to_nal_addr(addr),
257 ))
258 }
259}
260
261impl Multicast for StdUdpSocket {
262 type Error = io::Error;
263
264 async fn join(&mut self, multicast_addr: IpAddr) -> Result<(), Self::Error> {
265 match multicast_addr {
266 IpAddr::V4(addr) => self
267 .0
268 .as_ref()
269 .join_multicast_v4(&addr.octets().into(), &std::net::Ipv4Addr::UNSPECIFIED)?,
270 IpAddr::V6(addr) => self
271 .0
272 .as_ref()
273 .join_multicast_v6(&addr.octets().into(), 0)?,
274 }
275
276 Ok(())
277 }
278
279 async fn leave(&mut self, multicast_addr: IpAddr) -> Result<(), Self::Error> {
280 match multicast_addr {
281 IpAddr::V4(addr) => self
282 .0
283 .as_ref()
284 .leave_multicast_v4(&addr.octets().into(), &std::net::Ipv4Addr::UNSPECIFIED)?,
285 IpAddr::V6(addr) => self
286 .0
287 .as_ref()
288 .leave_multicast_v6(&addr.octets().into(), 0)?,
289 }
290
291 Ok(())
292 }
293}
294
295impl Dns for Stack {
296 type Error = io::Error;
297
298 async fn get_host_by_name(
299 &self,
300 host: &str,
301 addr_type: AddrType,
302 ) -> Result<IpAddr, Self::Error> {
303 let host = host.to_string();
304
305 dns_lookup_host(&host, addr_type)
306 }
307
308 async fn get_host_by_address(
309 &self,
310 _addr: IpAddr,
311 _result: &mut [u8],
312 ) -> Result<usize, Self::Error> {
313 Err(io::ErrorKind::Unsupported.into())
314 }
315}
316
317fn dns_lookup_host(host: &str, addr_type: AddrType) -> Result<IpAddr, io::Error> {
318 (host, 0_u16)
319 .to_socket_addrs()?
320 .find(|addr| match addr_type {
321 AddrType::IPv4 => matches!(addr, std::net::SocketAddr::V4(_)),
322 AddrType::IPv6 => matches!(addr, std::net::SocketAddr::V6(_)),
323 AddrType::Either => true,
324 })
325 .map(|addr| match addr {
326 std::net::SocketAddr::V4(v4) => v4.ip().octets().into(),
327 std::net::SocketAddr::V6(v6) => v6.ip().octets().into(),
328 })
329 .ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
330}
331
332#[cfg(all(unix, not(target_os = "espidf")))]
333mod raw {
334 use core::pin::pin;
335
336 use std::io::{self, ErrorKind};
337 use std::os::fd::{AsFd, AsRawFd};
338
339 use async_io::Async;
340
341 use embedded_nal_async_xtra::{RawSocket, RawStack};
342
343 use crate::Stack;
344
345 pub struct StdRawSocket(Async<std::net::UdpSocket>, u32);
346
347 impl RawSocket for StdRawSocket {
348 type Error = io::Error;
349
350 async fn send(&mut self, mac: Option<&[u8; 6]>, data: &[u8]) -> Result<(), Self::Error> {
351 let mut sockaddr = libc::sockaddr_ll {
352 sll_family: libc::AF_PACKET as _,
353 sll_protocol: (libc::ETH_P_IP as u16).to_be() as _,
354 sll_ifindex: self.1 as _,
355 sll_hatype: 0,
356 sll_pkttype: 0,
357 sll_halen: 0,
358 sll_addr: Default::default(),
359 };
360
361 if let Some(mac) = mac {
362 sockaddr.sll_halen = mac.len() as _;
363 sockaddr.sll_addr[..mac.len()].copy_from_slice(mac);
364 }
365
366 let fut = pin!(self.0.write_with(|io| {
367 let len = core::cmp::min(data.len(), u16::MAX as usize);
368
369 let ret = cvti(unsafe {
370 libc::sendto(
371 io.as_fd().as_raw_fd(),
372 data.as_ptr() as *const _,
373 len,
374 libc::MSG_NOSIGNAL,
375 &sockaddr as *const _ as *const _,
376 core::mem::size_of::<libc::sockaddr_ll>() as _,
377 )
378 })?;
379 Ok(ret as usize)
380 }));
381
382 let len = fut.await?;
383
384 assert_eq!(len, data.len());
385
386 Ok(())
387 }
388
389 async fn receive_into(
390 &mut self,
391 buffer: &mut [u8],
392 ) -> Result<(usize, [u8; 6]), Self::Error> {
393 let fut = pin!(self.0.read_with(|io| {
394 let mut storage: libc::sockaddr_storage = unsafe { core::mem::zeroed() };
395 let mut addrlen = core::mem::size_of_val(&storage) as libc::socklen_t;
396
397 let ret = cvti(unsafe {
398 libc::recvfrom(
399 io.as_fd().as_raw_fd(),
400 buffer.as_mut_ptr() as *mut _,
401 buffer.len(),
402 0,
403 &mut storage as *mut _ as *mut _,
404 &mut addrlen,
405 )
406 })?;
407
408 let sockaddr = as_sockaddr_ll(&storage, addrlen as usize)?;
409
410 let mut mac = [0; 6];
411 mac.copy_from_slice(&sockaddr.sll_addr[..6]);
412
413 Ok((ret as usize, mac))
414 }));
415
416 fut.await
417 }
418 }
419
420 impl RawStack for Stack {
421 type Error = io::Error;
422
423 type Socket = StdRawSocket;
424
425 async fn bind(&self, interface: u32) -> Result<Self::Socket, Self::Error> {
426 let socket = cvt(unsafe {
427 libc::socket(
428 libc::PF_PACKET,
429 libc::SOCK_DGRAM,
430 (libc::ETH_P_IP as u16).to_be() as _,
431 )
432 })?;
433
434 let sockaddr = libc::sockaddr_ll {
435 sll_family: libc::AF_PACKET as _,
436 sll_protocol: (libc::ETH_P_IP as u16).to_be() as _,
437 sll_ifindex: interface as _,
438 sll_hatype: 0,
439 sll_pkttype: 0,
440 sll_halen: 0,
441 sll_addr: Default::default(),
442 };
443
444 cvt(unsafe {
445 libc::bind(
446 socket,
447 &sockaddr as *const _ as *const _,
448 core::mem::size_of::<libc::sockaddr_ll>() as _,
449 )
450 })?;
451
452 let socket = {
458 use std::os::fd::FromRawFd;
459
460 unsafe { std::net::UdpSocket::from_raw_fd(socket) }
461 };
462
463 socket.set_broadcast(true)?;
464
465 Ok(StdRawSocket(Async::new(socket)?, interface as _))
466 }
467 }
468
469 fn as_sockaddr_ll(
470 storage: &libc::sockaddr_storage,
471 len: usize,
472 ) -> io::Result<&libc::sockaddr_ll> {
473 match storage.ss_family as core::ffi::c_int {
474 libc::AF_PACKET => {
475 assert!(len >= core::mem::size_of::<libc::sockaddr_ll>());
476 Ok(unsafe { (storage as *const _ as *const libc::sockaddr_ll).as_ref() }.unwrap())
477 }
478 _ => Err(io::Error::new(ErrorKind::InvalidInput, "invalid argument")),
479 }
480 }
481
482 fn cvt<T>(res: T) -> io::Result<T>
483 where
484 T: Into<i64> + Copy,
485 {
486 let ires: i64 = res.into();
487
488 if ires == -1 {
489 Err(io::Error::last_os_error())
490 } else {
491 Ok(res)
492 }
493 }
494
495 fn cvti<T>(res: T) -> io::Result<T>
496 where
497 T: Into<isize> + Copy,
498 {
499 let ires: isize = res.into();
500
501 if ires == -1 {
502 Err(io::Error::last_os_error())
503 } else {
504 Ok(res)
505 }
506 }
507}
508
509pub fn to_std_addr(addr: SocketAddr) -> std::net::SocketAddr {
510 match addr {
511 SocketAddr::V4(addr) => net::SocketAddr::V4(net::SocketAddrV4::new(
512 addr.ip().octets().into(),
513 addr.port(),
514 )),
515 SocketAddr::V6(addr) => net::SocketAddr::V6(net::SocketAddrV6::new(
516 addr.ip().octets().into(),
517 addr.port(),
518 addr.flowinfo(),
519 addr.scope_id(),
520 )),
521 }
522}
523
524pub fn to_nal_addr(addr: std::net::SocketAddr) -> SocketAddr {
525 match addr {
526 net::SocketAddr::V4(addr) => {
527 SocketAddr::V4(SocketAddrV4::new(addr.ip().octets().into(), addr.port()))
528 }
529 net::SocketAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(
530 addr.ip().octets().into(),
531 addr.port(),
532 addr.flowinfo(),
533 addr.scope_id(),
534 )),
535 }
536}
537
538pub fn to_std_ipv4_addr(addr: Ipv4Addr) -> std::net::Ipv4Addr {
539 addr.octets().into()
540}
541
542pub fn to_nal_ipv4_addr(addr: std::net::Ipv4Addr) -> Ipv4Addr {
543 addr.octets().into()
544}