erbium_net/
socket.rs

1/*   Copyright 2023 Perry Lorier
2 *
3 *  Licensed under the Apache License, Version 2.0 (the "License");
4 *  you may not use this file except in compliance with the License.
5 *  You may obtain a copy of the License at
6 *
7 *      http://www.apache.org/licenses/LICENSE-2.0
8 *
9 *  Unless required by applicable law or agreed to in writing, software
10 *  distributed under the License is distributed on an "AS IS" BASIS,
11 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 *  See the License for the specific language governing permissions and
13 *  limitations under the License.
14 *
15 *  SPDX-License-Identifier: Apache-2.0
16 *
17 *  Common Socket Traits
18 */
19
20use crate::addr::NetAddr;
21use std::os::unix::io::{OwnedFd, RawFd};
22
23pub fn std_to_libc_in_addr(addr: std::net::Ipv4Addr) -> libc::in_addr {
24    libc::in_addr {
25        s_addr: addr
26            .octets()
27            .iter()
28            .fold(0, |acc, x| ((acc << 8) | (*x as u32))),
29    }
30}
31
32pub const fn std_to_libc_in6_addr(addr: std::net::Ipv6Addr) -> libc::in6_addr {
33    libc::in6_addr {
34        s6_addr: addr.octets(),
35    }
36}
37
38pub type MsgFlags = nix::sys::socket::MsgFlags;
39pub use std::io::{IoSlice, IoSliceMut};
40
41use nix::libc;
42
43#[derive(Debug)]
44pub struct ControlMessage {
45    pub send_from: Option<std::net::IpAddr>,
46    /* private, used to hold memory after conversions */
47    pktinfo4: libc::in_pktinfo,
48    pktinfo6: libc::in6_pktinfo,
49}
50
51impl ControlMessage {
52    pub fn new() -> Self {
53        Self {
54            send_from: None,
55            pktinfo4: libc::in_pktinfo {
56                ipi_ifindex: 0, /* Unspecified interface */
57                ipi_addr: std_to_libc_in_addr(std::net::Ipv4Addr::UNSPECIFIED),
58                ipi_spec_dst: std_to_libc_in_addr(std::net::Ipv4Addr::UNSPECIFIED),
59            },
60            pktinfo6: libc::in6_pktinfo {
61                ipi6_ifindex: 0, /* Unspecified interface */
62                ipi6_addr: std_to_libc_in6_addr(std::net::Ipv6Addr::UNSPECIFIED),
63            },
64        }
65    }
66    #[must_use]
67    pub const fn set_send_from(mut self, send_from: Option<std::net::IpAddr>) -> Self {
68        self.send_from = send_from;
69        self
70    }
71    #[must_use]
72    pub const fn set_src4_intf(mut self, intf: u32) -> Self {
73        self.pktinfo4.ipi_ifindex = intf as i32;
74        self
75    }
76    #[must_use]
77    pub const fn set_src6_intf(mut self, intf: u32) -> Self {
78        self.pktinfo6.ipi6_ifindex = intf;
79        self
80    }
81    pub fn convert_to_cmsg(&mut self) -> Vec<nix::sys::socket::ControlMessage> {
82        let mut cmsgs: Vec<nix::sys::socket::ControlMessage> = vec![];
83
84        if let Some(addr) = self.send_from {
85            match addr {
86                std::net::IpAddr::V4(ip) => {
87                    self.pktinfo4.ipi_spec_dst = std_to_libc_in_addr(ip);
88                    cmsgs.push(nix::sys::socket::ControlMessage::Ipv4PacketInfo(
89                        &self.pktinfo4,
90                    ))
91                }
92                std::net::IpAddr::V6(ip) => {
93                    self.pktinfo6.ipi6_addr = std_to_libc_in6_addr(ip);
94                    cmsgs.push(nix::sys::socket::ControlMessage::Ipv6PacketInfo(
95                        &self.pktinfo6,
96                    ))
97                }
98            }
99        }
100
101        cmsgs
102    }
103}
104
105impl Default for ControlMessage {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[derive(Debug)]
112pub struct RecvMsg {
113    pub buffer: Vec<u8>,
114    pub address: Option<NetAddr>,
115    /* TODO: These should probably return std types */
116    /* Or possibly have accessors that convert them for you */
117    /* either way, we shouldn't be exporting nix types here */
118    timestamp: Option<nix::sys::time::TimeVal>,
119    ipv4pktinfo: Option<libc::in_pktinfo>,
120    ipv6pktinfo: Option<libc::in6_pktinfo>,
121}
122
123impl RecvMsg {
124    fn new(m: nix::sys::socket::RecvMsg<NetAddr>, buffer: Vec<u8>) -> RecvMsg {
125        let mut r = RecvMsg {
126            buffer,
127            address: m.address,
128            timestamp: None,
129            ipv4pktinfo: None,
130            ipv6pktinfo: None,
131        };
132
133        for cmsg in m.cmsgs() {
134            use nix::sys::socket::ControlMessageOwned;
135            match cmsg {
136                ControlMessageOwned::ScmTimestamp(rtime) => {
137                    r.timestamp = Some(rtime);
138                }
139                ControlMessageOwned::Ipv4PacketInfo(pi) => {
140                    r.ipv4pktinfo = Some(pi);
141                }
142                ControlMessageOwned::Ipv6PacketInfo(pi) => {
143                    r.ipv6pktinfo = Some(pi);
144                }
145                x => log::warn!("Unknown control message {:?}", x),
146            }
147        }
148
149        r
150    }
151
152    /// Returns the local address of the packet.
153    ///
154    /// This is primarily used by UDP sockets to tell you which address a packet arrived on when
155    /// the UDP socket is bound to INADDR_ANY or IN6ADDR_ANY.
156    pub const fn local_ip(&self) -> Option<std::net::IpAddr> {
157        // This function can be overridden to provide different implementations for different
158        // platforms.
159        //
160        if let Some(pi) = self.ipv6pktinfo {
161            // Oh come on, this conversion is even more ridiculous than the last one!
162            Some(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
163                (pi.ipi6_addr.s6_addr[0] as u16) << 8 | (pi.ipi6_addr.s6_addr[1] as u16),
164                (pi.ipi6_addr.s6_addr[2] as u16) << 8 | (pi.ipi6_addr.s6_addr[3] as u16),
165                (pi.ipi6_addr.s6_addr[4] as u16) << 8 | (pi.ipi6_addr.s6_addr[5] as u16),
166                (pi.ipi6_addr.s6_addr[6] as u16) << 8 | (pi.ipi6_addr.s6_addr[7] as u16),
167                (pi.ipi6_addr.s6_addr[8] as u16) << 8 | (pi.ipi6_addr.s6_addr[9] as u16),
168                (pi.ipi6_addr.s6_addr[10] as u16) << 8 | (pi.ipi6_addr.s6_addr[11] as u16),
169                (pi.ipi6_addr.s6_addr[12] as u16) << 8 | (pi.ipi6_addr.s6_addr[13] as u16),
170                (pi.ipi6_addr.s6_addr[14] as u16) << 8 | (pi.ipi6_addr.s6_addr[15] as u16),
171            )))
172        } else if let Some(pi) = self.ipv4pktinfo {
173            let ip = pi.ipi_addr.s_addr.to_ne_bytes(); // This is already in big endian form, don't try and perform a conversion.
174                                                       // It is a pity I haven't found a nicer way to do this conversion.
175            Some(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
176                ip[0], ip[1], ip[2], ip[3],
177            )))
178        } else {
179            None
180        }
181    }
182
183    pub fn local_intf(&self) -> Option<i32> {
184        if let Some(pi) = self.ipv6pktinfo {
185            Some(pi.ipi6_ifindex as i32)
186        } else {
187            self.ipv4pktinfo.map(|pi| pi.ipi_ifindex)
188        }
189    }
190}
191
192#[derive(Debug)]
193pub struct SocketFd {
194    fd: OwnedFd,
195}
196
197impl std::os::unix::io::AsRawFd for SocketFd {
198    fn as_raw_fd(&self) -> RawFd {
199        self.fd.as_raw_fd()
200    }
201}
202
203pub fn new_socket(
204    domain: libc::c_int,
205    ty: libc::c_int,
206    protocol: libc::c_int,
207) -> Result<SocketFd, std::io::Error> {
208    // I would love to use the nix socket() wrapper, except, uh, it has a closed enum.
209    // See https://github.com/nix-rust/nix/issues/854
210    //
211    // So I have to use the libc version directly.
212    unsafe {
213        use std::os::unix::io::FromRawFd as _;
214        let fd = libc::socket(
215            domain,
216            ty | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
217            protocol,
218        );
219        if fd == -1 {
220            return Err(std::io::Error::last_os_error());
221        }
222        Ok(SocketFd {
223            fd: OwnedFd::from_raw_fd(fd as RawFd),
224        })
225    }
226}
227
228pub async fn recv_msg<F: std::os::unix::io::AsRawFd>(
229    sock: &tokio::io::unix::AsyncFd<F>,
230    bufsize: usize,
231    flags: MsgFlags,
232) -> Result<RecvMsg, std::io::Error> {
233    let mut ev = sock.readable().await?;
234
235    let mut buf = Vec::new();
236    buf.resize_with(bufsize, Default::default);
237    let iov = &mut [IoSliceMut::new(buf.as_mut_slice())];
238
239    let mut cmsg = Vec::new();
240    cmsg.resize_with(65536, Default::default); /* TODO: Calculate a more reasonable size */
241
242    let mut flags = flags;
243    flags.set(MsgFlags::MSG_DONTWAIT, true);
244
245    match nix::sys::socket::recvmsg(sock.get_ref().as_raw_fd(), iov, Some(&mut cmsg), flags) {
246        Ok(rm) => {
247            let buf = rm.iovs().next().unwrap();
248            ev.retain_ready();
249            Ok(RecvMsg::new(rm, buf.into()))
250        }
251        Err(e) if e == nix::errno::Errno::EAGAIN => {
252            ev.clear_ready();
253            Err(e.into())
254        }
255        Err(e) => {
256            ev.retain_ready();
257            Err(e.into())
258        }
259    }
260}
261
262/// This function makes an async stream out of calling recv_msg repeatedly.
263pub fn recv_msg_stream<F>(
264    sock: &tokio::io::unix::AsyncFd<F>,
265    bufsize: usize,
266    flags: MsgFlags,
267) -> impl futures::stream::Stream<Item = Result<RecvMsg, std::io::Error>> + '_
268where
269    F: std::os::unix::io::AsRawFd,
270{
271    futures::stream::unfold((), move |()| async move {
272        Some((recv_msg::<F>(sock, bufsize, flags).await, ()))
273    })
274}
275
276pub async fn send_msg<F: std::os::unix::io::AsRawFd>(
277    sock: &tokio::io::unix::AsyncFd<F>,
278    buffer: &[u8],
279    cmsg: &ControlMessage,
280    flags: MsgFlags,
281    from: Option<&NetAddr>,
282) -> std::io::Result<()> {
283    let mut ev = sock.writable().await?;
284
285    let iov = &[IoSlice::new(buffer)];
286    let mut cmsgs: Vec<nix::sys::socket::ControlMessage> = vec![];
287    let mut in_pktinfo = cmsg.pktinfo4;
288    let mut in6_pktinfo = cmsg.pktinfo6;
289
290    if let Some(addr) = cmsg.send_from {
291        match addr {
292            std::net::IpAddr::V4(ip) => {
293                in_pktinfo.ipi_spec_dst = std_to_libc_in_addr(ip);
294                cmsgs.push(nix::sys::socket::ControlMessage::Ipv4PacketInfo(
295                    &in_pktinfo,
296                ))
297            }
298            std::net::IpAddr::V6(ip) => {
299                in6_pktinfo.ipi6_addr = std_to_libc_in6_addr(ip);
300                cmsgs.push(nix::sys::socket::ControlMessage::Ipv6PacketInfo(
301                    &in6_pktinfo,
302                ))
303            }
304        }
305    } else if in6_pktinfo.ipi6_ifindex != 0 {
306        cmsgs.push(nix::sys::socket::ControlMessage::Ipv6PacketInfo(
307            &in6_pktinfo,
308        ));
309    } else if in_pktinfo.ipi_ifindex != 0 {
310        cmsgs.push(nix::sys::socket::ControlMessage::Ipv4PacketInfo(
311            &in_pktinfo,
312        ));
313    }
314
315    match nix::sys::socket::sendmsg(sock.get_ref().as_raw_fd(), iov, &cmsgs, flags, from) {
316        Ok(_) => {
317            ev.retain_ready();
318            Ok(())
319        }
320        Err(nix::errno::Errno::EAGAIN) => {
321            ev.clear_ready();
322            Err(nix::errno::Errno::EAGAIN.into())
323        }
324        Err(e) => {
325            ev.retain_ready();
326            Err(e.into())
327        }
328    }
329}
330
331pub fn set_ipv6_unicast_hoplimit(fd: RawFd, val: i32) -> Result<(), nix::Error> {
332    unsafe {
333        let res = libc::setsockopt(
334            fd,
335            libc::IPPROTO_IPV6,
336            libc::IPV6_UNICAST_HOPS,
337            &val as *const i32 as *const libc::c_void,
338            std::mem::size_of::<i32>() as libc::socklen_t,
339        );
340        nix::errno::Errno::result(res).map(drop)
341    }
342}
343
344pub fn set_ipv6_multicast_hoplimit(fd: RawFd, val: i32) -> Result<(), nix::Error> {
345    unsafe {
346        let res = libc::setsockopt(
347            fd,
348            libc::IPPROTO_IPV6,
349            libc::IPV6_MULTICAST_HOPS,
350            &val as *const i32 as *const libc::c_void,
351            std::mem::size_of::<i32>() as libc::socklen_t,
352        );
353        nix::errno::Errno::result(res).map(drop)
354    }
355}