multicast_discovery_socket/
socket.rs1use std::io;
2use std::io::{IoSlice, IoSliceMut, Result};
3use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
4use socket2::{Domain, Protocol, SockAddr, Socket, Type};
5#[cfg(depend_nix)]
6use nix::sys::socket;
7
8pub struct MultiInterfaceSocket {
10 socket: Socket,
11 #[cfg(windows)]
12 wsa_structs: win_helper::WSAStructs
13}
14#[cfg(depend_nix)]
15fn nix_to_io_error(e: nix::Error) -> io::Error {
16 io::Error::other(e)
17}
18
19#[cfg(windows)]
20#[path = "./win_specific.rs"]
21mod win_helper;
22
23impl MultiInterfaceSocket {
24 pub fn bind_any() -> Result<Self> {
25 Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
26 }
27
28 pub fn bind_port(port: u16) -> Result<Self> {
29 Self::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port))
30 }
31 pub fn bind(addr: SocketAddrV4) -> Result<Self> {
32 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
33 socket.bind(&addr.into())?;
34
35 #[cfg(depend_nix)]
36 use std::os::fd::AsFd;
37 #[cfg(depend_nix)]
38 socket::setsockopt(&socket.as_fd(), socket::sockopt::Ipv4PacketInfo, &true)
39 .map_err(nix_to_io_error)?;
40
41 #[cfg(windows)]
42 let wsa_structs = win_helper::win_init(&socket)?;
43
44 Ok(Self {
45 socket,
46 #[cfg(windows)]
47 wsa_structs
48 })
49 }
50
51 pub fn get_bind_addr(&self) -> Result<SocketAddrV4> {
52 let addr = self.socket.local_addr()?;
53 if let Some(addr) = addr.as_socket_ipv4() {
54 Ok(addr)
55 } else {
56 Err(io::Error::other("Not an IPv4 address"))
57 }
58 }
59
60 pub fn join_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
62 self.socket.join_multicast_v4(&addr, &interface)
63 }
64
65 pub fn leave_multicast_group(&self, addr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
67 self.socket.leave_multicast_v4(&addr, &interface)
68 }
69
70 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
72 self.socket.set_nonblocking(nonblocking)
73 }
74 pub fn set_read_timeout(&self, timeout: std::time::Duration) -> Result<()> {
76 self.socket.set_read_timeout(Some(timeout))
77 }
78 pub fn set_read_timeout_inf(&self) -> Result<()> {
80 self.socket.set_read_timeout(None)
81 }
82
83}
84
85#[cfg(depend_nix)]
86impl MultiInterfaceSocket {
87 pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
89 use std::os::fd::AsRawFd;
90
91 let mut control_buffer = nix::cmsg_space!(nix::libc::in_pktinfo);
92 let mut bufs = [IoSliceMut::new(buf)];
93 let message: socket::RecvMsg<socket::SockaddrIn> = socket::recvmsg(
94 self.socket.as_raw_fd(),
95 &mut bufs,
96 Some(&mut control_buffer),
97 socket::MsgFlags::empty(),
98 )
99 .map_err(nix_to_io_error)?;
100
101 let dst_addr = message.address.map(|a| SocketAddrV4::new(a.ip(), a.port()))
102 .unwrap_or(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0));
103 let sz = message.bytes;
104
105 let mut index = 0;
106 for cmsg in message.cmsgs()? {
107 if let socket::ControlMessageOwned::Ipv4PacketInfo(pkt_info) = cmsg {
108 index = pkt_info.ipi_ifindex as u32;
109 break;
110 }
111 }
112 Ok((&mut buf[..sz], dst_addr, index))
113 }
114
115
116 #[cfg(depend_nix)]
117 pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, _source_if_addr: IpAddr) -> Result<usize> {
118 use std::os::fd::AsRawFd;
119
120 let bufs = [IoSlice::new(buf)];
121 let mut pkt_info: nix::libc::in_pktinfo = unsafe { std::mem::zeroed() };
122 pkt_info.ipi_ifindex = iface_index as i32;
123
124 socket::sendmsg(
125 self.socket.as_raw_fd(),
126 &bufs,
127 &[socket::ControlMessage::Ipv4PacketInfo(&pkt_info)],
128 socket::MsgFlags::empty(),
129 Some(&socket::SockaddrIn::from(addr)),
130 )
131 .map_err(nix_to_io_error)
132 }
133}
134#[cfg(windows)]
135impl MultiInterfaceSocket {
136 pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
138 let (sz, addr, iface) = self.wsa_structs.receive(buf, &self.socket)?;
139 Ok((&mut buf[..sz], addr, iface))
140 }
141 pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, source_if_addr: IpAddr) -> Result<usize> {
142 if let IpAddr::V4(source_ip_addr) = source_if_addr {
143 self.wsa_structs.send(buf, addr, iface_index, source_ip_addr, &self.socket)
144 }
145 else {
146 Err(io::Error::other("Not an IPv4 address"))
147 }
148 }
149
150}
151#[cfg(use_fallback_impl)]
152fn convert_buf(buf: &mut [u8]) -> &mut [std::mem::MaybeUninit<u8>] {
153 unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>, buf.len()) }
154}
155#[cfg(use_fallback_impl)]
157impl MultiInterfaceSocket {
158 pub fn recv_from_iface<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a mut [u8], SocketAddrV4, u32)> {
159 let (sz, addr) = self.socket.recv_from(convert_buf(buf))?;
160
161 if let Some(addr) = addr.as_socket_ipv4() {
162 Ok((&mut buf[..sz], addr, 1))
163 }
164 else {
165 Err(io::Error::other("Not an IPv4 address"))
166 }
167
168 }
169 pub fn send_to_iface(&self, buf: &[u8], addr: SocketAddrV4, iface_index: u32, source_if_addr: IpAddr) -> Result<usize> {
170 self.socket.send_to(buf, &SockAddr::from(addr))
171 }
172}