multicast_discovery_socket/
socket.rs

1use std::io;
2use std::io::{IoSlice, IoSliceMut, Result};
3use std::net::{IpAddr, Ipv4Addr, SocketAddrV4};
4use socket2::{Domain, Protocol, Socket, Type};
5#[cfg(unix)]
6use nix::sys::socket;
7
8/// Ipv4 udp socket with capability to send/receive packets on specific interfaces.
9pub struct MultiInterfaceSocket {
10    socket: Socket,
11    #[cfg(windows)]
12    wsa_structs: win_helper::WSAStructs
13}
14#[cfg(unix)]
15fn nix_to_io_error(e: nix::Error) -> io::Error {
16    io::Error::other(e)
17}
18
19#[cfg(windows)]
20mod win_helper {
21    use std::ffi::{c_char, c_int};
22    use std::{io, mem, ptr};
23    use std::net::{Ipv4Addr, SocketAddrV4};
24    use std::os::windows::io::RawSocket;
25    use std::os::windows::prelude::AsRawSocket;
26    use socket2::Socket;
27    use winapi::shared::guiddef::GUID;
28    use winapi::shared::inaddr::*;
29    use winapi::shared::minwindef::DWORD;
30    use winapi::shared::minwindef::{INT, LPDWORD};
31    use winapi::shared::ws2def::LPWSAMSG;
32    use winapi::shared::ws2def::*;
33    use winapi::shared::ws2ipdef::*;
34    use winapi::um::winsock2;
35    use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG};
36    use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET};
37
38    fn last_error() -> io::Error {
39        io::Error::from_raw_os_error(unsafe { winsock2::WSAGetLastError() })
40    }
41
42    unsafe fn setsockopt<T>(socket: RawSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()>
43    where
44        T: Copy,
45    {
46        let payload = &payload as *const T as *const c_char;
47        if winsock2::setsockopt(socket as _, opt, val, payload, mem::size_of::<T>() as c_int) == 0 {
48            Ok(())
49        } else {
50            Err(last_error())
51        }
52    }
53    type WSARecvMsgExtension = unsafe extern "system" fn(
54        s: SOCKET,
55        lpMsg: LPWSAMSG,
56        lpdwNumberOfBytesRecvd: LPDWORD,
57        lpOverlapped: LPWSAOVERLAPPED,
58        lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
59    ) -> INT;
60    type WSASendMsgExtension = unsafe extern "system" fn(
61        s: SOCKET,
62        lpMsg: LPWSAMSG,
63        dwFlags: DWORD,
64        lpNumberOfBytesSent: LPDWORD,
65        lpOverlapped: LPWSAOVERLAPPED,
66        lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
67    ) -> INT;
68    
69    unsafe fn get_fn_pointer(socket: RawSocket, guid: GUID, fn_pointer: &mut usize, byte_len: &mut u32) -> c_int {
70        let fn_ptr = fn_pointer as *const _ as *mut _;
71        winsock2::WSAIoctl(
72            socket as _,
73            SIO_GET_EXTENSION_FUNCTION_POINTER,
74            &guid as *const _ as *mut _,
75            mem::size_of_val(&guid) as DWORD,
76            fn_ptr,
77            mem::size_of_val(&fn_ptr) as DWORD,
78            byte_len,
79            ptr::null_mut(),
80            None,
81        )
82    }
83    
84    fn locate_wsarecvmsg(socket: RawSocket) -> io::Result<WSARecvMsgExtension> {
85        let mut fn_pointer: usize = 0;
86        let mut byte_len: u32 = 0;
87
88        let r = unsafe { get_fn_pointer(socket, WSAID_WSARECVMSG, &mut fn_pointer, &mut byte_len) };
89
90        if r != 0 {
91            return Err(io::Error::last_os_error());
92        }
93
94        if mem::size_of::<LPFN_WSARECVMSG>() != byte_len as _ {
95            return Err(io::Error::other("Locating fn pointer to WSARecvMsg returned different expected bytes"));
96        }
97        let cast_to_fn: LPFN_WSARECVMSG = unsafe { mem::transmute(fn_pointer) };
98
99        match cast_to_fn {
100            None => Err(io::Error::other("WSARecvMsg extension not found")),
101            Some(extension) => Ok(extension),
102        }
103    }
104
105    fn locate_wsasendmsg(socket: RawSocket) -> io::Result<WSASendMsgExtension> {
106        let mut fn_pointer: usize = 0;
107        let mut byte_len: u32 = 0;
108
109        let r = unsafe { get_fn_pointer(socket, WSAID_WSASENDMSG, &mut fn_pointer, &mut byte_len) };
110        if r != 0 {
111            return Err(io::Error::last_os_error());
112        }
113
114        if mem::size_of::<LPFN_WSASENDMSG>() != byte_len as _ {
115            return Err(io::Error::other("Locating fn pointer to WSASendMsg returned different expected bytes"));
116        }
117        let cast_to_fn: LPFN_WSASENDMSG = unsafe { mem::transmute(fn_pointer) };
118
119        match cast_to_fn {
120            None => Err(io::Error::other("WSASendMsg extension not found",
121            )),
122            Some(extension) => Ok(extension),
123        }
124    }
125    pub struct WSAStructs {
126        wsarecvmsg: WSARecvMsgExtension,
127        wsasendmsg: WSASendMsgExtension,
128    }
129
130
131    fn set_pktinfo(socket: RawSocket, payload: bool) -> io::Result<()> {
132        unsafe { setsockopt(socket, IPPROTO_IP, IP_PKTINFO, payload as c_int) }
133    }
134
135    fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un {
136        let octets = addr.octets();
137        let res = u32::from_ne_bytes(octets);
138        let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() };
139        unsafe { *(new_addr.S_addr_mut()) = res };
140        new_addr
141    }
142
143    const CMSG_HEADER_SIZE: usize = size_of::<WSACMSGHDR>();
144    const PKTINFO_DATA_SIZE: usize = size_of::<IN_PKTINFO>();
145    const CONTROL_PKTINFO_BUFFER_SIZE: usize = CMSG_HEADER_SIZE + PKTINFO_DATA_SIZE;
146
147    pub fn win_init(
148        socket: &Socket
149    ) -> io::Result<WSAStructs> {
150
151        // enable fetching interface information and locate the extension function
152        set_pktinfo(socket.as_raw_socket(), true)?;
153        let wsarecvmsg: WSARecvMsgExtension = locate_wsarecvmsg(socket.as_raw_socket())?;
154        let wsasendmsg: WSASendMsgExtension = locate_wsasendmsg(socket.as_raw_socket())?;
155
156        Ok(WSAStructs {
157            wsarecvmsg,
158            wsasendmsg
159        })
160    }
161
162    impl WSAStructs {
163        pub fn receive(&self, data_buffer: &mut [u8], socket: &Socket) -> io::Result<(usize, SocketAddrV4, u32)> {
164            let mut data = WSABUF {
165                buf: data_buffer.as_mut_ptr() as *mut i8,
166                len: data_buffer.len() as u32,
167            };
168
169            let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
170            let control = WSABUF {
171                buf: control_buffer.as_mut_ptr(),
172                len: control_buffer.len() as u32,
173            };
174
175            let mut origin_address: SOCKADDR = unsafe { mem::zeroed() };
176            let mut wsa_msg = WSAMSG {
177                name: &mut origin_address,
178                namelen: mem::size_of_val(&origin_address) as i32,
179                lpBuffers: &mut data,
180                Control: control,
181                dwBufferCount: 1,
182                dwFlags: 0,
183            };
184
185            let mut read_bytes = 0;
186            let r = {
187                unsafe {
188                    (self.wsarecvmsg)(
189                        socket.as_raw_socket() as _,
190                        &mut wsa_msg,
191                        &mut read_bytes,
192                        ptr::null_mut(),
193                        None,
194                    )
195                }
196            };
197
198            if r != 0 {
199                return Err(io::Error::last_os_error());
200            }
201
202            let origin_address = if origin_address.sa_family != AF_INET as ADDRESS_FAMILY {
203                SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)
204            }
205            else {
206                let sa_data = origin_address.sa_data;
207
208                // Extract port (network byte order -> big-endian)
209                let port = u16::from_be_bytes([sa_data[0] as u8, sa_data[1] as u8]);
210
211                // Extract IP bytes
212                let ip = Ipv4Addr::new(
213                    sa_data[2] as u8,
214                    sa_data[3] as u8,
215                    sa_data[4] as u8,
216                    sa_data[5] as u8,
217                );
218
219                SocketAddrV4::new(ip, port) 
220            };
221
222            let mut index = 0;
223            // Ensures that the control buffer is the size of the CSMG_HEADER + the pkinto data
224            if control.len as usize == CONTROL_PKTINFO_BUFFER_SIZE {
225                let cmsg_header: WSACMSGHDR = unsafe { ptr::read_unaligned(control.buf as *const _) };
226                if cmsg_header.cmsg_level == IPPROTO_IP && cmsg_header.cmsg_type == IP_PKTINFO {
227                    let interface_info: IN_PKTINFO =
228                        unsafe { ptr::read_unaligned(control.buf.add(CMSG_HEADER_SIZE) as *const _) };
229                    index = interface_info.ipi_ifindex;
230                };
231            };
232
233            Ok((read_bytes as usize, origin_address, index))
234        }
235
236        pub fn send(&self, buf: &[u8], dst_addr: SocketAddrV4, iface_index: u32, source_if_addr: Ipv4Addr, socket: &Socket) -> io::Result<usize> {
237            let pkt_info = IN_PKTINFO {
238                ipi_addr: IN_ADDR {
239                    S_un: to_s_addr(&source_if_addr),
240                },
241                ipi_ifindex: iface_index,
242            };
243
244            let mut data = WSABUF {
245                buf: buf.as_ptr() as *mut _,
246                len: buf.len() as _,
247            };
248
249            let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
250            let hdr = CMSGHDR {
251                cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE,
252                cmsg_level: IPPROTO_IP,
253                cmsg_type: IP_PKTINFO,
254            };
255            unsafe {
256                ptr::copy(
257                    &hdr as *const _ as *const _,
258                    control_buffer.as_mut_ptr(),
259                    CMSG_HEADER_SIZE,
260                );
261                ptr::copy(
262                    &pkt_info as *const _ as *const _,
263                    control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE),
264                    PKTINFO_DATA_SIZE,
265                )
266            };
267            let control = WSABUF {
268                buf: control_buffer.as_mut_ptr(),
269                len: control_buffer.len() as _,
270            };
271
272            // Set custom port
273            let destination = socket2::SockAddr::from(dst_addr);
274            let destination_address = destination.as_ptr();
275            let mut wsa_msg = WSAMSG {
276                name: destination_address as *mut _,
277                namelen: destination.len(),
278                lpBuffers: &mut data,
279                Control: control,
280                dwBufferCount: 1,
281                dwFlags: 0,
282            };
283
284            let mut sent_bytes = 0;
285            let r = unsafe {
286                (self.wsasendmsg)(
287                    socket.as_raw_socket() as _,
288                    &mut wsa_msg,
289                    0,
290                    &mut sent_bytes,
291                    ptr::null_mut(),
292                    None,
293                )
294            };
295            if r != 0 {
296                return Err(io::Error::last_os_error());
297            }
298
299            Ok(sent_bytes as _)
300        }
301    }
302}
303
304impl MultiInterfaceSocket {
305    pub fn bind_any() -> Result<Self> {
306        Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
307    }
308    
309    pub fn bind_port(port: u16) -> Result<Self> {
310        Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port))
311    }
312    pub fn bind(addr: SocketAddrV4) -> Result<Self> {
313        let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
314        socket.bind(&addr.into())?;
315
316        #[cfg(unix)]
317        use std::os::fd::AsFd;
318        #[cfg(unix)]
319        socket::setsockopt(&socket.as_fd(), socket::sockopt::Ipv4PacketInfo, &true)
320            .map_err(nix_to_io_error)?;
321
322        #[cfg(windows)]
323        let wsa_structs = win_helper::win_init(&socket)?;
324
325        Ok(Self {
326            socket,
327            #[cfg(windows)]
328            wsa_structs
329        })
330    }
331    
332    pub fn get_bind_addr(&self) -> Result<SocketAddrV4> {
333        let addr = self.socket.local_addr()?;
334        if let Some(addr) = addr.as_socket_ipv4() {
335            Ok(addr)
336        } else {
337            Err(io::Error::other("Not an IPv4 address"))
338        }
339    }
340    
341    /// Join a multicast group on provided interface. 
342    pub fn join_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
343        self.socket.join_multicast_v4(&addr, &interface)
344    }
345
346    /// Leave a multicast group on provided interface. 
347    pub fn leave_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
348        self.socket.leave_multicast_v4(&addr, &interface)
349    }
350    
351    /// Nonblocking mode will cause read operations to return immediately with an error if no data is available.
352    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
353        self.socket.set_nonblocking(nonblocking)
354    }
355    /// Assign a timeout to read operations on the socket.
356    pub fn set_read_timeout(&self, timeout: std::time::Duration) -> Result<()> {
357        self.socket.set_read_timeout(Some(timeout))
358    }
359    /// When socket is non-blocking, this option will cause the read operation to block indefinitely until data is available.
360    pub fn set_read_timeout_inf(&self) -> Result<()> {
361        self.socket.set_read_timeout(None)
362    }
363
364    /// `recvfrom`, but with interface index.
365    #[cfg(unix)]
366    pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
367        use std::os::fd::AsRawFd;
368        
369        let mut control_buffer = nix::cmsg_space!(nix::libc::in_pktinfo);
370        let mut bufs = [IoSliceMut::new(buf)];
371        let message: socket::RecvMsg<socket::SockaddrIn> = socket::recvmsg(
372            self.socket.as_raw_fd(),
373            &mut bufs,
374            Some(&mut control_buffer),
375            socket::MsgFlags::empty(),
376        )
377            .map_err(nix_to_io_error)?;
378
379        let dst_addr = message.address.map(|a| SocketAddrV4::new(a.ip(), a.port()))
380            .unwrap_or(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0));
381        let sz = message.bytes;
382
383        let mut index = 0;
384        for cmsg in message.cmsgs()? {
385            if let socket::ControlMessageOwned::Ipv4PacketInfo(pkt_info) = cmsg {
386                index = pkt_info.ipi_ifindex as u32;
387                break;
388            }
389        }
390        Ok((&mut buf[..sz], dst_addr, index))
391    }
392
393    /// `recvfrom`, but with interface index.
394    #[cfg(windows)]
395    pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
396        let (sz, addr, iface) = self.wsa_structs.receive(buf, &self.socket)?;
397        Ok((&mut buf[..sz], addr, iface))
398    }
399    #[cfg(unix)]
400    pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, _source_if_addr: IpAddr) -> Result<usize> {
401        use std::os::fd::AsRawFd;
402        
403        let bufs = [IoSlice::new(buf)];
404        let mut pkt_info: nix::libc::in_pktinfo = unsafe { std::mem::zeroed() };
405        pkt_info.ipi_ifindex = iface_index as i32;
406
407        socket::sendmsg(
408            self.socket.as_raw_fd(),
409            &bufs,
410            &[socket::ControlMessage::Ipv4PacketInfo(&pkt_info)],
411            socket::MsgFlags::empty(),
412            Some(&socket::SockaddrIn::from(addr)),
413        )
414            .map_err(nix_to_io_error)
415    }
416
417    #[cfg(windows)]
418    pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, source_if_addr: IpAddr) -> Result<usize> {
419        if let IpAddr::V4(source_ip_addr) = source_if_addr {
420            self.wsa_structs.send(buf, addr, iface_index, source_ip_addr, &self.socket)
421        }
422        else {
423            Err(io::Error::other("Not an IPv4 address"))
424        }
425    }
426}