multicast_socket/
unix.rs

1use std::collections::HashMap;
2use std::io;
3use std::mem;
4use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
5use std::os::unix::io::AsRawFd;
6
7use socket2::{Domain, Protocol, Socket, Type};
8
9use nix::sys::socket as sock;
10use nix::sys::uio::IoVec;
11
12fn create_on_interfaces(
13    options: crate::MulticastOptions,
14    interfaces: Vec<Ipv4Addr>,
15    multicast_address: SocketAddrV4,
16) -> io::Result<MulticastSocket> {
17    let socket = Socket::new(Domain::ipv4(), Type::dgram(), Some(Protocol::udp()))?;
18    socket.set_read_timeout(options.read_timeout)?;
19    socket.set_multicast_loop_v4(options.loopback)?;
20    socket.set_reuse_address(true)?;
21    socket.set_reuse_port(true)?;
22
23    // Ipv4PacketInfo translates to `IP_PKTINFO`. Checkout the [ip
24    // manpage](https://man7.org/linux/man-pages/man7/ip.7.html) for more details. In summary
25    // setting this option allows for determining on which interface a packet was received.
26    sock::setsockopt(socket.as_raw_fd(), sock::sockopt::Ipv4PacketInfo, &true)
27        .map_err(nix_to_io_error)?;
28
29    for interface in &interfaces {
30        socket.join_multicast_v4(multicast_address.ip(), &interface)?;
31    }
32
33    socket.bind(&SocketAddr::new(options.bind_address.into(), multicast_address.port()).into())?;
34
35    Ok(MulticastSocket {
36        socket,
37        interfaces,
38        multicast_address,
39        buffer_size: options.buffer_size,
40    })
41}
42
43pub struct MulticastSocket {
44    socket: socket2::Socket,
45    interfaces: Vec<Ipv4Addr>,
46    multicast_address: SocketAddrV4,
47    buffer_size: usize,
48}
49
50#[derive(Debug, Clone)]
51pub enum Interface {
52    Default,
53    Ip(Ipv4Addr),
54    Index(i32),
55}
56
57#[derive(Debug, Clone)]
58pub struct Message {
59    pub data: Vec<u8>,
60    pub origin_address: SocketAddrV4,
61    pub interface: Interface,
62}
63
64pub fn all_ipv4_interfaces() -> io::Result<Vec<Ipv4Addr>> {
65    let interfaces = if_addrs::get_if_addrs()?.into_iter();
66
67    // We have to filter the same interface if it has multiple ips
68    // https://stackoverflow.com/questions/49819010/ip-add-membership-fails-when-set-both-on-interface-and-its-subinterface-is-that
69    let mut collected_interfaces = HashMap::with_capacity(interfaces.len());
70    for interface in interfaces {
71        if !collected_interfaces.contains_key(&interface.name) {
72            match interface.ip() {
73                std::net::IpAddr::V4(v4) if !interface.is_loopback() => {
74                    collected_interfaces.insert(interface.name, v4);
75                }
76                _ => {}
77            }
78        }
79    }
80    Ok(collected_interfaces.into_iter().map(|(_, ip)| ip).collect())
81}
82
83impl MulticastSocket {
84    pub fn all_interfaces(multicast_address: SocketAddrV4) -> io::Result<Self> {
85        let interfaces = all_ipv4_interfaces()?;
86        create_on_interfaces(Default::default(), interfaces, multicast_address)
87    }
88
89    pub fn with_options(
90        multicast_address: SocketAddrV4,
91        interfaces: Vec<Ipv4Addr>,
92        options: crate::MulticastOptions,
93    ) -> io::Result<Self> {
94        create_on_interfaces(options, interfaces, multicast_address)
95    }
96}
97
98fn nix_to_io_error(e: nix::Error) -> io::Error {
99    io::Error::new(io::ErrorKind::Other, e)
100}
101
102impl MulticastSocket {
103    pub fn receive(&self) -> io::Result<Message> {
104        let mut data_buffer = vec![0; self.buffer_size];
105        let mut control_buffer = nix::cmsg_space!(libc::in_pktinfo);
106
107        let message = sock::recvmsg(
108            self.socket.as_raw_fd(),
109            &[IoVec::from_mut_slice(&mut data_buffer)],
110            Some(&mut control_buffer),
111            sock::MsgFlags::empty(),
112        )
113        .map_err(nix_to_io_error)?;
114
115        let origin_address = match message.address {
116            Some(sock::SockAddr::Inet(v4)) => Some(v4.to_std()),
117            _ => None,
118        };
119        let origin_address = match origin_address {
120            Some(SocketAddr::V4(v4)) => v4,
121            _ => SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0),
122        };
123
124        let mut interface = Interface::Default;
125
126        for cmsg in message.cmsgs() {
127            if let sock::ControlMessageOwned::Ipv4PacketInfo(pktinfo) = cmsg {
128                interface = Interface::Index(pktinfo.ipi_ifindex as _);
129            }
130        }
131
132        Ok(Message {
133            data: data_buffer[0..message.bytes].to_vec(),
134            origin_address,
135            interface,
136        })
137    }
138
139    pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result<usize> {
140        let mut pkt_info: libc::in_pktinfo = unsafe { mem::zeroed() };
141
142        match interface {
143            Interface::Default => {}
144            Interface::Ip(address) => pkt_info.ipi_spec_dst = sock::Ipv4Addr::from_std(address).0,
145            Interface::Index(index) => pkt_info.ipi_ifindex = *index as _,
146        };
147
148        let destination = sock::InetAddr::from_std(&self.multicast_address.into());
149
150        sock::sendmsg(
151            self.socket.as_raw_fd(),
152            &[IoVec::from_slice(&buf)],
153            &[sock::ControlMessage::Ipv4PacketInfo(&pkt_info)],
154            sock::MsgFlags::empty(),
155            Some(&sock::SockAddr::new_inet(destination)),
156        )
157        .map_err(nix_to_io_error)
158    }
159
160    pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> {
161        for interface in &self.interfaces {
162            self.send(buf, &Interface::Ip(*interface))?;
163        }
164        Ok(())
165    }
166}