Skip to main content

nex_socket/udp/
sync_impl.rs

1use crate::udp::UdpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::IpAddr;
5use std::net::{SocketAddr, UdpSocket as StdUdpSocket};
6
7/// Synchronous low level UDP socket.
8#[derive(Debug)]
9pub struct UdpSocket {
10    socket: Socket,
11}
12
13/// Metadata returned from `recv_msg`.
14#[derive(Clone, Debug, Eq, PartialEq)]
15pub struct UdpRecvMeta {
16    /// Number of bytes received into the data buffer.
17    pub bytes_read: usize,
18    /// Source address of the datagram.
19    pub source_addr: SocketAddr,
20    /// Destination address that received the datagram, if provided by ancillary data.
21    pub destination_addr: Option<IpAddr>,
22    /// Interface index on which the datagram was received, if provided.
23    pub interface_index: Option<u32>,
24}
25
26/// Optional metadata used by `send_msg`.
27#[derive(Clone, Debug, Default, Eq, PartialEq)]
28pub struct UdpSendMeta {
29    /// Explicit source IP address to use for transmission when supported.
30    pub source_addr: Option<IpAddr>,
31    /// Explicit outgoing interface index when supported.
32    pub interface_index: Option<u32>,
33}
34
35impl UdpSocket {
36    /// Create a socket from the provided configuration.
37    pub fn from_config(config: &UdpConfig) -> io::Result<Self> {
38        config.validate()?;
39
40        let socket = Socket::new(
41            config.socket_family.to_domain(),
42            config.socket_type.to_sock_type(),
43            Some(Protocol::UDP),
44        )?;
45
46        socket.set_nonblocking(false)?;
47
48        // Set socket options based on configuration
49        if let Some(flag) = config.reuseaddr {
50            socket.set_reuse_address(flag)?;
51        }
52        #[cfg(any(
53            target_os = "android",
54            target_os = "dragonfly",
55            target_os = "freebsd",
56            target_os = "fuchsia",
57            target_os = "ios",
58            target_os = "linux",
59            target_os = "macos",
60            target_os = "netbsd",
61            target_os = "openbsd",
62            target_os = "tvos",
63            target_os = "visionos",
64            target_os = "watchos"
65        ))]
66        if let Some(flag) = config.reuseport {
67            socket.set_reuse_port(flag)?;
68        }
69        if let Some(flag) = config.broadcast {
70            socket.set_broadcast(flag)?;
71        }
72        if let Some(ttl) = config.ttl {
73            socket.set_ttl(ttl)?;
74        }
75        if let Some(hoplimit) = config.hoplimit {
76            socket.set_unicast_hops_v6(hoplimit)?;
77        }
78        if let Some(timeout) = config.read_timeout {
79            socket.set_read_timeout(Some(timeout))?;
80        }
81        if let Some(timeout) = config.write_timeout {
82            socket.set_write_timeout(Some(timeout))?;
83        }
84        if let Some(size) = config.recv_buffer_size {
85            socket.set_recv_buffer_size(size)?;
86        }
87        if let Some(size) = config.send_buffer_size {
88            socket.set_send_buffer_size(size)?;
89        }
90        if let Some(tos) = config.tos {
91            socket.set_tos(tos)?;
92        }
93        #[cfg(any(
94            target_os = "android",
95            target_os = "dragonfly",
96            target_os = "freebsd",
97            target_os = "fuchsia",
98            target_os = "ios",
99            target_os = "linux",
100            target_os = "macos",
101            target_os = "netbsd",
102            target_os = "openbsd",
103            target_os = "tvos",
104            target_os = "visionos",
105            target_os = "watchos"
106        ))]
107        if let Some(tclass) = config.tclass_v6 {
108            socket.set_tclass_v6(tclass)?;
109        }
110        if let Some(only_v6) = config.only_v6 {
111            socket.set_only_v6(only_v6)?;
112        }
113        if let Some(on) = config.recv_pktinfo {
114            crate::udp::set_recv_pktinfo(&socket, config.socket_family, on)?;
115        }
116
117        // Linux: optional interface name
118        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
119        if let Some(iface) = &config.bind_device {
120            socket.bind_device(Some(iface.as_bytes()))?;
121        }
122
123        // bind to the specified address if provided
124        if let Some(addr) = config.bind_addr {
125            socket.bind(&addr.into())?;
126        }
127
128        Ok(Self { socket })
129    }
130
131    /// Create a socket of arbitrary type (DGRAM or RAW).
132    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
133        let socket = Socket::new(domain, sock_type, Some(Protocol::UDP))?;
134        socket.set_nonblocking(false)?;
135        Ok(Self { socket })
136    }
137
138    /// Convenience constructor for IPv4 DGRAM.
139    pub fn v4_dgram() -> io::Result<Self> {
140        Self::new(Domain::IPV4, SockType::DGRAM)
141    }
142
143    /// Convenience constructor for IPv6 DGRAM.
144    pub fn v6_dgram() -> io::Result<Self> {
145        Self::new(Domain::IPV6, SockType::DGRAM)
146    }
147
148    /// IPv4 RAW UDP. Requires administrator privileges.
149    pub fn raw_v4() -> io::Result<Self> {
150        Self::new(Domain::IPV4, SockType::RAW)
151    }
152
153    /// IPv6 RAW UDP. Requires administrator privileges.
154    pub fn raw_v6() -> io::Result<Self> {
155        Self::new(Domain::IPV6, SockType::RAW)
156    }
157
158    /// Send data.
159    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
160        self.socket.send_to(buf, &target.into())
161    }
162
163    /// Send data with ancillary metadata (`sendmsg` on Unix).
164    ///
165    /// When supported by the current OS, source address and interface index are
166    /// propagated using packet-info control messages.
167    #[cfg(unix)]
168    pub fn send_msg(
169        &self,
170        buf: &[u8],
171        target: SocketAddr,
172        meta: Option<&UdpSendMeta>,
173    ) -> io::Result<usize> {
174        use nix::sys::socket::{ControlMessage, MsgFlags, SockaddrIn, SockaddrIn6, sendmsg};
175        use std::io::IoSlice;
176        use std::os::fd::AsRawFd;
177
178        let iov = [IoSlice::new(buf)];
179        let raw_fd = self.socket.as_raw_fd();
180
181        match target {
182            SocketAddr::V4(addr) => {
183                let sockaddr = SockaddrIn::from(addr);
184                #[cfg(any(
185                    target_os = "android",
186                    target_os = "linux",
187                    target_os = "netbsd",
188                    target_vendor = "apple"
189                ))]
190                {
191                    if let Some(meta) = meta {
192                        if meta.source_addr.is_some() || meta.interface_index.is_some() {
193                            if let Some(src) = meta.source_addr {
194                                if !src.is_ipv4() {
195                                    return Err(io::Error::new(
196                                        io::ErrorKind::InvalidInput,
197                                        "source_addr family does not match target",
198                                    ));
199                                }
200                            }
201                            let mut pktinfo: libc::in_pktinfo = unsafe { std::mem::zeroed() };
202                            if let Some(src) = meta.source_addr.and_then(|ip| match ip {
203                                IpAddr::V4(v4) => Some(v4),
204                                IpAddr::V6(_) => None,
205                            }) {
206                                pktinfo.ipi_spec_dst.s_addr = u32::from_ne_bytes(src.octets());
207                            }
208                            if let Some(ifindex) = meta.interface_index {
209                                pktinfo.ipi_ifindex = ifindex.try_into().map_err(|_| {
210                                    io::Error::new(
211                                        io::ErrorKind::InvalidInput,
212                                        "interface_index is out of range for this platform",
213                                    )
214                                })?;
215                            }
216                            let cmsgs = [ControlMessage::Ipv4PacketInfo(&pktinfo)];
217                            return sendmsg(
218                                raw_fd,
219                                &iov,
220                                &cmsgs,
221                                MsgFlags::empty(),
222                                Some(&sockaddr),
223                            )
224                            .map_err(|e| io::Error::from_raw_os_error(e as i32));
225                        }
226                    }
227                }
228                if let Some(meta) = meta {
229                    if meta.source_addr.is_some() || meta.interface_index.is_some() {
230                        return Err(io::Error::new(
231                            io::ErrorKind::Unsupported,
232                            "send_msg packet-info metadata is not supported on this platform",
233                        ));
234                    }
235                }
236                sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr))
237                    .map_err(|e| io::Error::from_raw_os_error(e as i32))
238            }
239            SocketAddr::V6(addr) => {
240                let sockaddr = SockaddrIn6::from(addr);
241                #[cfg(any(
242                    target_os = "android",
243                    target_os = "freebsd",
244                    target_os = "linux",
245                    target_os = "netbsd",
246                    target_vendor = "apple"
247                ))]
248                {
249                    if let Some(meta) = meta {
250                        if meta.source_addr.is_some() || meta.interface_index.is_some() {
251                            if let Some(src) = meta.source_addr {
252                                if !src.is_ipv6() {
253                                    return Err(io::Error::new(
254                                        io::ErrorKind::InvalidInput,
255                                        "source_addr family does not match target",
256                                    ));
257                                }
258                            }
259                            let mut pktinfo: libc::in6_pktinfo = unsafe { std::mem::zeroed() };
260                            if let Some(src) = meta.source_addr.and_then(|ip| match ip {
261                                IpAddr::V4(_) => None,
262                                IpAddr::V6(v6) => Some(v6),
263                            }) {
264                                pktinfo.ipi6_addr.s6_addr = src.octets();
265                            }
266                            if let Some(ifindex) = meta.interface_index {
267                                pktinfo.ipi6_ifindex = ifindex.try_into().map_err(|_| {
268                                    io::Error::new(
269                                        io::ErrorKind::InvalidInput,
270                                        "interface_index is out of range for this platform",
271                                    )
272                                })?;
273                            }
274                            let cmsgs = [ControlMessage::Ipv6PacketInfo(&pktinfo)];
275                            return sendmsg(
276                                raw_fd,
277                                &iov,
278                                &cmsgs,
279                                MsgFlags::empty(),
280                                Some(&sockaddr),
281                            )
282                            .map_err(|e| io::Error::from_raw_os_error(e as i32));
283                        }
284                    }
285                }
286                if let Some(meta) = meta {
287                    if meta.source_addr.is_some() || meta.interface_index.is_some() {
288                        return Err(io::Error::new(
289                            io::ErrorKind::Unsupported,
290                            "send_msg packet-info metadata is not supported on this platform",
291                        ));
292                    }
293                }
294                sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr))
295                    .map_err(|e| io::Error::from_raw_os_error(e as i32))
296            }
297        }
298    }
299
300    /// Send data with ancillary metadata (`sendmsg` is not available on this platform build).
301    #[cfg(not(unix))]
302    pub fn send_msg(
303        &self,
304        _buf: &[u8],
305        _target: SocketAddr,
306        _meta: Option<&UdpSendMeta>,
307    ) -> io::Result<usize> {
308        Err(io::Error::new(
309            io::ErrorKind::Unsupported,
310            "send_msg is only supported on Unix",
311        ))
312    }
313
314    /// Receive data.
315    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
316        // Safety: `MaybeUninit<u8>` has the same layout as `u8`.
317        let buf_maybe = unsafe {
318            std::slice::from_raw_parts_mut(
319                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
320                buf.len(),
321            )
322        };
323
324        let (n, addr) = self.socket.recv_from(buf_maybe)?;
325        let addr = addr
326            .as_socket()
327            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
328
329        Ok((n, addr))
330    }
331
332    /// Receive data with ancillary metadata (`recvmsg` on Unix).
333    ///
334    /// This allows extracting packet-info control messages such as destination
335    /// address and incoming interface index when enabled with
336    /// `set_recv_pktinfo_v4` / `set_recv_pktinfo_v6`.
337    #[cfg(unix)]
338    pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result<UdpRecvMeta> {
339        use nix::sys::socket::{ControlMessageOwned, MsgFlags, SockaddrStorage, recvmsg};
340        use std::io::IoSliceMut;
341        use std::os::fd::AsRawFd;
342
343        let mut iov = [IoSliceMut::new(buf)];
344        #[cfg(any(
345            target_os = "android",
346            target_os = "fuchsia",
347            target_os = "linux",
348            target_vendor = "apple",
349            target_os = "netbsd"
350        ))]
351        let mut cmsgspace = nix::cmsg_space!(libc::in_pktinfo, libc::in6_pktinfo);
352        #[cfg(all(
353            not(any(
354                target_os = "android",
355                target_os = "fuchsia",
356                target_os = "linux",
357                target_vendor = "apple",
358                target_os = "netbsd"
359            )),
360            any(target_os = "freebsd", target_os = "openbsd")
361        ))]
362        let mut cmsgspace = nix::cmsg_space!(libc::in6_pktinfo);
363        #[cfg(all(
364            not(any(
365                target_os = "android",
366                target_os = "fuchsia",
367                target_os = "linux",
368                target_vendor = "apple",
369                target_os = "netbsd"
370            )),
371            not(any(target_os = "freebsd", target_os = "openbsd"))
372        ))]
373        let mut cmsgspace = nix::cmsg_space!(libc::c_int);
374        let msg = recvmsg::<SockaddrStorage>(
375            self.socket.as_raw_fd(),
376            &mut iov,
377            Some(&mut cmsgspace),
378            MsgFlags::empty(),
379        )
380        .map_err(|e| io::Error::from_raw_os_error(e as i32))?;
381
382        let source_addr = msg
383            .address
384            .and_then(|addr: SockaddrStorage| {
385                if let Some(v4) = addr.as_sockaddr_in() {
386                    return Some(SocketAddr::from(*v4));
387                }
388                if let Some(v6) = addr.as_sockaddr_in6() {
389                    return Some(SocketAddr::from(*v6));
390                }
391                None
392            })
393            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid source address"))?;
394
395        let mut destination_addr = None;
396        let mut interface_index = None;
397
398        if let Ok(cmsgs) = msg.cmsgs() {
399            for cmsg in cmsgs {
400                match cmsg {
401                    #[cfg(any(
402                        target_os = "android",
403                        target_os = "fuchsia",
404                        target_os = "linux",
405                        target_vendor = "apple",
406                        target_os = "netbsd"
407                    ))]
408                    ControlMessageOwned::Ipv4PacketInfo(info) => {
409                        destination_addr = Some(IpAddr::V4(std::net::Ipv4Addr::from(
410                            info.ipi_addr.s_addr.to_ne_bytes(),
411                        )));
412                        interface_index = Some(info.ipi_ifindex.try_into().map_err(|_| {
413                            io::Error::new(
414                                io::ErrorKind::InvalidData,
415                                "received invalid interface index",
416                            )
417                        })?);
418                    }
419                    #[cfg(any(
420                        target_os = "android",
421                        target_os = "freebsd",
422                        target_os = "linux",
423                        target_os = "macos",
424                        target_os = "ios",
425                        target_os = "tvos",
426                        target_os = "visionos",
427                        target_os = "watchos",
428                        target_os = "netbsd",
429                        target_os = "openbsd"
430                    ))]
431                    ControlMessageOwned::Ipv6PacketInfo(info) => {
432                        destination_addr =
433                            Some(IpAddr::V6(std::net::Ipv6Addr::from(info.ipi6_addr.s6_addr)));
434                        interface_index = Some(info.ipi6_ifindex.try_into().map_err(|_| {
435                            io::Error::new(
436                                io::ErrorKind::InvalidData,
437                                "received invalid interface index",
438                            )
439                        })?);
440                    }
441                    _ => {}
442                }
443            }
444        }
445
446        Ok(UdpRecvMeta {
447            bytes_read: msg.bytes,
448            source_addr,
449            destination_addr,
450            interface_index,
451        })
452    }
453
454    /// Receive data with ancillary metadata (`recvmsg` is not available on this platform build).
455    #[cfg(not(unix))]
456    pub fn recv_msg(&self, _buf: &mut [u8]) -> io::Result<UdpRecvMeta> {
457        Err(io::Error::new(
458            io::ErrorKind::Unsupported,
459            "recv_msg is only supported on Unix",
460        ))
461    }
462
463    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
464        self.socket.set_ttl(ttl)
465    }
466
467    pub fn ttl(&self) -> io::Result<u32> {
468        self.socket.ttl()
469    }
470
471    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
472        self.socket.set_unicast_hops_v6(hops)
473    }
474
475    pub fn hoplimit(&self) -> io::Result<u32> {
476        self.socket.unicast_hops_v6()
477    }
478
479    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
480        self.socket.set_reuse_address(on)
481    }
482
483    pub fn reuseaddr(&self) -> io::Result<bool> {
484        self.socket.reuse_address()
485    }
486
487    #[cfg(any(
488        target_os = "android",
489        target_os = "dragonfly",
490        target_os = "freebsd",
491        target_os = "fuchsia",
492        target_os = "ios",
493        target_os = "linux",
494        target_os = "macos",
495        target_os = "netbsd",
496        target_os = "openbsd",
497        target_os = "tvos",
498        target_os = "visionos",
499        target_os = "watchos"
500    ))]
501    pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
502        self.socket.set_reuse_port(on)
503    }
504
505    #[cfg(any(
506        target_os = "android",
507        target_os = "dragonfly",
508        target_os = "freebsd",
509        target_os = "fuchsia",
510        target_os = "ios",
511        target_os = "linux",
512        target_os = "macos",
513        target_os = "netbsd",
514        target_os = "openbsd",
515        target_os = "tvos",
516        target_os = "visionos",
517        target_os = "watchos"
518    ))]
519    pub fn reuseport(&self) -> io::Result<bool> {
520        self.socket.reuse_port()
521    }
522
523    pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
524        self.socket.set_broadcast(on)
525    }
526
527    pub fn broadcast(&self) -> io::Result<bool> {
528        self.socket.broadcast()
529    }
530
531    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
532        self.socket.set_recv_buffer_size(size)
533    }
534
535    pub fn recv_buffer_size(&self) -> io::Result<usize> {
536        self.socket.recv_buffer_size()
537    }
538
539    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
540        self.socket.set_send_buffer_size(size)
541    }
542
543    pub fn send_buffer_size(&self) -> io::Result<usize> {
544        self.socket.send_buffer_size()
545    }
546
547    pub fn set_tos(&self, tos: u32) -> io::Result<()> {
548        self.socket.set_tos(tos)
549    }
550
551    pub fn tos(&self) -> io::Result<u32> {
552        self.socket.tos()
553    }
554
555    #[cfg(any(
556        target_os = "android",
557        target_os = "dragonfly",
558        target_os = "freebsd",
559        target_os = "fuchsia",
560        target_os = "ios",
561        target_os = "linux",
562        target_os = "macos",
563        target_os = "netbsd",
564        target_os = "openbsd",
565        target_os = "tvos",
566        target_os = "visionos",
567        target_os = "watchos"
568    ))]
569    pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
570        self.socket.set_tclass_v6(tclass)
571    }
572
573    #[cfg(any(
574        target_os = "android",
575        target_os = "dragonfly",
576        target_os = "freebsd",
577        target_os = "fuchsia",
578        target_os = "ios",
579        target_os = "linux",
580        target_os = "macos",
581        target_os = "netbsd",
582        target_os = "openbsd",
583        target_os = "tvos",
584        target_os = "visionos",
585        target_os = "watchos"
586    ))]
587    pub fn tclass_v6(&self) -> io::Result<u32> {
588        self.socket.tclass_v6()
589    }
590
591    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
592        self.socket.set_only_v6(only_v6)
593    }
594
595    pub fn only_v6(&self) -> io::Result<bool> {
596        self.socket.only_v6()
597    }
598
599    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
600        self.socket.set_keepalive(on)
601    }
602
603    pub fn keepalive(&self) -> io::Result<bool> {
604        self.socket.keepalive()
605    }
606
607    /// Enable IPv4 packet-info ancillary data receiving (`IP_PKTINFO`) where supported.
608    pub fn set_recv_pktinfo_v4(&self, on: bool) -> io::Result<()> {
609        crate::udp::set_recv_pktinfo_v4(&self.socket, on)
610    }
611
612    /// Enable IPv6 packet-info ancillary data receiving (`IPV6_RECVPKTINFO`) where supported.
613    pub fn set_recv_pktinfo_v6(&self, on: bool) -> io::Result<()> {
614        crate::udp::set_recv_pktinfo_v6(&self.socket, on)
615    }
616
617    /// Query whether IPv4 packet-info ancillary data receiving is enabled.
618    pub fn recv_pktinfo_v4(&self) -> io::Result<bool> {
619        crate::udp::recv_pktinfo_v4(&self.socket)
620    }
621
622    /// Query whether IPv6 packet-info ancillary data receiving is enabled.
623    pub fn recv_pktinfo_v6(&self) -> io::Result<bool> {
624        crate::udp::recv_pktinfo_v6(&self.socket)
625    }
626
627    /// Retrieve the local socket address.
628    pub fn local_addr(&self) -> io::Result<SocketAddr> {
629        self.socket
630            .local_addr()?
631            .as_socket()
632            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
633    }
634
635    /// Convert into a raw `std::net::UdpSocket`.
636    pub fn to_std(self) -> io::Result<StdUdpSocket> {
637        Ok(self.socket.into())
638    }
639
640    /// Construct from a raw `socket2::Socket`.
641    pub fn from_socket(socket: Socket) -> Self {
642        Self { socket }
643    }
644
645    /// Borrow the inner `socket2::Socket`.
646    pub fn socket(&self) -> &Socket {
647        &self.socket
648    }
649
650    /// Consume and return the inner `socket2::Socket`.
651    pub fn into_socket(self) -> Socket {
652        self.socket
653    }
654
655    #[cfg(unix)]
656    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
657        use std::os::fd::AsRawFd;
658        self.socket.as_raw_fd()
659    }
660
661    #[cfg(windows)]
662    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
663        use std::os::windows::io::AsRawSocket;
664        self.socket.as_raw_socket()
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    #[test]
673    fn create_v4_socket() {
674        let sock = UdpSocket::v4_dgram().expect("create socket");
675        sock.socket.bind(&"0.0.0.0:0".parse::<SocketAddr>().unwrap().into()).expect("bind");
676        let addr = sock.local_addr().expect("addr");
677        assert!(addr.is_ipv4());
678    }
679}