1use std::io;
2use std::io::{IoSlice, IoSliceMut, Result};
3use std::net::{IpAddr, Ipv4Addr, SocketAddrV4};
4use socket2::{Domain, Protocol, Socket, Type};
5#[cfg(unix)]
6use nix::sys::socket;
7
8pub struct MultiInterfaceSocket {
10 socket: Socket,
11 #[cfg(windows)]
12 wsa_structs: win_helper::WSAStructs
13}
14#[cfg(unix)]
15fn nix_to_io_error(e: nix::Error) -> io::Error {
16 io::Error::other(e)
17}
18
19#[cfg(windows)]
20mod win_helper {
21 use std::ffi::{c_char, c_int};
22 use std::{io, mem, ptr};
23 use std::net::{Ipv4Addr, SocketAddrV4};
24 use std::os::windows::io::RawSocket;
25 use std::os::windows::prelude::AsRawSocket;
26 use socket2::Socket;
27 use winapi::shared::guiddef::GUID;
28 use winapi::shared::inaddr::*;
29 use winapi::shared::minwindef::DWORD;
30 use winapi::shared::minwindef::{INT, LPDWORD};
31 use winapi::shared::ws2def::LPWSAMSG;
32 use winapi::shared::ws2def::*;
33 use winapi::shared::ws2ipdef::*;
34 use winapi::um::winsock2;
35 use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG};
36 use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET};
37
38 fn last_error() -> io::Error {
39 io::Error::from_raw_os_error(unsafe { winsock2::WSAGetLastError() })
40 }
41
42 unsafe fn setsockopt<T>(socket: RawSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()>
43 where
44 T: Copy,
45 {
46 let payload = &payload as *const T as *const c_char;
47 if winsock2::setsockopt(socket as _, opt, val, payload, mem::size_of::<T>() as c_int) == 0 {
48 Ok(())
49 } else {
50 Err(last_error())
51 }
52 }
53 type WSARecvMsgExtension = unsafe extern "system" fn(
54 s: SOCKET,
55 lpMsg: LPWSAMSG,
56 lpdwNumberOfBytesRecvd: LPDWORD,
57 lpOverlapped: LPWSAOVERLAPPED,
58 lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
59 ) -> INT;
60 type WSASendMsgExtension = unsafe extern "system" fn(
61 s: SOCKET,
62 lpMsg: LPWSAMSG,
63 dwFlags: DWORD,
64 lpNumberOfBytesSent: LPDWORD,
65 lpOverlapped: LPWSAOVERLAPPED,
66 lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
67 ) -> INT;
68
69 unsafe fn get_fn_pointer(socket: RawSocket, guid: GUID, fn_pointer: &mut usize, byte_len: &mut u32) -> c_int {
70 let fn_ptr = fn_pointer as *const _ as *mut _;
71 winsock2::WSAIoctl(
72 socket as _,
73 SIO_GET_EXTENSION_FUNCTION_POINTER,
74 &guid as *const _ as *mut _,
75 mem::size_of_val(&guid) as DWORD,
76 fn_ptr,
77 mem::size_of_val(&fn_ptr) as DWORD,
78 byte_len,
79 ptr::null_mut(),
80 None,
81 )
82 }
83
84 fn locate_wsarecvmsg(socket: RawSocket) -> io::Result<WSARecvMsgExtension> {
85 let mut fn_pointer: usize = 0;
86 let mut byte_len: u32 = 0;
87
88 let r = unsafe { get_fn_pointer(socket, WSAID_WSARECVMSG, &mut fn_pointer, &mut byte_len) };
89
90 if r != 0 {
91 return Err(io::Error::last_os_error());
92 }
93
94 if mem::size_of::<LPFN_WSARECVMSG>() != byte_len as _ {
95 return Err(io::Error::other("Locating fn pointer to WSARecvMsg returned different expected bytes"));
96 }
97 let cast_to_fn: LPFN_WSARECVMSG = unsafe { mem::transmute(fn_pointer) };
98
99 match cast_to_fn {
100 None => Err(io::Error::other("WSARecvMsg extension not found")),
101 Some(extension) => Ok(extension),
102 }
103 }
104
105 fn locate_wsasendmsg(socket: RawSocket) -> io::Result<WSASendMsgExtension> {
106 let mut fn_pointer: usize = 0;
107 let mut byte_len: u32 = 0;
108
109 let r = unsafe { get_fn_pointer(socket, WSAID_WSASENDMSG, &mut fn_pointer, &mut byte_len) };
110 if r != 0 {
111 return Err(io::Error::last_os_error());
112 }
113
114 if mem::size_of::<LPFN_WSASENDMSG>() != byte_len as _ {
115 return Err(io::Error::other("Locating fn pointer to WSASendMsg returned different expected bytes"));
116 }
117 let cast_to_fn: LPFN_WSASENDMSG = unsafe { mem::transmute(fn_pointer) };
118
119 match cast_to_fn {
120 None => Err(io::Error::other("WSASendMsg extension not found",
121 )),
122 Some(extension) => Ok(extension),
123 }
124 }
125 pub struct WSAStructs {
126 wsarecvmsg: WSARecvMsgExtension,
127 wsasendmsg: WSASendMsgExtension,
128 }
129
130
131 fn set_pktinfo(socket: RawSocket, payload: bool) -> io::Result<()> {
132 unsafe { setsockopt(socket, IPPROTO_IP, IP_PKTINFO, payload as c_int) }
133 }
134
135 fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un {
136 let octets = addr.octets();
137 let res = u32::from_ne_bytes(octets);
138 let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() };
139 unsafe { *(new_addr.S_addr_mut()) = res };
140 new_addr
141 }
142
143 const CMSG_HEADER_SIZE: usize = size_of::<WSACMSGHDR>();
144 const PKTINFO_DATA_SIZE: usize = size_of::<IN_PKTINFO>();
145 const CONTROL_PKTINFO_BUFFER_SIZE: usize = CMSG_HEADER_SIZE + PKTINFO_DATA_SIZE;
146
147 pub fn win_init(
148 socket: &Socket
149 ) -> io::Result<WSAStructs> {
150
151 set_pktinfo(socket.as_raw_socket(), true)?;
153 let wsarecvmsg: WSARecvMsgExtension = locate_wsarecvmsg(socket.as_raw_socket())?;
154 let wsasendmsg: WSASendMsgExtension = locate_wsasendmsg(socket.as_raw_socket())?;
155
156 Ok(WSAStructs {
157 wsarecvmsg,
158 wsasendmsg
159 })
160 }
161
162 impl WSAStructs {
163 pub fn receive(&self, data_buffer: &mut [u8], socket: &Socket) -> io::Result<(usize, SocketAddrV4, u32)> {
164 let mut data = WSABUF {
165 buf: data_buffer.as_mut_ptr() as *mut i8,
166 len: data_buffer.len() as u32,
167 };
168
169 let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
170 let control = WSABUF {
171 buf: control_buffer.as_mut_ptr(),
172 len: control_buffer.len() as u32,
173 };
174
175 let mut origin_address: SOCKADDR = unsafe { mem::zeroed() };
176 let mut wsa_msg = WSAMSG {
177 name: &mut origin_address,
178 namelen: mem::size_of_val(&origin_address) as i32,
179 lpBuffers: &mut data,
180 Control: control,
181 dwBufferCount: 1,
182 dwFlags: 0,
183 };
184
185 let mut read_bytes = 0;
186 let r = {
187 unsafe {
188 (self.wsarecvmsg)(
189 socket.as_raw_socket() as _,
190 &mut wsa_msg,
191 &mut read_bytes,
192 ptr::null_mut(),
193 None,
194 )
195 }
196 };
197
198 if r != 0 {
199 return Err(io::Error::last_os_error());
200 }
201
202 let origin_address = if origin_address.sa_family != AF_INET as ADDRESS_FAMILY {
203 SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)
204 }
205 else {
206 let sa_data = origin_address.sa_data;
207
208 let port = u16::from_be_bytes([sa_data[0] as u8, sa_data[1] as u8]);
210
211 let ip = Ipv4Addr::new(
213 sa_data[2] as u8,
214 sa_data[3] as u8,
215 sa_data[4] as u8,
216 sa_data[5] as u8,
217 );
218
219 SocketAddrV4::new(ip, port)
220 };
221
222 let mut index = 0;
223 if control.len as usize == CONTROL_PKTINFO_BUFFER_SIZE {
225 let cmsg_header: WSACMSGHDR = unsafe { ptr::read_unaligned(control.buf as *const _) };
226 if cmsg_header.cmsg_level == IPPROTO_IP && cmsg_header.cmsg_type == IP_PKTINFO {
227 let interface_info: IN_PKTINFO =
228 unsafe { ptr::read_unaligned(control.buf.add(CMSG_HEADER_SIZE) as *const _) };
229 index = interface_info.ipi_ifindex;
230 };
231 };
232
233 Ok((read_bytes as usize, origin_address, index))
234 }
235
236 pub fn send(&self, buf: &[u8], dst_addr: SocketAddrV4, iface_index: u32, source_if_addr: Ipv4Addr, socket: &Socket) -> io::Result<usize> {
237 let pkt_info = IN_PKTINFO {
238 ipi_addr: IN_ADDR {
239 S_un: to_s_addr(&source_if_addr),
240 },
241 ipi_ifindex: iface_index,
242 };
243
244 let mut data = WSABUF {
245 buf: buf.as_ptr() as *mut _,
246 len: buf.len() as _,
247 };
248
249 let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
250 let hdr = CMSGHDR {
251 cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE,
252 cmsg_level: IPPROTO_IP,
253 cmsg_type: IP_PKTINFO,
254 };
255 unsafe {
256 ptr::copy(
257 &hdr as *const _ as *const _,
258 control_buffer.as_mut_ptr(),
259 CMSG_HEADER_SIZE,
260 );
261 ptr::copy(
262 &pkt_info as *const _ as *const _,
263 control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE),
264 PKTINFO_DATA_SIZE,
265 )
266 };
267 let control = WSABUF {
268 buf: control_buffer.as_mut_ptr(),
269 len: control_buffer.len() as _,
270 };
271
272 let destination = socket2::SockAddr::from(dst_addr);
274 let destination_address = destination.as_ptr();
275 let mut wsa_msg = WSAMSG {
276 name: destination_address as *mut _,
277 namelen: destination.len(),
278 lpBuffers: &mut data,
279 Control: control,
280 dwBufferCount: 1,
281 dwFlags: 0,
282 };
283
284 let mut sent_bytes = 0;
285 let r = unsafe {
286 (self.wsasendmsg)(
287 socket.as_raw_socket() as _,
288 &mut wsa_msg,
289 0,
290 &mut sent_bytes,
291 ptr::null_mut(),
292 None,
293 )
294 };
295 if r != 0 {
296 return Err(io::Error::last_os_error());
297 }
298
299 Ok(sent_bytes as _)
300 }
301 }
302}
303
304impl MultiInterfaceSocket {
305 pub fn bind_any() -> Result<Self> {
306 Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
307 }
308
309 pub fn bind_port(port: u16) -> Result<Self> {
310 Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port))
311 }
312 pub fn bind(addr: SocketAddrV4) -> Result<Self> {
313 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
314 socket.bind(&addr.into())?;
315
316 #[cfg(unix)]
317 use std::os::fd::AsFd;
318 #[cfg(unix)]
319 socket::setsockopt(&socket.as_fd(), socket::sockopt::Ipv4PacketInfo, &true)
320 .map_err(nix_to_io_error)?;
321
322 #[cfg(windows)]
323 let wsa_structs = win_helper::win_init(&socket)?;
324
325 Ok(Self {
326 socket,
327 #[cfg(windows)]
328 wsa_structs
329 })
330 }
331
332 pub fn get_bind_addr(&self) -> Result<SocketAddrV4> {
333 let addr = self.socket.local_addr()?;
334 if let Some(addr) = addr.as_socket_ipv4() {
335 Ok(addr)
336 } else {
337 Err(io::Error::other("Not an IPv4 address"))
338 }
339 }
340
341 pub fn join_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
343 self.socket.join_multicast_v4(&addr, &interface)
344 }
345
346 pub fn leave_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
348 self.socket.leave_multicast_v4(&addr, &interface)
349 }
350
351 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
353 self.socket.set_nonblocking(nonblocking)
354 }
355 pub fn set_read_timeout(&self, timeout: std::time::Duration) -> Result<()> {
357 self.socket.set_read_timeout(Some(timeout))
358 }
359 pub fn set_read_timeout_inf(&self) -> Result<()> {
361 self.socket.set_read_timeout(None)
362 }
363
364 #[cfg(unix)]
366 pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
367 use std::os::fd::AsRawFd;
368
369 let mut control_buffer = nix::cmsg_space!(nix::libc::in_pktinfo);
370 let mut bufs = [IoSliceMut::new(buf)];
371 let message: socket::RecvMsg<socket::SockaddrIn> = socket::recvmsg(
372 self.socket.as_raw_fd(),
373 &mut bufs,
374 Some(&mut control_buffer),
375 socket::MsgFlags::empty(),
376 )
377 .map_err(nix_to_io_error)?;
378
379 let dst_addr = message.address.map(|a| SocketAddrV4::new(a.ip(), a.port()))
380 .unwrap_or(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0));
381 let sz = message.bytes;
382
383 let mut index = 0;
384 for cmsg in message.cmsgs()? {
385 if let socket::ControlMessageOwned::Ipv4PacketInfo(pkt_info) = cmsg {
386 index = pkt_info.ipi_ifindex as u32;
387 break;
388 }
389 }
390 Ok((&mut buf[..sz], dst_addr, index))
391 }
392
393 #[cfg(windows)]
395 pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
396 let (sz, addr, iface) = self.wsa_structs.receive(buf, &self.socket)?;
397 Ok((&mut buf[..sz], addr, iface))
398 }
399 #[cfg(unix)]
400 pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, _source_if_addr: IpAddr) -> Result<usize> {
401 use std::os::fd::AsRawFd;
402
403 let bufs = [IoSlice::new(buf)];
404 let mut pkt_info: nix::libc::in_pktinfo = unsafe { std::mem::zeroed() };
405 pkt_info.ipi_ifindex = iface_index as i32;
406
407 socket::sendmsg(
408 self.socket.as_raw_fd(),
409 &bufs,
410 &[socket::ControlMessage::Ipv4PacketInfo(&pkt_info)],
411 socket::MsgFlags::empty(),
412 Some(&socket::SockaddrIn::from(addr)),
413 )
414 .map_err(nix_to_io_error)
415 }
416
417 #[cfg(windows)]
418 pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, source_if_addr: IpAddr) -> Result<usize> {
419 if let IpAddr::V4(source_ip_addr) = source_if_addr {
420 self.wsa_structs.send(buf, addr, iface_index, source_ip_addr, &self.socket)
421 }
422 else {
423 Err(io::Error::other("Not an IPv4 address"))
424 }
425 }
426}