edge_std_nal_async/
lib.rs

1#![allow(async_fn_in_trait)]
2#![warn(clippy::large_futures)]
3
4use core::pin::pin;
5
6use std::io;
7use std::net::{self, TcpStream, ToSocketAddrs, UdpSocket};
8
9use async_io::Async;
10use futures_lite::io::{AsyncReadExt, AsyncWriteExt};
11
12use embedded_io_async::{ErrorType, Read, Write};
13
14use embedded_nal_async::{
15    AddrType, ConnectedUdp, Dns, IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6,
16    TcpConnect, UdpStack, UnconnectedUdp,
17};
18
19use embedded_nal_async_xtra::{Multicast, TcpAccept, TcpListen, TcpSplittableConnection};
20
21#[cfg(all(unix, not(target_os = "espidf")))]
22pub use raw::*;
23
24#[derive(Default)]
25pub struct Stack(());
26
27impl Stack {
28    pub const fn new() -> Self {
29        Self(())
30    }
31}
32
33impl TcpConnect for Stack {
34    type Error = io::Error;
35
36    type Connection<'a> = StdTcpConnection where Self: 'a;
37
38    async fn connect(&self, remote: SocketAddr) -> Result<Self::Connection<'_>, Self::Error> {
39        let connection = Async::<TcpStream>::connect(to_std_addr(remote)).await?;
40
41        Ok(StdTcpConnection(connection))
42    }
43}
44
45impl TcpListen for Stack {
46    type Error = io::Error;
47
48    type Acceptor<'m>
49    = StdTcpAccept where Self: 'm;
50
51    async fn listen(&self, remote: SocketAddr) -> Result<Self::Acceptor<'_>, Self::Error> {
52        Async::<net::TcpListener>::bind(to_std_addr(remote)).map(StdTcpAccept)
53    }
54}
55
56pub struct StdTcpAccept(Async<net::TcpListener>);
57
58impl TcpAccept for StdTcpAccept {
59    type Error = io::Error;
60
61    type Connection<'m> = StdTcpConnection;
62
63    #[cfg(not(target_os = "espidf"))]
64    async fn accept(&self) -> Result<Self::Connection<'_>, Self::Error> {
65        let connection = self.0.accept().await.map(|(socket, _)| socket)?;
66
67        Ok(StdTcpConnection(connection))
68    }
69
70    #[cfg(target_os = "espidf")]
71    async fn accept(&self) -> Result<Self::Connection<'_>, Self::Error> {
72        // ESP IDF (lwIP actually) does not really support `select`-ing on
73        // socket accept: https://groups.google.com/g/osdeve_mirror_tcpip_lwip/c/Vsz7SVa6a2M
74        //
75        // If we do this, `select` would block and not return with our accepting socket `fd`
76        // marked as ready even if our accepting socket has a pending connection.
77        //
78        // (Note also that since the time when the above link was posted on the internet,
79        // the lwIP `accept` API has improved a bit in that it would now return `EWOULDBLOCK`
80        // instead of blocking indefinitely
81        // - and we take advantage of that in the "async" implementation below.)
82        //
83        // The workaround below is not ideal in that
84        // it uses a timer to poll the socket, but it avoids spinning a hidden,
85        // separate thread just to accept connections - which would be the alternative.
86        loop {
87            match self.0.as_ref().accept() {
88                Ok((connection, _)) => break Ok(StdTcpConnection(Async::new(connection)?)),
89                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
90                    async_io::Timer::after(core::time::Duration::from_millis(5)).await;
91                }
92                Err(err) => break Err(err),
93            }
94        }
95    }
96}
97
98pub struct StdTcpConnection(Async<TcpStream>);
99
100impl ErrorType for StdTcpConnection {
101    type Error = io::Error;
102}
103
104impl Read for StdTcpConnection {
105    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
106        self.0.read(buf).await
107    }
108}
109
110impl Write for StdTcpConnection {
111    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
112        self.0.write(buf).await
113    }
114
115    async fn flush(&mut self) -> Result<(), Self::Error> {
116        self.0.flush().await
117    }
118}
119
120impl ErrorType for &StdTcpConnection {
121    type Error = io::Error;
122}
123
124impl Read for &StdTcpConnection {
125    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
126        (&self.0).read(buf).await
127    }
128}
129
130impl Write for &StdTcpConnection {
131    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
132        (&self.0).write(buf).await
133    }
134
135    async fn flush(&mut self) -> Result<(), Self::Error> {
136        (&self.0).flush().await
137    }
138}
139
140impl TcpSplittableConnection for StdTcpConnection {
141    type Read<'a> = &'a StdTcpConnection where Self: 'a;
142
143    type Write<'a> = &'a StdTcpConnection where Self: 'a;
144
145    fn split(&mut self) -> Result<(Self::Read<'_>, Self::Write<'_>), io::Error> {
146        let socket = &*self;
147
148        Ok((socket, socket))
149    }
150}
151
152impl UdpStack for Stack {
153    type Error = io::Error;
154
155    type Connected = StdUdpSocket;
156
157    type UniquelyBound = StdUdpSocket;
158
159    type MultiplyBound = StdUdpSocket;
160
161    async fn connect_from(
162        &self,
163        local: SocketAddr,
164        remote: SocketAddr,
165    ) -> Result<(SocketAddr, Self::Connected), Self::Error> {
166        let socket = Async::<UdpSocket>::bind(to_std_addr(local))?;
167
168        socket.as_ref().connect(to_std_addr(remote))?;
169
170        Ok((
171            to_nal_addr(socket.as_ref().local_addr()?),
172            StdUdpSocket(socket),
173        ))
174    }
175
176    async fn bind_single(
177        &self,
178        local: SocketAddr,
179    ) -> Result<(SocketAddr, Self::UniquelyBound), Self::Error> {
180        let socket = Async::<UdpSocket>::bind(to_std_addr(local))?;
181
182        socket.as_ref().set_broadcast(true)?;
183
184        Ok((
185            to_nal_addr(socket.as_ref().local_addr()?),
186            StdUdpSocket(socket),
187        ))
188    }
189
190    async fn bind_multiple(&self, _local: SocketAddr) -> Result<Self::MultiplyBound, Self::Error> {
191        unimplemented!() // TODO
192    }
193}
194
195pub struct StdUdpSocket(Async<UdpSocket>);
196
197impl ConnectedUdp for StdUdpSocket {
198    type Error = io::Error;
199
200    async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> {
201        let mut offset = 0;
202
203        loop {
204            let fut = pin!(self.0.send(&data[offset..]));
205            offset += fut.await?;
206
207            if offset == data.len() {
208                break;
209            }
210        }
211
212        Ok(())
213    }
214
215    async fn receive_into(&mut self, buffer: &mut [u8]) -> Result<usize, Self::Error> {
216        let fut = pin!(self.0.recv(buffer));
217        fut.await
218    }
219}
220
221impl UnconnectedUdp for StdUdpSocket {
222    type Error = io::Error;
223
224    async fn send(
225        &mut self,
226        local: SocketAddr,
227        remote: SocketAddr,
228        data: &[u8],
229    ) -> Result<(), Self::Error> {
230        assert!(local == to_nal_addr(self.0.as_ref().local_addr()?));
231
232        let mut offset = 0;
233
234        loop {
235            let fut = pin!(self.0.send_to(data, to_std_addr(remote)));
236            offset += fut.await?;
237
238            if offset == data.len() {
239                break;
240            }
241        }
242
243        Ok(())
244    }
245
246    async fn receive_into(
247        &mut self,
248        buffer: &mut [u8],
249    ) -> Result<(usize, SocketAddr, SocketAddr), Self::Error> {
250        let fut = pin!(self.0.recv_from(buffer));
251        let (len, addr) = fut.await?;
252
253        Ok((
254            len,
255            to_nal_addr(self.0.as_ref().local_addr()?),
256            to_nal_addr(addr),
257        ))
258    }
259}
260
261impl Multicast for StdUdpSocket {
262    type Error = io::Error;
263
264    async fn join(&mut self, multicast_addr: IpAddr) -> Result<(), Self::Error> {
265        match multicast_addr {
266            IpAddr::V4(addr) => self
267                .0
268                .as_ref()
269                .join_multicast_v4(&addr.octets().into(), &std::net::Ipv4Addr::UNSPECIFIED)?,
270            IpAddr::V6(addr) => self
271                .0
272                .as_ref()
273                .join_multicast_v6(&addr.octets().into(), 0)?,
274        }
275
276        Ok(())
277    }
278
279    async fn leave(&mut self, multicast_addr: IpAddr) -> Result<(), Self::Error> {
280        match multicast_addr {
281            IpAddr::V4(addr) => self
282                .0
283                .as_ref()
284                .leave_multicast_v4(&addr.octets().into(), &std::net::Ipv4Addr::UNSPECIFIED)?,
285            IpAddr::V6(addr) => self
286                .0
287                .as_ref()
288                .leave_multicast_v6(&addr.octets().into(), 0)?,
289        }
290
291        Ok(())
292    }
293}
294
295impl Dns for Stack {
296    type Error = io::Error;
297
298    async fn get_host_by_name(
299        &self,
300        host: &str,
301        addr_type: AddrType,
302    ) -> Result<IpAddr, Self::Error> {
303        let host = host.to_string();
304
305        dns_lookup_host(&host, addr_type)
306    }
307
308    async fn get_host_by_address(
309        &self,
310        _addr: IpAddr,
311        _result: &mut [u8],
312    ) -> Result<usize, Self::Error> {
313        Err(io::ErrorKind::Unsupported.into())
314    }
315}
316
317fn dns_lookup_host(host: &str, addr_type: AddrType) -> Result<IpAddr, io::Error> {
318    (host, 0_u16)
319        .to_socket_addrs()?
320        .find(|addr| match addr_type {
321            AddrType::IPv4 => matches!(addr, std::net::SocketAddr::V4(_)),
322            AddrType::IPv6 => matches!(addr, std::net::SocketAddr::V6(_)),
323            AddrType::Either => true,
324        })
325        .map(|addr| match addr {
326            std::net::SocketAddr::V4(v4) => v4.ip().octets().into(),
327            std::net::SocketAddr::V6(v6) => v6.ip().octets().into(),
328        })
329        .ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
330}
331
332#[cfg(all(unix, not(target_os = "espidf")))]
333mod raw {
334    use core::pin::pin;
335
336    use std::io::{self, ErrorKind};
337    use std::os::fd::{AsFd, AsRawFd};
338
339    use async_io::Async;
340
341    use embedded_nal_async_xtra::{RawSocket, RawStack};
342
343    use crate::Stack;
344
345    pub struct StdRawSocket(Async<std::net::UdpSocket>, u32);
346
347    impl RawSocket for StdRawSocket {
348        type Error = io::Error;
349
350        async fn send(&mut self, mac: Option<&[u8; 6]>, data: &[u8]) -> Result<(), Self::Error> {
351            let mut sockaddr = libc::sockaddr_ll {
352                sll_family: libc::AF_PACKET as _,
353                sll_protocol: (libc::ETH_P_IP as u16).to_be() as _,
354                sll_ifindex: self.1 as _,
355                sll_hatype: 0,
356                sll_pkttype: 0,
357                sll_halen: 0,
358                sll_addr: Default::default(),
359            };
360
361            if let Some(mac) = mac {
362                sockaddr.sll_halen = mac.len() as _;
363                sockaddr.sll_addr[..mac.len()].copy_from_slice(mac);
364            }
365
366            let fut = pin!(self.0.write_with(|io| {
367                let len = core::cmp::min(data.len(), u16::MAX as usize);
368
369                let ret = cvti(unsafe {
370                    libc::sendto(
371                        io.as_fd().as_raw_fd(),
372                        data.as_ptr() as *const _,
373                        len,
374                        libc::MSG_NOSIGNAL,
375                        &sockaddr as *const _ as *const _,
376                        core::mem::size_of::<libc::sockaddr_ll>() as _,
377                    )
378                })?;
379                Ok(ret as usize)
380            }));
381
382            let len = fut.await?;
383
384            assert_eq!(len, data.len());
385
386            Ok(())
387        }
388
389        async fn receive_into(
390            &mut self,
391            buffer: &mut [u8],
392        ) -> Result<(usize, [u8; 6]), Self::Error> {
393            let fut = pin!(self.0.read_with(|io| {
394                let mut storage: libc::sockaddr_storage = unsafe { core::mem::zeroed() };
395                let mut addrlen = core::mem::size_of_val(&storage) as libc::socklen_t;
396
397                let ret = cvti(unsafe {
398                    libc::recvfrom(
399                        io.as_fd().as_raw_fd(),
400                        buffer.as_mut_ptr() as *mut _,
401                        buffer.len(),
402                        0,
403                        &mut storage as *mut _ as *mut _,
404                        &mut addrlen,
405                    )
406                })?;
407
408                let sockaddr = as_sockaddr_ll(&storage, addrlen as usize)?;
409
410                let mut mac = [0; 6];
411                mac.copy_from_slice(&sockaddr.sll_addr[..6]);
412
413                Ok((ret as usize, mac))
414            }));
415
416            fut.await
417        }
418    }
419
420    impl RawStack for Stack {
421        type Error = io::Error;
422
423        type Socket = StdRawSocket;
424
425        async fn bind(&self, interface: u32) -> Result<Self::Socket, Self::Error> {
426            let socket = cvt(unsafe {
427                libc::socket(
428                    libc::PF_PACKET,
429                    libc::SOCK_DGRAM,
430                    (libc::ETH_P_IP as u16).to_be() as _,
431                )
432            })?;
433
434            let sockaddr = libc::sockaddr_ll {
435                sll_family: libc::AF_PACKET as _,
436                sll_protocol: (libc::ETH_P_IP as u16).to_be() as _,
437                sll_ifindex: interface as _,
438                sll_hatype: 0,
439                sll_pkttype: 0,
440                sll_halen: 0,
441                sll_addr: Default::default(),
442            };
443
444            cvt(unsafe {
445                libc::bind(
446                    socket,
447                    &sockaddr as *const _ as *const _,
448                    core::mem::size_of::<libc::sockaddr_ll>() as _,
449                )
450            })?;
451
452            // TODO
453            // cvt(unsafe {
454            //     libc::setsockopt(socket, libc::SOL_PACKET, libc::PACKET_AUXDATA, &1_u32 as *const _ as *const _, 4)
455            // })?;
456
457            let socket = {
458                use std::os::fd::FromRawFd;
459
460                unsafe { std::net::UdpSocket::from_raw_fd(socket) }
461            };
462
463            socket.set_broadcast(true)?;
464
465            Ok(StdRawSocket(Async::new(socket)?, interface as _))
466        }
467    }
468
469    fn as_sockaddr_ll(
470        storage: &libc::sockaddr_storage,
471        len: usize,
472    ) -> io::Result<&libc::sockaddr_ll> {
473        match storage.ss_family as core::ffi::c_int {
474            libc::AF_PACKET => {
475                assert!(len >= core::mem::size_of::<libc::sockaddr_ll>());
476                Ok(unsafe { (storage as *const _ as *const libc::sockaddr_ll).as_ref() }.unwrap())
477            }
478            _ => Err(io::Error::new(ErrorKind::InvalidInput, "invalid argument")),
479        }
480    }
481
482    fn cvt<T>(res: T) -> io::Result<T>
483    where
484        T: Into<i64> + Copy,
485    {
486        let ires: i64 = res.into();
487
488        if ires == -1 {
489            Err(io::Error::last_os_error())
490        } else {
491            Ok(res)
492        }
493    }
494
495    fn cvti<T>(res: T) -> io::Result<T>
496    where
497        T: Into<isize> + Copy,
498    {
499        let ires: isize = res.into();
500
501        if ires == -1 {
502            Err(io::Error::last_os_error())
503        } else {
504            Ok(res)
505        }
506    }
507}
508
509pub fn to_std_addr(addr: SocketAddr) -> std::net::SocketAddr {
510    match addr {
511        SocketAddr::V4(addr) => net::SocketAddr::V4(net::SocketAddrV4::new(
512            addr.ip().octets().into(),
513            addr.port(),
514        )),
515        SocketAddr::V6(addr) => net::SocketAddr::V6(net::SocketAddrV6::new(
516            addr.ip().octets().into(),
517            addr.port(),
518            addr.flowinfo(),
519            addr.scope_id(),
520        )),
521    }
522}
523
524pub fn to_nal_addr(addr: std::net::SocketAddr) -> SocketAddr {
525    match addr {
526        net::SocketAddr::V4(addr) => {
527            SocketAddr::V4(SocketAddrV4::new(addr.ip().octets().into(), addr.port()))
528        }
529        net::SocketAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(
530            addr.ip().octets().into(),
531            addr.port(),
532            addr.flowinfo(),
533            addr.scope_id(),
534        )),
535    }
536}
537
538pub fn to_std_ipv4_addr(addr: Ipv4Addr) -> std::net::Ipv4Addr {
539    addr.octets().into()
540}
541
542pub fn to_nal_ipv4_addr(addr: std::net::Ipv4Addr) -> Ipv4Addr {
543    addr.octets().into()
544}