multicast_discovery_socket/
socket.rs

1use std::io;
2use std::io::{IoSlice, IoSliceMut, Result};
3use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
4use socket2::{Domain, Protocol, SockAddr, Socket, Type};
5#[cfg(depend_nix)]
6use nix::sys::socket;
7
8/// Ipv4 udp socket with capability to send/receive packets on specific interfaces.
9pub struct MultiInterfaceSocket {
10    socket: Socket,
11    #[cfg(windows)]
12    wsa_structs: win_helper::WSAStructs
13}
14#[cfg(depend_nix)]
15fn nix_to_io_error(e: nix::Error) -> io::Error {
16    io::Error::other(e)
17}
18
19#[cfg(windows)]
20#[path = "./win_specific.rs"]
21mod win_helper;
22
23impl MultiInterfaceSocket {
24    pub fn bind_any() -> Result<Self> {
25        Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
26    }
27    
28    pub fn bind_port(port: u16) -> Result<Self> {
29        Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port))
30    }
31    pub fn bind(addr: SocketAddrV4) -> Result<Self> {
32        let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
33        socket.bind(&addr.into())?;
34
35        #[cfg(depend_nix)]
36        use std::os::fd::AsFd;
37        #[cfg(depend_nix)]
38        socket::setsockopt(&socket.as_fd(), socket::sockopt::Ipv4PacketInfo, &true)
39            .map_err(nix_to_io_error)?;
40
41        #[cfg(windows)]
42        let wsa_structs = win_helper::win_init(&socket)?;
43
44        Ok(Self {
45            socket,
46            #[cfg(windows)]
47            wsa_structs
48        })
49    }
50    
51    pub fn get_bind_addr(&self) -> Result<SocketAddrV4> {
52        let addr = self.socket.local_addr()?;
53        if let Some(addr) = addr.as_socket_ipv4() {
54            Ok(addr)
55        } else {
56            Err(io::Error::other("Not an IPv4 address"))
57        }
58    }
59    
60    /// Join a multicast group on provided interface. 
61    pub fn join_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
62        self.socket.join_multicast_v4(&addr, &interface)
63    }
64
65    /// Leave a multicast group on provided interface. 
66    pub fn leave_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
67        self.socket.leave_multicast_v4(&addr, &interface)
68    }
69    
70    /// Nonblocking mode will cause read operations to return immediately with an error if no data is available.
71    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
72        self.socket.set_nonblocking(nonblocking)
73    }
74    /// Assign a timeout to read operations on the socket.
75    pub fn set_read_timeout(&self, timeout: std::time::Duration) -> Result<()> {
76        self.socket.set_read_timeout(Some(timeout))
77    }
78    /// When socket is non-blocking, this option will cause the read operation to block indefinitely until data is available.
79    pub fn set_read_timeout_inf(&self) -> Result<()> {
80        self.socket.set_read_timeout(None)
81    }
82
83}
84
85#[cfg(depend_nix)]
86impl MultiInterfaceSocket {
87    /// `recvfrom`, but with interface index.
88    pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
89        use std::os::fd::AsRawFd;
90
91        let mut control_buffer = nix::cmsg_space!(nix::libc::in_pktinfo);
92        let mut bufs = [IoSliceMut::new(buf)];
93        let message: socket::RecvMsg<socket::SockaddrIn> = socket::recvmsg(
94            self.socket.as_raw_fd(),
95            &mut bufs,
96            Some(&mut control_buffer),
97            socket::MsgFlags::empty(),
98        )
99            .map_err(nix_to_io_error)?;
100
101        let dst_addr = message.address.map(|a| SocketAddrV4::new(a.ip(), a.port()))
102            .unwrap_or(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0));
103        let sz = message.bytes;
104
105        let mut index = 0;
106        for cmsg in message.cmsgs()? {
107            if let socket::ControlMessageOwned::Ipv4PacketInfo(pkt_info) = cmsg {
108                index = pkt_info.ipi_ifindex as u32;
109                break;
110            }
111        }
112        Ok((&mut buf[..sz], dst_addr, index))
113    }
114
115
116    #[cfg(depend_nix)]
117    pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, _source_if_addr: IpAddr) -> Result<usize> {
118        use std::os::fd::AsRawFd;
119
120        let bufs = [IoSlice::new(buf)];
121        let mut pkt_info: nix::libc::in_pktinfo = unsafe { std::mem::zeroed() };
122        pkt_info.ipi_ifindex = iface_index as i32;
123
124        socket::sendmsg(
125            self.socket.as_raw_fd(),
126            &bufs,
127            &[socket::ControlMessage::Ipv4PacketInfo(&pkt_info)],
128            socket::MsgFlags::empty(),
129            Some(&socket::SockaddrIn::from(addr)),
130        )
131            .map_err(nix_to_io_error)
132    }
133}
134#[cfg(windows)]
135impl MultiInterfaceSocket {
136    /// `recvfrom`, but with interface index.
137    pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
138        let (sz, addr, iface) = self.wsa_structs.receive(buf, &self.socket)?;
139        Ok((&mut buf[..sz], addr, iface))
140    }
141    pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, source_if_addr: IpAddr) -> Result<usize> {
142        if let IpAddr::V4(source_ip_addr) = source_if_addr {
143            self.wsa_structs.send(buf, addr, iface_index, source_ip_addr, &self.socket)
144        }
145        else {
146            Err(io::Error::other("Not an IPv4 address"))
147        }
148    }
149
150}
151#[cfg(use_fallback_impl)]
152fn convert_buf(buf: &mut [u8]) -> &mut [std::mem::MaybeUninit<u8>] {
153    unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>, buf.len()) }
154}
155// Fallback implementation
156#[cfg(use_fallback_impl)]
157impl MultiInterfaceSocket {
158    pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
159        let (sz, addr) = self.socket.recv_from(convert_buf(buf))?;
160
161        if let Some(addr) = addr.as_socket_ipv4() {
162            Ok((&mut buf[..sz], addr, 1))
163        }
164        else {
165            Err(io::Error::other("Not an IPv4 address"))
166        }
167
168    }
169    pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, source_if_addr: IpAddr) -> Result<usize> {
170        self.socket.send_to(buf, &SockAddr::from(addr))
171    }
172}