Skip to main content

oko_multicast_socket/
unix.rs

1use std::collections::HashMap;
2use std::io::IoSliceMut;
3use std::io::{self, IoSlice};
4use std::mem;
5use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
6use std::os::unix::io::AsRawFd;
7
8use network_interface::{NetworkInterface, NetworkInterfaceConfig, V4IfAddr};
9use socket2::{Domain, Protocol, Socket, Type};
10
11#[cfg(feature = "tokio")]
12use tokio::io::Interest;
13
14use nix::sys::socket::{self as sock, RecvMsg};
15
16fn create_on_interfaces(
17    options: crate::MulticastOptions,
18    interfaces: Vec<Ipv4Addr>,
19    multicast_address: SocketAddrV4,
20) -> io::Result<MulticastSocket> {
21    let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
22    socket.set_nonblocking(options.nonblocking)?;
23    if !options.nonblocking {
24        socket.set_read_timeout(options.read_timeout)?;
25    }
26    socket.set_multicast_loop_v4(options.loopback)?;
27    socket.set_reuse_address(true)?;
28    socket.set_reuse_port(true)?;
29
30    // Ipv4PacketInfo translates to `IP_PKTINFO`. Checkout the [ip
31    // manpage](https://man7.org/linux/man-pages/man7/ip.7.html) for more details. In summary
32    // setting this option allows for determining on which interface a packet was received.
33    sock::setsockopt(socket.as_raw_fd(), sock::sockopt::Ipv4PacketInfo, &true)
34        .map_err(nix_to_io_error)?;
35
36    for interface in &interfaces {
37        socket.join_multicast_v4(multicast_address.ip(), &interface)?;
38    }
39
40    socket.bind(&SocketAddr::new(options.bind_address.into(), multicast_address.port()).into())?;
41
42    Ok(MulticastSocket {
43        socket,
44        inner: MulticastSocketInner {
45            interfaces,
46            multicast_address,
47            buffer_size: options.buffer_size,
48        },
49    })
50}
51
52struct MulticastSocketInner {
53    interfaces: Vec<Ipv4Addr>,
54    multicast_address: SocketAddrV4,
55    buffer_size: usize,
56}
57
58pub struct MulticastSocket {
59    socket: socket2::Socket,
60    inner: MulticastSocketInner,
61}
62
63#[derive(Debug, Clone)]
64pub enum Interface {
65    Default,
66    Ip(Ipv4Addr),
67    Index(i32),
68}
69
70#[derive(Debug, Clone)]
71pub struct Message {
72    pub data: Vec<u8>,
73    pub origin_address: SocketAddrV4,
74    pub interface: Interface,
75}
76
77pub fn all_ipv4_interfaces() -> io::Result<Vec<Ipv4Addr>> {
78    let interfaces =
79        NetworkInterface::show().map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
80    // We have to filter the same interface if it has multiple ips
81    // https://stackoverflow.com/questions/49819010/ip-add-membership-fails-when-set-both-on-interface-and-its-subinterface-is-that
82    let mut collected_interfaces = HashMap::with_capacity(interfaces.len());
83    for iface in interfaces.iter() {
84        for (name, ipv4addr) in iface.addr.iter().filter_map(|addr| {
85            if let network_interface::Addr::V4(V4IfAddr { ip, .. }) = addr {
86                if !ip.is_loopback() {
87                    Some((iface.name.clone(), ip.clone()))
88                } else {
89                    None
90                }
91            } else {
92                None
93            }
94        }) {
95            if !collected_interfaces.contains_key(&name) {
96                collected_interfaces.insert(name, ipv4addr);
97            }
98        }
99    }
100    Ok(collected_interfaces.into_iter().map(|(_, ip)| ip).collect())
101}
102
103impl MulticastSocket {
104    pub fn all_interfaces(multicast_address: SocketAddrV4) -> io::Result<Self> {
105        let interfaces = all_ipv4_interfaces()?;
106        create_on_interfaces(Default::default(), interfaces, multicast_address)
107    }
108
109    pub fn with_options(
110        multicast_address: SocketAddrV4,
111        interfaces: Vec<Ipv4Addr>,
112        options: crate::MulticastOptions,
113    ) -> io::Result<Self> {
114        create_on_interfaces(options, interfaces, multicast_address)
115    }
116}
117
118fn nix_to_io_error(e: nix::Error) -> io::Error {
119    match e {
120        nix::errno::Errno::EAGAIN => io::ErrorKind::WouldBlock.into(),
121        _ => io::Error::new(io::ErrorKind::Other, e),
122    }
123}
124
125impl MulticastSocket {
126    pub fn receive(&self) -> io::Result<Message> {
127        let mut data_buffer = vec![0; self.inner.buffer_size];
128        let mut control_buffer = nix::cmsg_space!(libc::in_pktinfo);
129        let io_slice = &mut [IoSliceMut::new(&mut data_buffer)];
130
131        let message: RecvMsg<sock::SockaddrIn> = sock::recvmsg(
132            self.socket.as_raw_fd(),
133            io_slice,
134            Some(&mut control_buffer),
135            sock::MsgFlags::empty(),
136        )
137        .map_err(nix_to_io_error)?;
138
139        let origin_address = match message.address {
140            Some(sockaddr) => SocketAddrV4::new(
141                Ipv4Addr::from(sockaddr.ip().to_le()),
142                sockaddr.port().to_le(),
143            ),
144            _ => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0),
145        };
146
147        let mut interface = Interface::Default;
148
149        for cmsg in message.cmsgs() {
150            if let sock::ControlMessageOwned::Ipv4PacketInfo(pktinfo) = cmsg {
151                interface = Interface::Index(pktinfo.ipi_ifindex as _);
152            }
153        }
154
155        // Weird borrow interaction here because of the mutable borrow that
156        // goes in to the IoSlice, so it's time to bust out the good old fashioned
157        // for loop.
158        let mut data = Vec::with_capacity(message.bytes);
159        for i in 0..message.bytes {
160            data.push(data_buffer[i]);
161        }
162        Ok(Message {
163            data,
164            origin_address,
165            interface,
166        })
167    }
168
169    pub fn send_to(
170        &self,
171        buf: &[u8],
172        interface: &Interface,
173        addr: SocketAddrV4,
174    ) -> io::Result<usize> {
175        let mut pkt_info: libc::in_pktinfo = unsafe { mem::zeroed() };
176
177        match interface {
178            Interface::Default => {}
179            Interface::Ip(address) => {
180                pkt_info.ipi_spec_dst = libc::in_addr {
181                    s_addr: u32::from_ne_bytes(address.octets()),
182                }
183            }
184            Interface::Index(index) => pkt_info.ipi_ifindex = *index as _,
185        };
186
187        sock::sendmsg(
188            self.socket.as_raw_fd(),
189            &[IoSlice::new(&buf)],
190            &[sock::ControlMessage::Ipv4PacketInfo(&pkt_info)],
191            sock::MsgFlags::empty(),
192            Some(&sock::SockaddrIn::from(SocketAddrV4::from(addr))),
193        )
194        .map_err(nix_to_io_error)
195    }
196
197    pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result<usize> {
198        self.send_to(buf, interface, self.inner.multicast_address)
199    }
200
201    pub fn broadcast_to(&self, buf: &[u8], addr: SocketAddrV4) -> io::Result<()> {
202        for interface in &self.inner.interfaces {
203            self.send_to(buf, &Interface::Ip(*interface), addr)?;
204        }
205        Ok(())
206    }
207
208    pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> {
209        self.broadcast_to(buf, self.inner.multicast_address)
210    }
211}
212
213#[cfg(feature = "tokio")]
214pub struct AsyncMulticastSocket {
215    socket: tokio::net::UdpSocket,
216    inner: MulticastSocketInner,
217}
218
219/// Converts this socket in to one with an `async` API.
220/// This will call `bind` on the socket if it has not already
221/// been bound.
222#[cfg(feature = "tokio")]
223impl TryFrom<MulticastSocket> for AsyncMulticastSocket {
224    type Error = io::Error;
225
226    fn try_from(other: MulticastSocket) -> Result<Self, Self::Error> {
227        other.socket.set_nonblocking(true)?;
228        let sock = tokio::net::UdpSocket::from_std(other.socket.into())?;
229        Ok(Self {
230            socket: sock,
231            inner: other.inner,
232        })
233    }
234}
235
236#[cfg(feature = "tokio")]
237impl AsyncMulticastSocket {
238    pub async fn receive(&self) -> io::Result<Message> {
239        let mut data_buffer = vec![0; self.inner.buffer_size];
240
241        // There is no Async API for the UNIX sendmsg/recvmsg vectored scatter-gather
242        // calls, and the multihome functionality relies on receiving that ancillary data,
243        // so we have to make this operation async "manually".
244        self.socket
245            .async_io(Interest::READABLE, || {
246                let io_slice = &mut [IoSliceMut::new(&mut data_buffer)];
247                let mut control_buffer = nix::cmsg_space!(libc::in_pktinfo);
248                let message: RecvMsg<sock::SockaddrIn> = sock::recvmsg(
249                    self.socket.as_raw_fd(),
250                    io_slice,
251                    Some(&mut control_buffer),
252                    sock::MsgFlags::empty(),
253                )
254                .map_err(nix_to_io_error)?;
255
256                let origin_address = match message.address {
257                    Some(sockaddr) => SocketAddrV4::new(
258                        Ipv4Addr::from(sockaddr.ip().to_le()),
259                        sockaddr.port().to_le(),
260                    ),
261                    _ => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0),
262                };
263
264                let mut interface = Interface::Default;
265
266                for cmsg in message.cmsgs() {
267                    if let sock::ControlMessageOwned::Ipv4PacketInfo(pktinfo) = cmsg {
268                        interface = Interface::Index(pktinfo.ipi_ifindex as _);
269                    }
270                }
271
272                // Weird borrow interaction here because of the mutable borrow that
273                // goes in to the IoSlice, so it's time to bust out the good old fashioned
274                // for loop.
275                let mut data = Vec::with_capacity(message.bytes);
276                for i in 0..message.bytes {
277                    data.push(data_buffer[i]);
278                }
279
280                Ok(Message {
281                    data,
282                    origin_address,
283                    interface,
284                })
285            })
286            .await
287    }
288
289    pub async fn send_to(
290        &self,
291        buf: &[u8],
292        interface: &Interface,
293        addr: SocketAddrV4,
294    ) -> io::Result<usize> {
295        let mut pkt_info: libc::in_pktinfo = unsafe { mem::zeroed() };
296
297        match interface {
298            Interface::Default => {}
299            Interface::Ip(address) => {
300                pkt_info.ipi_spec_dst = libc::in_addr {
301                    s_addr: u32::from_ne_bytes(address.octets()),
302                }
303            }
304            Interface::Index(index) => pkt_info.ipi_ifindex = *index as _,
305        };
306
307        // There is no Async API for the UNIX sendmsg/recvmsg vectored scatter-gather
308        // calls, and the multihome functionality relies on receiving that ancillary data,
309        // so we have to make this operation async "manually".
310        self.socket
311            .async_io(Interest::WRITABLE, || {
312                sock::sendmsg(
313                    self.socket.as_raw_fd(),
314                    &[IoSlice::new(&buf)],
315                    &[sock::ControlMessage::Ipv4PacketInfo(&pkt_info)],
316                    sock::MsgFlags::empty(),
317                    Some(&sock::SockaddrIn::from(SocketAddrV4::from(addr))),
318                )
319                .map_err(nix_to_io_error)
320            })
321            .await
322    }
323    pub async fn send(&self, buf: &[u8], interface: &Interface) -> io::Result<usize> {
324        self.send_to(buf, interface, self.inner.multicast_address)
325            .await
326    }
327
328    pub async fn broadcast_to(&self, buf: &[u8], addr: SocketAddrV4) -> io::Result<()> {
329        for interface in &self.inner.interfaces {
330            self.send_to(buf, &Interface::Ip(*interface), addr).await?;
331        }
332        Ok(())
333    }
334
335    pub async fn broadcast(&self, buf: &[u8]) -> io::Result<()> {
336        self.broadcast_to(buf, self.inner.multicast_address).await
337    }
338}