1use std::collections::{HashMap, HashSet};
2use std::ffi::CStr;
3use std::io;
4use std::iter::FromIterator;
5use std::mem;
6use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
7use std::os::windows::prelude::*;
8use std::ptr;
9use std::str::FromStr;
10
11use socket2::{Domain, Protocol, Socket, Type};
12
13use winapi::ctypes::{c_char, c_int};
14use winapi::shared::inaddr::*;
15use winapi::shared::minwindef::DWORD;
16use winapi::shared::minwindef::{INT, LPDWORD};
17use winapi::shared::winerror::ERROR_BUFFER_OVERFLOW;
18use winapi::shared::ws2def::LPWSAMSG;
19use winapi::shared::ws2def::*;
20use winapi::shared::ws2ipdef::*;
21use winapi::um::iptypes;
22use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG};
23use winapi::um::winsock2 as sock;
24use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET};
25
26fn last_error() -> io::Error {
27 io::Error::from_raw_os_error(unsafe { sock::WSAGetLastError() })
28}
29
30unsafe fn setsockopt<T>(socket: RawSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()>
31where
32 T: Copy,
33{
34 let payload = &payload as *const T as *const c_char;
35 if sock::setsockopt(socket as _, opt, val, payload, mem::size_of::<T>() as c_int) == 0 {
36 Ok(())
37 } else {
38 Err(last_error())
39 }
40}
41
42type WSARecvMsgExtension = unsafe extern "system" fn(
43 s: SOCKET,
44 lpMsg: LPWSAMSG,
45 lpdwNumberOfBytesRecvd: LPDWORD,
46 lpOverlapped: LPWSAOVERLAPPED,
47 lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
48) -> INT;
49
50fn locate_wsarecvmsg(socket: RawSocket) -> io::Result<WSARecvMsgExtension> {
51 let mut fn_pointer: usize = 0;
52 let mut byte_len: u32 = 0;
53
54 let r = unsafe {
55 sock::WSAIoctl(
56 socket as _,
57 SIO_GET_EXTENSION_FUNCTION_POINTER,
58 &WSAID_WSARECVMSG as *const _ as *mut _,
59 mem::size_of_val(&WSAID_WSARECVMSG) as DWORD,
60 &mut fn_pointer as *const _ as *mut _,
61 mem::size_of_val(&fn_pointer) as DWORD,
62 &mut byte_len,
63 ptr::null_mut(),
64 None,
65 )
66 };
67 if r != 0 {
68 return Err(io::Error::last_os_error());
69 }
70
71 if mem::size_of::<LPFN_WSARECVMSG>() != byte_len as _ {
72 return Err(io::Error::new(
73 io::ErrorKind::Other,
74 "Locating fn pointer to WSARecvMsg returned different expected bytes",
75 ));
76 }
77 let cast_to_fn: LPFN_WSARECVMSG = unsafe { mem::transmute(fn_pointer) };
78
79 match cast_to_fn {
80 None => Err(io::Error::new(
81 io::ErrorKind::Other,
82 "WSARecvMsg extension not foud",
83 )),
84 Some(extension) => Ok(extension),
85 }
86}
87
88type WSASendMsgExtension = unsafe extern "system" fn(
89 s: SOCKET,
90 lpMsg: LPWSAMSG,
91 dwFlags: DWORD,
92 lpNumberOfBytesSent: LPDWORD,
93 lpOverlapped: LPWSAOVERLAPPED,
94 lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
95) -> INT;
96
97fn locate_wsasendmsg(socket: RawSocket) -> io::Result<WSASendMsgExtension> {
98 let mut fn_pointer: usize = 0;
99 let mut byte_len: u32 = 0;
100
101 let r = unsafe {
102 sock::WSAIoctl(
103 socket as _,
104 SIO_GET_EXTENSION_FUNCTION_POINTER,
105 &WSAID_WSASENDMSG as *const _ as *mut _,
106 mem::size_of_val(&WSAID_WSASENDMSG) as DWORD,
107 &mut fn_pointer as *const _ as *mut _,
108 mem::size_of_val(&fn_pointer) as DWORD,
109 &mut byte_len,
110 ptr::null_mut(),
111 None,
112 )
113 };
114 if r != 0 {
115 return Err(io::Error::last_os_error());
116 }
117
118 if mem::size_of::<LPFN_WSASENDMSG>() != byte_len as _ {
119 return Err(io::Error::new(
120 io::ErrorKind::Other,
121 "Locating fn pointer to WSASendMsg returned different expected bytes",
122 ));
123 }
124 let cast_to_fn: LPFN_WSASENDMSG = unsafe { mem::transmute(fn_pointer) };
125
126 match cast_to_fn {
127 None => Err(io::Error::new(
128 io::ErrorKind::Other,
129 "WSASendMsg extension not foud",
130 )),
131 Some(extension) => Ok(extension),
132 }
133}
134
135fn set_pktinfo(socket: RawSocket, payload: bool) -> io::Result<()> {
136 unsafe { setsockopt(socket, IPPROTO_IP, IP_PKTINFO, payload as c_int) }
137}
138
139fn create_on_interfaces(
140 options: crate::MulticastOptions,
141 interfaces: Vec<Ipv4Addr>,
142 multicast_address: SocketAddrV4,
143) -> io::Result<MulticastSocket> {
144 let socket = Socket::new(Domain::ipv4(), Type::dgram(), Some(Protocol::udp()))?;
145 socket.set_read_timeout(options.read_timeout)?;
146 socket.set_multicast_loop_v4(options.loopback)?;
147 socket.set_reuse_address(true)?;
148
149 set_pktinfo(socket.as_raw_socket(), true)?;
151 let wsarecvmsg: WSARecvMsgExtension = locate_wsarecvmsg(socket.as_raw_socket())?;
152 let wsasendmsg: WSASendMsgExtension = locate_wsasendmsg(socket.as_raw_socket())?;
153
154 for interface in &interfaces {
156 socket.join_multicast_v4(multicast_address.ip(), &interface)?;
157 }
158
159 socket.bind(&SocketAddr::new(options.bind_address.into(), multicast_address.port()).into())?;
162
163 let interfaces = build_address_table(HashSet::from_iter(interfaces))?;
164
165 Ok(MulticastSocket {
166 socket,
167 wsarecvmsg,
168 wsasendmsg,
169 interfaces,
170 multicast_address,
171 buffer_size: options.buffer_size,
172 })
173}
174
175fn build_address_table(interfaces: HashSet<Ipv4Addr>) -> io::Result<HashMap<u32, Ipv4Addr>> {
176 let mut size = 0u32;
177 let r = unsafe { winapi::um::iphlpapi::GetAdaptersInfo(ptr::null_mut(), &mut size) };
178 if r != ERROR_BUFFER_OVERFLOW {
179 return Err(io::Error::last_os_error());
180 }
181
182 let mut buffer = vec![0; mem::size_of::<iptypes::IP_ADAPTER_INFO>() * (size as usize)];
183 let mut adapter_info = buffer.as_mut_ptr() as iptypes::PIP_ADAPTER_INFO;
184 let mut size = buffer.len() as u32;
185 let r = unsafe { winapi::um::iphlpapi::GetAdaptersInfo(adapter_info, &mut size) };
186
187 if r != 0 {
188 return Err(io::Error::last_os_error());
189 }
190
191 let mut table = HashMap::with_capacity(interfaces.len());
192 loop {
193 if adapter_info.is_null() {
194 break;
195 }
196
197 let current: iptypes::IP_ADAPTER_INFO = unsafe { *adapter_info };
198 let ip_address =
199 unsafe { CStr::from_ptr(current.IpAddressList.IpAddress.String.as_ptr()) }.to_str();
200 let ip_address = match ip_address {
201 Ok(i) => Ipv4Addr::from_str(&i),
202 _ => {
203 continue;
204 }
205 };
206 let ip_address = match ip_address {
207 Ok(i) => i,
208 _ => {
209 continue;
210 }
211 };
212
213 if interfaces.contains(&ip_address) {
214 table.insert(current.Index, ip_address);
215 }
216
217 adapter_info = current.Next;
218 }
219
220 Ok(table)
221}
222
223pub struct MulticastSocket {
224 socket: socket2::Socket,
225 wsarecvmsg: WSARecvMsgExtension,
226 wsasendmsg: WSASendMsgExtension,
227 interfaces: HashMap<u32, Ipv4Addr>,
228 multicast_address: SocketAddrV4,
229 buffer_size: usize,
230}
231
232#[derive(Debug, Clone)]
233pub enum Interface {
234 Default,
235 Ip(Ipv4Addr),
236 Index(u32),
237}
238
239#[derive(Debug, Clone)]
240pub struct Message {
241 pub data: Vec<u8>,
242 pub origin_address: SocketAddrV4,
243 pub interface: Interface,
244}
245
246const CMSG_HEADER_SIZE: usize = mem::size_of::<WSACMSGHDR>();
247const PKTINFO_DATA_SIZE: usize = mem::size_of::<IN_PKTINFO>();
248const CONTROL_PKTINFO_BUFFER_SIZE: usize = CMSG_HEADER_SIZE + PKTINFO_DATA_SIZE;
249
250pub fn all_ipv4_interfaces() -> io::Result<Vec<Ipv4Addr>> {
251 let interfaces = if_addrs::get_if_addrs()?
252 .into_iter()
253 .filter_map(|i| match i.ip() {
254 std::net::IpAddr::V4(v4) => Some(v4),
255 _ => None,
256 })
257 .collect();
258 Ok(interfaces)
259}
260
261impl MulticastSocket {
262 pub fn all_interfaces(multicast_address: SocketAddrV4) -> io::Result<Self> {
263 let interfaces = all_ipv4_interfaces()?;
264 create_on_interfaces(Default::default(), interfaces, multicast_address)
265 }
266
267 pub fn with_options(
268 multicast_address: SocketAddrV4,
269 interfaces: Vec<Ipv4Addr>,
270 options: crate::MulticastOptions,
271 ) -> io::Result<Self> {
272 create_on_interfaces(options, interfaces, multicast_address)
273 }
274}
275
276impl MulticastSocket {
277 pub fn receive(&self) -> io::Result<Message> {
278 let mut data_buffer = vec![0; self.buffer_size];
279 let mut data = WSABUF {
280 buf: data_buffer.as_mut_ptr(),
281 len: data_buffer.len() as u32,
282 };
283
284 let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
285 let control = WSABUF {
286 buf: control_buffer.as_mut_ptr(),
287 len: control_buffer.len() as u32,
288 };
289
290 let mut origin_address: SOCKADDR = unsafe { mem::zeroed() };
291 let mut wsa_msg = WSAMSG {
292 name: &mut origin_address,
293 namelen: mem::size_of_val(&origin_address) as i32,
294 lpBuffers: &mut data,
295 Control: control,
296 dwBufferCount: 1,
297 dwFlags: 0,
298 };
299
300 let mut read_bytes = 0;
301 let r = {
302 unsafe {
303 (self.wsarecvmsg)(
304 self.socket.as_raw_socket() as _,
305 &mut wsa_msg,
306 &mut read_bytes,
307 ptr::null_mut(),
308 None,
309 )
310 }
311 };
312
313 if r != 0 {
314 return Err(io::Error::last_os_error());
315 }
316
317 let origin_address = unsafe {
318 socket2::SockAddr::from_raw_parts(
319 &origin_address,
320 mem::size_of_val(&origin_address) as i32,
321 )
322 }
323 .as_std();
324
325 let origin_address = match origin_address {
326 Some(SocketAddr::V4(v4)) => v4,
327 _ => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0),
328 };
329
330 let mut interface = Interface::Default;
331 if control.len as usize == CONTROL_PKTINFO_BUFFER_SIZE {
333 let cmsg_header: WSACMSGHDR = unsafe { ptr::read_unaligned(control.buf as *const _) }; if cmsg_header.cmsg_level == IPPROTO_IP && cmsg_header.cmsg_type == IP_PKTINFO {
335 let interface_info: IN_PKTINFO =
336 unsafe { ptr::read_unaligned(control.buf.add(CMSG_HEADER_SIZE) as *const _) }; interface = Interface::Index(interface_info.ipi_ifindex);
338 };
339 };
340
341 Ok(Message {
342 data: data_buffer[0..read_bytes as _]
343 .iter()
344 .map(|i| *i as u8)
345 .collect(),
346 origin_address,
347 interface,
348 })
349 }
350
351 pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result<usize> {
352 let pkt_info = match interface {
353 Interface::Default => None,
354 Interface::Ip(address) => Some(IN_PKTINFO {
355 ipi_addr: IN_ADDR {
356 S_un: to_s_addr(address),
357 },
358 ipi_ifindex: 0,
359 }),
360 Interface::Index(index) => self.interfaces.get(index).map(|address| IN_PKTINFO {
361 ipi_addr: IN_ADDR {
362 S_un: to_s_addr(address),
363 },
364 ipi_ifindex: *index,
365 }),
366 };
367
368 let mut data = WSABUF {
369 buf: buf.as_ptr() as *mut _,
370 len: buf.len() as _,
371 };
372
373 let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
374 let control = if let Some(pkt_info) = pkt_info {
375 let hdr = CMSGHDR {
376 cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE,
377 cmsg_level: IPPROTO_IP,
378 cmsg_type: IP_PKTINFO,
379 };
380 unsafe {
381 ptr::copy(
382 &hdr as *const _ as *const _,
383 control_buffer.as_mut_ptr(),
384 CMSG_HEADER_SIZE,
385 );
386 ptr::copy(
387 &pkt_info as *const _ as *const _,
388 control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE),
389 PKTINFO_DATA_SIZE,
390 )
391 };
392 WSABUF {
393 buf: control_buffer.as_mut_ptr(),
394 len: control_buffer.len() as _,
395 }
396 } else {
397 WSABUF {
398 buf: [].as_mut_ptr(),
399 len: 0,
400 }
401 };
402
403 let destination = socket2::SockAddr::from(self.multicast_address);
404 let destination_address = destination.as_ptr();
405 let mut wsa_msg = WSAMSG {
406 name: destination_address as *mut _,
407 namelen: destination.len(),
408 lpBuffers: &mut data,
409 Control: control,
410 dwBufferCount: 1,
411 dwFlags: 0,
412 };
413
414 let mut sent_bytes = 0;
415 let r = unsafe {
416 (self.wsasendmsg)(
417 self.socket.as_raw_socket() as _,
418 &mut wsa_msg,
419 0,
420 &mut sent_bytes,
421 ptr::null_mut(),
422 None,
423 )
424 };
425 if r != 0 {
426 return Err(io::Error::last_os_error());
427 }
428
429 Ok(sent_bytes as _)
430 }
431
432 pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> {
433 for interface in self.interfaces.values() {
434 self.send(buf, &Interface::Ip(*interface))?;
435 }
436 Ok(())
437 }
438}
439
440fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un {
441 let octets = addr.octets();
442 let res = u32::from_ne_bytes(octets);
443 let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() };
444 unsafe { *(new_addr.S_addr_mut()) = res };
445 new_addr
446}