multicast_socket/
win.rs

1use std::collections::{HashMap, HashSet};
2use std::ffi::CStr;
3use std::io;
4use std::iter::FromIterator;
5use std::mem;
6use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
7use std::os::windows::prelude::*;
8use std::ptr;
9use std::str::FromStr;
10
11use socket2::{Domain, Protocol, Socket, Type};
12
13use winapi::ctypes::{c_char, c_int};
14use winapi::shared::inaddr::*;
15use winapi::shared::minwindef::DWORD;
16use winapi::shared::minwindef::{INT, LPDWORD};
17use winapi::shared::winerror::ERROR_BUFFER_OVERFLOW;
18use winapi::shared::ws2def::LPWSAMSG;
19use winapi::shared::ws2def::*;
20use winapi::shared::ws2ipdef::*;
21use winapi::um::iptypes;
22use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG};
23use winapi::um::winsock2 as sock;
24use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET};
25
26fn last_error() -> io::Error {
27    io::Error::from_raw_os_error(unsafe { sock::WSAGetLastError() })
28}
29
30unsafe fn setsockopt<T>(socket: RawSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()>
31where
32    T: Copy,
33{
34    let payload = &payload as *const T as *const c_char;
35    if sock::setsockopt(socket as _, opt, val, payload, mem::size_of::<T>() as c_int) == 0 {
36        Ok(())
37    } else {
38        Err(last_error())
39    }
40}
41
42type WSARecvMsgExtension = unsafe extern "system" fn(
43    s: SOCKET,
44    lpMsg: LPWSAMSG,
45    lpdwNumberOfBytesRecvd: LPDWORD,
46    lpOverlapped: LPWSAOVERLAPPED,
47    lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
48) -> INT;
49
50fn locate_wsarecvmsg(socket: RawSocket) -> io::Result<WSARecvMsgExtension> {
51    let mut fn_pointer: usize = 0;
52    let mut byte_len: u32 = 0;
53
54    let r = unsafe {
55        sock::WSAIoctl(
56            socket as _,
57            SIO_GET_EXTENSION_FUNCTION_POINTER,
58            &WSAID_WSARECVMSG as *const _ as *mut _,
59            mem::size_of_val(&WSAID_WSARECVMSG) as DWORD,
60            &mut fn_pointer as *const _ as *mut _,
61            mem::size_of_val(&fn_pointer) as DWORD,
62            &mut byte_len,
63            ptr::null_mut(),
64            None,
65        )
66    };
67    if r != 0 {
68        return Err(io::Error::last_os_error());
69    }
70
71    if mem::size_of::<LPFN_WSARECVMSG>() != byte_len as _ {
72        return Err(io::Error::new(
73            io::ErrorKind::Other,
74            "Locating fn pointer to WSARecvMsg returned different expected bytes",
75        ));
76    }
77    let cast_to_fn: LPFN_WSARECVMSG = unsafe { mem::transmute(fn_pointer) };
78
79    match cast_to_fn {
80        None => Err(io::Error::new(
81            io::ErrorKind::Other,
82            "WSARecvMsg extension not foud",
83        )),
84        Some(extension) => Ok(extension),
85    }
86}
87
88type WSASendMsgExtension = unsafe extern "system" fn(
89    s: SOCKET,
90    lpMsg: LPWSAMSG,
91    dwFlags: DWORD,
92    lpNumberOfBytesSent: LPDWORD,
93    lpOverlapped: LPWSAOVERLAPPED,
94    lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
95) -> INT;
96
97fn locate_wsasendmsg(socket: RawSocket) -> io::Result<WSASendMsgExtension> {
98    let mut fn_pointer: usize = 0;
99    let mut byte_len: u32 = 0;
100
101    let r = unsafe {
102        sock::WSAIoctl(
103            socket as _,
104            SIO_GET_EXTENSION_FUNCTION_POINTER,
105            &WSAID_WSASENDMSG as *const _ as *mut _,
106            mem::size_of_val(&WSAID_WSASENDMSG) as DWORD,
107            &mut fn_pointer as *const _ as *mut _,
108            mem::size_of_val(&fn_pointer) as DWORD,
109            &mut byte_len,
110            ptr::null_mut(),
111            None,
112        )
113    };
114    if r != 0 {
115        return Err(io::Error::last_os_error());
116    }
117
118    if mem::size_of::<LPFN_WSASENDMSG>() != byte_len as _ {
119        return Err(io::Error::new(
120            io::ErrorKind::Other,
121            "Locating fn pointer to WSASendMsg returned different expected bytes",
122        ));
123    }
124    let cast_to_fn: LPFN_WSASENDMSG = unsafe { mem::transmute(fn_pointer) };
125
126    match cast_to_fn {
127        None => Err(io::Error::new(
128            io::ErrorKind::Other,
129            "WSASendMsg extension not foud",
130        )),
131        Some(extension) => Ok(extension),
132    }
133}
134
135fn set_pktinfo(socket: RawSocket, payload: bool) -> io::Result<()> {
136    unsafe { setsockopt(socket, IPPROTO_IP, IP_PKTINFO, payload as c_int) }
137}
138
139fn create_on_interfaces(
140    options: crate::MulticastOptions,
141    interfaces: Vec<Ipv4Addr>,
142    multicast_address: SocketAddrV4,
143) -> io::Result<MulticastSocket> {
144    let socket = Socket::new(Domain::ipv4(), Type::dgram(), Some(Protocol::udp()))?;
145    socket.set_read_timeout(options.read_timeout)?;
146    socket.set_multicast_loop_v4(options.loopback)?;
147    socket.set_reuse_address(true)?;
148
149    // enable fetching interface information and locate the extension function
150    set_pktinfo(socket.as_raw_socket(), true)?;
151    let wsarecvmsg: WSARecvMsgExtension = locate_wsarecvmsg(socket.as_raw_socket())?;
152    let wsasendmsg: WSASendMsgExtension = locate_wsasendmsg(socket.as_raw_socket())?;
153
154    // Join multicast listeners on every interface passed
155    for interface in &interfaces {
156        socket.join_multicast_v4(multicast_address.ip(), &interface)?;
157    }
158
159    // On Windows, unlike all Unix variants, it is improper to bind to the multicast address
160    // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms737550(v=vs.85).aspx
161    socket.bind(&SocketAddr::new(options.bind_address.into(), multicast_address.port()).into())?;
162
163    let interfaces = build_address_table(HashSet::from_iter(interfaces))?;
164
165    Ok(MulticastSocket {
166        socket,
167        wsarecvmsg,
168        wsasendmsg,
169        interfaces,
170        multicast_address,
171        buffer_size: options.buffer_size,
172    })
173}
174
175fn build_address_table(interfaces: HashSet<Ipv4Addr>) -> io::Result<HashMap<u32, Ipv4Addr>> {
176    let mut size = 0u32;
177    let r = unsafe { winapi::um::iphlpapi::GetAdaptersInfo(ptr::null_mut(), &mut size) };
178    if r != ERROR_BUFFER_OVERFLOW {
179        return Err(io::Error::last_os_error());
180    }
181
182    let mut buffer = vec![0; mem::size_of::<iptypes::IP_ADAPTER_INFO>() * (size as usize)];
183    let mut adapter_info = buffer.as_mut_ptr() as iptypes::PIP_ADAPTER_INFO;
184    let mut size = buffer.len() as u32;
185    let r = unsafe { winapi::um::iphlpapi::GetAdaptersInfo(adapter_info, &mut size) };
186
187    if r != 0 {
188        return Err(io::Error::last_os_error());
189    }
190
191    let mut table = HashMap::with_capacity(interfaces.len());
192    loop {
193        if adapter_info.is_null() {
194            break;
195        }
196
197        let current: iptypes::IP_ADAPTER_INFO = unsafe { *adapter_info };
198        let ip_address =
199            unsafe { CStr::from_ptr(current.IpAddressList.IpAddress.String.as_ptr()) }.to_str();
200        let ip_address = match ip_address {
201            Ok(i) => Ipv4Addr::from_str(&i),
202            _ => {
203                continue;
204            }
205        };
206        let ip_address = match ip_address {
207            Ok(i) => i,
208            _ => {
209                continue;
210            }
211        };
212
213        if interfaces.contains(&ip_address) {
214            table.insert(current.Index, ip_address);
215        }
216
217        adapter_info = current.Next;
218    }
219
220    Ok(table)
221}
222
223pub struct MulticastSocket {
224    socket: socket2::Socket,
225    wsarecvmsg: WSARecvMsgExtension,
226    wsasendmsg: WSASendMsgExtension,
227    interfaces: HashMap<u32, Ipv4Addr>,
228    multicast_address: SocketAddrV4,
229    buffer_size: usize,
230}
231
232#[derive(Debug, Clone)]
233pub enum Interface {
234    Default,
235    Ip(Ipv4Addr),
236    Index(u32),
237}
238
239#[derive(Debug, Clone)]
240pub struct Message {
241    pub data: Vec<u8>,
242    pub origin_address: SocketAddrV4,
243    pub interface: Interface,
244}
245
246const CMSG_HEADER_SIZE: usize = mem::size_of::<WSACMSGHDR>();
247const PKTINFO_DATA_SIZE: usize = mem::size_of::<IN_PKTINFO>();
248const CONTROL_PKTINFO_BUFFER_SIZE: usize = CMSG_HEADER_SIZE + PKTINFO_DATA_SIZE;
249
250pub fn all_ipv4_interfaces() -> io::Result<Vec<Ipv4Addr>> {
251    let interfaces = if_addrs::get_if_addrs()?
252        .into_iter()
253        .filter_map(|i| match i.ip() {
254            std::net::IpAddr::V4(v4) => Some(v4),
255            _ => None,
256        })
257        .collect();
258    Ok(interfaces)
259}
260
261impl MulticastSocket {
262    pub fn all_interfaces(multicast_address: SocketAddrV4) -> io::Result<Self> {
263        let interfaces = all_ipv4_interfaces()?;
264        create_on_interfaces(Default::default(), interfaces, multicast_address)
265    }
266
267    pub fn with_options(
268        multicast_address: SocketAddrV4,
269        interfaces: Vec<Ipv4Addr>,
270        options: crate::MulticastOptions,
271    ) -> io::Result<Self> {
272        create_on_interfaces(options, interfaces, multicast_address)
273    }
274}
275
276impl MulticastSocket {
277    pub fn receive(&self) -> io::Result<Message> {
278        let mut data_buffer = vec![0; self.buffer_size];
279        let mut data = WSABUF {
280            buf: data_buffer.as_mut_ptr(),
281            len: data_buffer.len() as u32,
282        };
283
284        let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
285        let control = WSABUF {
286            buf: control_buffer.as_mut_ptr(),
287            len: control_buffer.len() as u32,
288        };
289
290        let mut origin_address: SOCKADDR = unsafe { mem::zeroed() };
291        let mut wsa_msg = WSAMSG {
292            name: &mut origin_address,
293            namelen: mem::size_of_val(&origin_address) as i32,
294            lpBuffers: &mut data,
295            Control: control,
296            dwBufferCount: 1,
297            dwFlags: 0,
298        };
299
300        let mut read_bytes = 0;
301        let r = {
302            unsafe {
303                (self.wsarecvmsg)(
304                    self.socket.as_raw_socket() as _,
305                    &mut wsa_msg,
306                    &mut read_bytes,
307                    ptr::null_mut(),
308                    None,
309                )
310            }
311        };
312
313        if r != 0 {
314            return Err(io::Error::last_os_error());
315        }
316
317        let origin_address = unsafe {
318            socket2::SockAddr::from_raw_parts(
319                &origin_address,
320                mem::size_of_val(&origin_address) as i32,
321            )
322        }
323        .as_std();
324
325        let origin_address = match origin_address {
326            Some(SocketAddr::V4(v4)) => v4,
327            _ => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0),
328        };
329
330        let mut interface = Interface::Default;
331        // Ensures that the control buffer is the size of the CSMG_HEADER + the pkinto data
332        if control.len as usize == CONTROL_PKTINFO_BUFFER_SIZE {
333            let cmsg_header: WSACMSGHDR = unsafe { ptr::read_unaligned(control.buf as *const _) }; // TODO fix clippy warning without breaking the code
334            if cmsg_header.cmsg_level == IPPROTO_IP && cmsg_header.cmsg_type == IP_PKTINFO {
335                let interface_info: IN_PKTINFO =
336                    unsafe { ptr::read_unaligned(control.buf.add(CMSG_HEADER_SIZE) as *const _) }; // TODO fix clippy warning without breaking the code
337                interface = Interface::Index(interface_info.ipi_ifindex);
338            };
339        };
340
341        Ok(Message {
342            data: data_buffer[0..read_bytes as _]
343                .iter()
344                .map(|i| *i as u8)
345                .collect(),
346            origin_address,
347            interface,
348        })
349    }
350
351    pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result<usize> {
352        let pkt_info = match interface {
353            Interface::Default => None,
354            Interface::Ip(address) => Some(IN_PKTINFO {
355                ipi_addr: IN_ADDR {
356                    S_un: to_s_addr(address),
357                },
358                ipi_ifindex: 0,
359            }),
360            Interface::Index(index) => self.interfaces.get(index).map(|address| IN_PKTINFO {
361                ipi_addr: IN_ADDR {
362                    S_un: to_s_addr(address),
363                },
364                ipi_ifindex: *index,
365            }),
366        };
367
368        let mut data = WSABUF {
369            buf: buf.as_ptr() as *mut _,
370            len: buf.len() as _,
371        };
372
373        let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
374        let control = if let Some(pkt_info) = pkt_info {
375            let hdr = CMSGHDR {
376                cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE,
377                cmsg_level: IPPROTO_IP,
378                cmsg_type: IP_PKTINFO,
379            };
380            unsafe {
381                ptr::copy(
382                    &hdr as *const _ as *const _,
383                    control_buffer.as_mut_ptr(),
384                    CMSG_HEADER_SIZE,
385                );
386                ptr::copy(
387                    &pkt_info as *const _ as *const _,
388                    control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE),
389                    PKTINFO_DATA_SIZE,
390                )
391            };
392            WSABUF {
393                buf: control_buffer.as_mut_ptr(),
394                len: control_buffer.len() as _,
395            }
396        } else {
397            WSABUF {
398                buf: [].as_mut_ptr(),
399                len: 0,
400            }
401        };
402
403        let destination = socket2::SockAddr::from(self.multicast_address);
404        let destination_address = destination.as_ptr();
405        let mut wsa_msg = WSAMSG {
406            name: destination_address as *mut _,
407            namelen: destination.len(),
408            lpBuffers: &mut data,
409            Control: control,
410            dwBufferCount: 1,
411            dwFlags: 0,
412        };
413
414        let mut sent_bytes = 0;
415        let r = unsafe {
416            (self.wsasendmsg)(
417                self.socket.as_raw_socket() as _,
418                &mut wsa_msg,
419                0,
420                &mut sent_bytes,
421                ptr::null_mut(),
422                None,
423            )
424        };
425        if r != 0 {
426            return Err(io::Error::last_os_error());
427        }
428
429        Ok(sent_bytes as _)
430    }
431
432    pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> {
433        for interface in self.interfaces.values() {
434            self.send(buf, &Interface::Ip(*interface))?;
435        }
436        Ok(())
437    }
438}
439
440fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un {
441    let octets = addr.octets();
442    let res = u32::from_ne_bytes(octets);
443    let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() };
444    unsafe { *(new_addr.S_addr_mut()) = res };
445    new_addr
446}