Skip to main content

nex_socket/udp/
sync_impl.rs

1use crate::udp::UdpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::IpAddr;
5use std::net::{SocketAddr, UdpSocket as StdUdpSocket};
6
7/// Synchronous low level UDP socket.
8#[derive(Debug)]
9pub struct UdpSocket {
10    socket: Socket,
11}
12
13/// Metadata returned from `recv_msg`.
14#[derive(Clone, Debug, Eq, PartialEq)]
15pub struct UdpRecvMeta {
16    /// Number of bytes received into the data buffer.
17    pub bytes_read: usize,
18    /// Source address of the datagram.
19    pub source_addr: SocketAddr,
20    /// Destination address that received the datagram, if provided by ancillary data.
21    pub destination_addr: Option<IpAddr>,
22    /// Interface index on which the datagram was received, if provided.
23    pub interface_index: Option<u32>,
24}
25
26/// Optional metadata used by `send_msg`.
27#[derive(Clone, Debug, Default, Eq, PartialEq)]
28pub struct UdpSendMeta {
29    /// Explicit source IP address to use for transmission when supported.
30    pub source_addr: Option<IpAddr>,
31    /// Explicit outgoing interface index when supported.
32    pub interface_index: Option<u32>,
33}
34
35impl UdpSocket {
36    /// Create a socket from the provided configuration.
37    pub fn from_config(config: &UdpConfig) -> io::Result<Self> {
38        let socket = Socket::new(
39            config.socket_family.to_domain(),
40            config.socket_type.to_sock_type(),
41            Some(Protocol::UDP),
42        )?;
43
44        socket.set_nonblocking(false)?;
45
46        // Set socket options based on configuration
47        if let Some(flag) = config.reuseaddr {
48            socket.set_reuse_address(flag)?;
49        }
50        #[cfg(any(
51            target_os = "android",
52            target_os = "dragonfly",
53            target_os = "freebsd",
54            target_os = "fuchsia",
55            target_os = "ios",
56            target_os = "linux",
57            target_os = "macos",
58            target_os = "netbsd",
59            target_os = "openbsd",
60            target_os = "tvos",
61            target_os = "visionos",
62            target_os = "watchos"
63        ))]
64        if let Some(flag) = config.reuseport {
65            socket.set_reuse_port(flag)?;
66        }
67        if let Some(flag) = config.broadcast {
68            socket.set_broadcast(flag)?;
69        }
70        if let Some(ttl) = config.ttl {
71            socket.set_ttl(ttl)?;
72        }
73        if let Some(hoplimit) = config.hoplimit {
74            socket.set_unicast_hops_v6(hoplimit)?;
75        }
76        if let Some(timeout) = config.read_timeout {
77            socket.set_read_timeout(Some(timeout))?;
78        }
79        if let Some(timeout) = config.write_timeout {
80            socket.set_write_timeout(Some(timeout))?;
81        }
82        if let Some(size) = config.recv_buffer_size {
83            socket.set_recv_buffer_size(size)?;
84        }
85        if let Some(size) = config.send_buffer_size {
86            socket.set_send_buffer_size(size)?;
87        }
88        if let Some(tos) = config.tos {
89            socket.set_tos(tos)?;
90        }
91        #[cfg(any(
92            target_os = "android",
93            target_os = "dragonfly",
94            target_os = "freebsd",
95            target_os = "fuchsia",
96            target_os = "ios",
97            target_os = "linux",
98            target_os = "macos",
99            target_os = "netbsd",
100            target_os = "openbsd",
101            target_os = "tvos",
102            target_os = "visionos",
103            target_os = "watchos"
104        ))]
105        if let Some(tclass) = config.tclass_v6 {
106            socket.set_tclass_v6(tclass)?;
107        }
108        if let Some(only_v6) = config.only_v6 {
109            socket.set_only_v6(only_v6)?;
110        }
111        if let Some(on) = config.recv_pktinfo {
112            crate::udp::set_recv_pktinfo(&socket, config.socket_family, on)?;
113        }
114
115        // Linux: optional interface name
116        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
117        if let Some(iface) = &config.bind_device {
118            socket.bind_device(Some(iface.as_bytes()))?;
119        }
120
121        // bind to the specified address if provided
122        if let Some(addr) = config.bind_addr {
123            socket.bind(&addr.into())?;
124        }
125
126        Ok(Self { socket })
127    }
128
129    /// Create a socket of arbitrary type (DGRAM or RAW).
130    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
131        let socket = Socket::new(domain, sock_type, Some(Protocol::UDP))?;
132        socket.set_nonblocking(false)?;
133        Ok(Self { socket })
134    }
135
136    /// Convenience constructor for IPv4 DGRAM.
137    pub fn v4_dgram() -> io::Result<Self> {
138        Self::new(Domain::IPV4, SockType::DGRAM)
139    }
140
141    /// Convenience constructor for IPv6 DGRAM.
142    pub fn v6_dgram() -> io::Result<Self> {
143        Self::new(Domain::IPV6, SockType::DGRAM)
144    }
145
146    /// IPv4 RAW UDP. Requires administrator privileges.
147    pub fn raw_v4() -> io::Result<Self> {
148        Self::new(Domain::IPV4, SockType::RAW)
149    }
150
151    /// IPv6 RAW UDP. Requires administrator privileges.
152    pub fn raw_v6() -> io::Result<Self> {
153        Self::new(Domain::IPV6, SockType::RAW)
154    }
155
156    /// Send data.
157    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
158        self.socket.send_to(buf, &target.into())
159    }
160
161    /// Send data with ancillary metadata (`sendmsg` on Unix).
162    ///
163    /// When supported by the current OS, source address and interface index are
164    /// propagated using packet-info control messages.
165    #[cfg(unix)]
166    pub fn send_msg(
167        &self,
168        buf: &[u8],
169        target: SocketAddr,
170        meta: Option<&UdpSendMeta>,
171    ) -> io::Result<usize> {
172        use nix::sys::socket::{ControlMessage, MsgFlags, SockaddrIn, SockaddrIn6, sendmsg};
173        use std::io::IoSlice;
174        use std::os::fd::AsRawFd;
175
176        let iov = [IoSlice::new(buf)];
177        let raw_fd = self.socket.as_raw_fd();
178
179        match target {
180            SocketAddr::V4(addr) => {
181                let sockaddr = SockaddrIn::from(addr);
182                #[cfg(any(
183                    target_os = "android",
184                    target_os = "linux",
185                    target_os = "netbsd",
186                    target_vendor = "apple"
187                ))]
188                {
189                    if let Some(meta) = meta {
190                        if meta.source_addr.is_some() || meta.interface_index.is_some() {
191                            if let Some(src) = meta.source_addr {
192                                if !src.is_ipv4() {
193                                    return Err(io::Error::new(
194                                        io::ErrorKind::InvalidInput,
195                                        "source_addr family does not match target",
196                                    ));
197                                }
198                            }
199                            let mut pktinfo: libc::in_pktinfo = unsafe { std::mem::zeroed() };
200                            if let Some(src) = meta.source_addr.and_then(|ip| match ip {
201                                IpAddr::V4(v4) => Some(v4),
202                                IpAddr::V6(_) => None,
203                            }) {
204                                pktinfo.ipi_spec_dst.s_addr = u32::from_ne_bytes(src.octets());
205                            }
206                            if let Some(ifindex) = meta.interface_index {
207                                pktinfo.ipi_ifindex = ifindex.try_into().map_err(|_| {
208                                    io::Error::new(
209                                        io::ErrorKind::InvalidInput,
210                                        "interface_index is out of range for this platform",
211                                    )
212                                })?;
213                            }
214                            let cmsgs = [ControlMessage::Ipv4PacketInfo(&pktinfo)];
215                            return sendmsg(
216                                raw_fd,
217                                &iov,
218                                &cmsgs,
219                                MsgFlags::empty(),
220                                Some(&sockaddr),
221                            )
222                            .map_err(|e| io::Error::from_raw_os_error(e as i32));
223                        }
224                    }
225                }
226                if let Some(meta) = meta {
227                    if meta.source_addr.is_some() || meta.interface_index.is_some() {
228                        return Err(io::Error::new(
229                            io::ErrorKind::Unsupported,
230                            "send_msg packet-info metadata is not supported on this platform",
231                        ));
232                    }
233                }
234                sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr))
235                    .map_err(|e| io::Error::from_raw_os_error(e as i32))
236            }
237            SocketAddr::V6(addr) => {
238                let sockaddr = SockaddrIn6::from(addr);
239                #[cfg(any(
240                    target_os = "android",
241                    target_os = "freebsd",
242                    target_os = "linux",
243                    target_os = "netbsd",
244                    target_vendor = "apple"
245                ))]
246                {
247                    if let Some(meta) = meta {
248                        if meta.source_addr.is_some() || meta.interface_index.is_some() {
249                            if let Some(src) = meta.source_addr {
250                                if !src.is_ipv6() {
251                                    return Err(io::Error::new(
252                                        io::ErrorKind::InvalidInput,
253                                        "source_addr family does not match target",
254                                    ));
255                                }
256                            }
257                            let mut pktinfo: libc::in6_pktinfo = unsafe { std::mem::zeroed() };
258                            if let Some(src) = meta.source_addr.and_then(|ip| match ip {
259                                IpAddr::V4(_) => None,
260                                IpAddr::V6(v6) => Some(v6),
261                            }) {
262                                pktinfo.ipi6_addr.s6_addr = src.octets();
263                            }
264                            if let Some(ifindex) = meta.interface_index {
265                                pktinfo.ipi6_ifindex = ifindex.try_into().map_err(|_| {
266                                    io::Error::new(
267                                        io::ErrorKind::InvalidInput,
268                                        "interface_index is out of range for this platform",
269                                    )
270                                })?;
271                            }
272                            let cmsgs = [ControlMessage::Ipv6PacketInfo(&pktinfo)];
273                            return sendmsg(
274                                raw_fd,
275                                &iov,
276                                &cmsgs,
277                                MsgFlags::empty(),
278                                Some(&sockaddr),
279                            )
280                            .map_err(|e| io::Error::from_raw_os_error(e as i32));
281                        }
282                    }
283                }
284                if let Some(meta) = meta {
285                    if meta.source_addr.is_some() || meta.interface_index.is_some() {
286                        return Err(io::Error::new(
287                            io::ErrorKind::Unsupported,
288                            "send_msg packet-info metadata is not supported on this platform",
289                        ));
290                    }
291                }
292                sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr))
293                    .map_err(|e| io::Error::from_raw_os_error(e as i32))
294            }
295        }
296    }
297
298    /// Send data with ancillary metadata (`sendmsg` is not available on this platform build).
299    #[cfg(not(unix))]
300    pub fn send_msg(
301        &self,
302        _buf: &[u8],
303        _target: SocketAddr,
304        _meta: Option<&UdpSendMeta>,
305    ) -> io::Result<usize> {
306        Err(io::Error::new(
307            io::ErrorKind::Unsupported,
308            "send_msg is only supported on Unix",
309        ))
310    }
311
312    /// Receive data.
313    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
314        // Safety: `MaybeUninit<u8>` has the same layout as `u8`.
315        let buf_maybe = unsafe {
316            std::slice::from_raw_parts_mut(
317                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
318                buf.len(),
319            )
320        };
321
322        let (n, addr) = self.socket.recv_from(buf_maybe)?;
323        let addr = addr
324            .as_socket()
325            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
326
327        Ok((n, addr))
328    }
329
330    /// Receive data with ancillary metadata (`recvmsg` on Unix).
331    ///
332    /// This allows extracting packet-info control messages such as destination
333    /// address and incoming interface index when enabled with
334    /// `set_recv_pktinfo_v4` / `set_recv_pktinfo_v6`.
335    #[cfg(unix)]
336    pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result<UdpRecvMeta> {
337        use nix::sys::socket::{ControlMessageOwned, MsgFlags, SockaddrStorage, recvmsg};
338        use std::io::IoSliceMut;
339        use std::os::fd::AsRawFd;
340
341        let mut iov = [IoSliceMut::new(buf)];
342        #[cfg(any(
343            target_os = "android",
344            target_os = "fuchsia",
345            target_os = "linux",
346            target_vendor = "apple",
347            target_os = "netbsd"
348        ))]
349        let mut cmsgspace = nix::cmsg_space!(libc::in_pktinfo, libc::in6_pktinfo);
350        #[cfg(all(
351            not(any(
352                target_os = "android",
353                target_os = "fuchsia",
354                target_os = "linux",
355                target_vendor = "apple",
356                target_os = "netbsd"
357            )),
358            any(target_os = "freebsd", target_os = "openbsd")
359        ))]
360        let mut cmsgspace = nix::cmsg_space!(libc::in6_pktinfo);
361        #[cfg(all(
362            not(any(
363                target_os = "android",
364                target_os = "fuchsia",
365                target_os = "linux",
366                target_vendor = "apple",
367                target_os = "netbsd"
368            )),
369            not(any(target_os = "freebsd", target_os = "openbsd"))
370        ))]
371        let mut cmsgspace = nix::cmsg_space!(libc::c_int);
372        let msg = recvmsg::<SockaddrStorage>(
373            self.socket.as_raw_fd(),
374            &mut iov,
375            Some(&mut cmsgspace),
376            MsgFlags::empty(),
377        )
378        .map_err(|e| io::Error::from_raw_os_error(e as i32))?;
379
380        let source_addr = msg
381            .address
382            .and_then(|addr: SockaddrStorage| {
383                if let Some(v4) = addr.as_sockaddr_in() {
384                    return Some(SocketAddr::from(*v4));
385                }
386                if let Some(v6) = addr.as_sockaddr_in6() {
387                    return Some(SocketAddr::from(*v6));
388                }
389                None
390            })
391            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid source address"))?;
392
393        let mut destination_addr = None;
394        let mut interface_index = None;
395
396        if let Ok(cmsgs) = msg.cmsgs() {
397            for cmsg in cmsgs {
398                match cmsg {
399                    #[cfg(any(
400                        target_os = "android",
401                        target_os = "fuchsia",
402                        target_os = "linux",
403                        target_vendor = "apple",
404                        target_os = "netbsd"
405                    ))]
406                    ControlMessageOwned::Ipv4PacketInfo(info) => {
407                        destination_addr = Some(IpAddr::V4(std::net::Ipv4Addr::from(
408                            info.ipi_addr.s_addr.to_ne_bytes(),
409                        )));
410                        interface_index = Some(info.ipi_ifindex.try_into().map_err(|_| {
411                            io::Error::new(
412                                io::ErrorKind::InvalidData,
413                                "received invalid interface index",
414                            )
415                        })?);
416                    }
417                    #[cfg(any(
418                        target_os = "android",
419                        target_os = "freebsd",
420                        target_os = "linux",
421                        target_os = "macos",
422                        target_os = "ios",
423                        target_os = "tvos",
424                        target_os = "visionos",
425                        target_os = "watchos",
426                        target_os = "netbsd",
427                        target_os = "openbsd"
428                    ))]
429                    ControlMessageOwned::Ipv6PacketInfo(info) => {
430                        destination_addr =
431                            Some(IpAddr::V6(std::net::Ipv6Addr::from(info.ipi6_addr.s6_addr)));
432                        interface_index = Some(info.ipi6_ifindex.try_into().map_err(|_| {
433                            io::Error::new(
434                                io::ErrorKind::InvalidData,
435                                "received invalid interface index",
436                            )
437                        })?);
438                    }
439                    _ => {}
440                }
441            }
442        }
443
444        Ok(UdpRecvMeta {
445            bytes_read: msg.bytes,
446            source_addr,
447            destination_addr,
448            interface_index,
449        })
450    }
451
452    /// Receive data with ancillary metadata (`recvmsg` is not available on this platform build).
453    #[cfg(not(unix))]
454    pub fn recv_msg(&self, _buf: &mut [u8]) -> io::Result<UdpRecvMeta> {
455        Err(io::Error::new(
456            io::ErrorKind::Unsupported,
457            "recv_msg is only supported on Unix",
458        ))
459    }
460
461    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
462        self.socket.set_ttl(ttl)
463    }
464
465    pub fn ttl(&self) -> io::Result<u32> {
466        self.socket.ttl()
467    }
468
469    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
470        self.socket.set_unicast_hops_v6(hops)
471    }
472
473    pub fn hoplimit(&self) -> io::Result<u32> {
474        self.socket.unicast_hops_v6()
475    }
476
477    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
478        self.socket.set_reuse_address(on)
479    }
480
481    pub fn reuseaddr(&self) -> io::Result<bool> {
482        self.socket.reuse_address()
483    }
484
485    #[cfg(any(
486        target_os = "android",
487        target_os = "dragonfly",
488        target_os = "freebsd",
489        target_os = "fuchsia",
490        target_os = "ios",
491        target_os = "linux",
492        target_os = "macos",
493        target_os = "netbsd",
494        target_os = "openbsd",
495        target_os = "tvos",
496        target_os = "visionos",
497        target_os = "watchos"
498    ))]
499    pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
500        self.socket.set_reuse_port(on)
501    }
502
503    #[cfg(any(
504        target_os = "android",
505        target_os = "dragonfly",
506        target_os = "freebsd",
507        target_os = "fuchsia",
508        target_os = "ios",
509        target_os = "linux",
510        target_os = "macos",
511        target_os = "netbsd",
512        target_os = "openbsd",
513        target_os = "tvos",
514        target_os = "visionos",
515        target_os = "watchos"
516    ))]
517    pub fn reuseport(&self) -> io::Result<bool> {
518        self.socket.reuse_port()
519    }
520
521    pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
522        self.socket.set_broadcast(on)
523    }
524
525    pub fn broadcast(&self) -> io::Result<bool> {
526        self.socket.broadcast()
527    }
528
529    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
530        self.socket.set_recv_buffer_size(size)
531    }
532
533    pub fn recv_buffer_size(&self) -> io::Result<usize> {
534        self.socket.recv_buffer_size()
535    }
536
537    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
538        self.socket.set_send_buffer_size(size)
539    }
540
541    pub fn send_buffer_size(&self) -> io::Result<usize> {
542        self.socket.send_buffer_size()
543    }
544
545    pub fn set_tos(&self, tos: u32) -> io::Result<()> {
546        self.socket.set_tos(tos)
547    }
548
549    pub fn tos(&self) -> io::Result<u32> {
550        self.socket.tos()
551    }
552
553    #[cfg(any(
554        target_os = "android",
555        target_os = "dragonfly",
556        target_os = "freebsd",
557        target_os = "fuchsia",
558        target_os = "ios",
559        target_os = "linux",
560        target_os = "macos",
561        target_os = "netbsd",
562        target_os = "openbsd",
563        target_os = "tvos",
564        target_os = "visionos",
565        target_os = "watchos"
566    ))]
567    pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
568        self.socket.set_tclass_v6(tclass)
569    }
570
571    #[cfg(any(
572        target_os = "android",
573        target_os = "dragonfly",
574        target_os = "freebsd",
575        target_os = "fuchsia",
576        target_os = "ios",
577        target_os = "linux",
578        target_os = "macos",
579        target_os = "netbsd",
580        target_os = "openbsd",
581        target_os = "tvos",
582        target_os = "visionos",
583        target_os = "watchos"
584    ))]
585    pub fn tclass_v6(&self) -> io::Result<u32> {
586        self.socket.tclass_v6()
587    }
588
589    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
590        self.socket.set_only_v6(only_v6)
591    }
592
593    pub fn only_v6(&self) -> io::Result<bool> {
594        self.socket.only_v6()
595    }
596
597    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
598        self.socket.set_keepalive(on)
599    }
600
601    pub fn keepalive(&self) -> io::Result<bool> {
602        self.socket.keepalive()
603    }
604
605    /// Enable IPv4 packet-info ancillary data receiving (`IP_PKTINFO`) where supported.
606    pub fn set_recv_pktinfo_v4(&self, on: bool) -> io::Result<()> {
607        crate::udp::set_recv_pktinfo_v4(&self.socket, on)
608    }
609
610    /// Enable IPv6 packet-info ancillary data receiving (`IPV6_RECVPKTINFO`) where supported.
611    pub fn set_recv_pktinfo_v6(&self, on: bool) -> io::Result<()> {
612        crate::udp::set_recv_pktinfo_v6(&self.socket, on)
613    }
614
615    /// Query whether IPv4 packet-info ancillary data receiving is enabled.
616    pub fn recv_pktinfo_v4(&self) -> io::Result<bool> {
617        crate::udp::recv_pktinfo_v4(&self.socket)
618    }
619
620    /// Query whether IPv6 packet-info ancillary data receiving is enabled.
621    pub fn recv_pktinfo_v6(&self) -> io::Result<bool> {
622        crate::udp::recv_pktinfo_v6(&self.socket)
623    }
624
625    /// Retrieve the local socket address.
626    pub fn local_addr(&self) -> io::Result<SocketAddr> {
627        self.socket
628            .local_addr()?
629            .as_socket()
630            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
631    }
632
633    /// Convert into a raw `std::net::UdpSocket`.
634    pub fn to_std(self) -> io::Result<StdUdpSocket> {
635        Ok(self.socket.into())
636    }
637
638    /// Construct from a raw `socket2::Socket`.
639    pub fn from_socket(socket: Socket) -> Self {
640        Self { socket }
641    }
642
643    /// Borrow the inner `socket2::Socket`.
644    pub fn socket(&self) -> &Socket {
645        &self.socket
646    }
647
648    /// Consume and return the inner `socket2::Socket`.
649    pub fn into_socket(self) -> Socket {
650        self.socket
651    }
652
653    #[cfg(unix)]
654    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
655        use std::os::fd::AsRawFd;
656        self.socket.as_raw_fd()
657    }
658
659    #[cfg(windows)]
660    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
661        use std::os::windows::io::AsRawSocket;
662        self.socket.as_raw_socket()
663    }
664}
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669
670    #[test]
671    fn create_v4_socket() {
672        let sock = UdpSocket::v4_dgram().expect("create socket");
673        let addr = sock.local_addr().expect("addr");
674        assert!(addr.is_ipv4());
675    }
676}