multicast_discovery_socket/
lib.rs

1#![allow(clippy::int_plus_one)]
2#![allow(clippy::new_without_default)]
3
4use std::{io, thread};
5use std::borrow::Cow;
6use std::iter::once;
7use std::net::{IpAddr, SocketAddrV4};
8use std::ops::Deref;
9use std::time::{Duration, Instant};
10use if_addrs::Interface;
11use log::{debug, error, info, warn};
12use crate::config::MulticastDiscoveryConfig;
13use crate::interfaces::InterfaceTracker;
14use crate::protocol::DiscoveryMessage;
15use crate::socket::MultiInterfaceSocket;
16
17pub mod config;
18pub mod protocol;
19pub mod socket;
20pub mod interfaces;
21
22#[derive(Default)]
23pub struct PerInterfaceState {
24    last_announce_tm: Option<Instant>,
25    extended_announcements_request_tm: Option<Instant>,
26    extend_request_send_tm: Option<Instant>,
27
28    extended_announce_enabled: bool,
29}
30
31impl PerInterfaceState {
32    pub fn should_announce(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
33        self.last_announce_tm.is_none_or(|tm| now - tm > cfg.announce_interval)
34    }
35    pub fn should_send_extend_request(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
36        self.extend_request_send_tm.is_none_or(|tm| now - tm > cfg.extend_request_interval)
37    }
38    pub fn should_extended_announce(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
39        self.extended_announcements_enabled(now, cfg) && self.should_announce(now, cfg)
40    }
41
42    pub fn extended_announcements_enabled(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
43        self.extended_announcements_request_tm.is_some_and(|tm| now - tm < cfg.extended_announcement_effect_dur)
44    }
45    pub fn got_extend_announce_req(&mut self, now: Instant) {
46        self.extended_announcements_request_tm = Some(now);
47        self.extended_announce_enabled = true;
48    }
49}
50
51pub trait AdvertisementData: Sized + Clone {
52    fn encode_to_bytes(&self) -> Vec<u8>;
53    fn try_decode(bytes: &[u8]) -> Option<Self>;
54}
55#[cfg(not(feature="bincode"))]
56mod adv_data_impls {
57    use super::AdvertisementData;
58    impl AdvertisementData for () {
59        fn encode_to_bytes(&self) -> Vec<u8> {
60            Vec::new()
61        }
62        fn try_decode(bytes: &[u8]) -> Option<Self> {
63            if bytes.is_empty() {
64                Some(())
65            }
66            else {
67                None
68            }
69        }
70
71    }
72    impl AdvertisementData for Vec<u8> {
73        fn encode_to_bytes(&self) -> Vec<u8> {
74            self.clone()
75        }
76        fn try_decode(bytes: &[u8]) -> Option<Self> {
77            Some(bytes.to_vec())
78        }
79    }
80}
81
82#[cfg(feature="bincode")]
83use bincode::{Decode, Encode};
84#[cfg(feature="bincode")]
85impl<T> AdvertisementData for T
86    where T: Encode + Decode<()> + Clone {
87    fn encode_to_bytes(&self) -> Vec<u8> {
88        bincode::encode_to_vec(self, bincode::config::standard()).unwrap()
89    }
90    fn try_decode(bytes: &[u8]) -> Option<Self> {
91        bincode::decode_from_slice(bytes, bincode::config::standard())
92            .ok()
93            .map(|(v, _)| v)
94    }
95}
96
97pub struct MulticastDiscoverySocket<D: AdvertisementData> {
98    socket: MultiInterfaceSocket,
99    cfg: MulticastDiscoveryConfig,
100    discover_id: u32,
101    running_port: MulticastRunningPort,
102    interface_tracker: InterfaceTracker<PerInterfaceState>,
103
104    announce_enabled: bool,
105    discover_replies: bool,
106    /// Announce payload: service port. If not set, announcements are disabled
107    service_port_and_adv_data: Option<(u16, D)>,
108}
109
110#[derive(Debug, Copy, Clone)]
111pub enum MulticastRunningPort {
112    Primary(u16),
113    Backup(u16),
114    Other
115}
116
117impl Deref for MulticastRunningPort {
118    type Target = u16;
119    fn deref(&self) -> &Self::Target {
120        match self {
121            MulticastRunningPort::Primary(p) => p,
122            MulticastRunningPort::Backup(p) => p,
123            MulticastRunningPort::Other => &0,
124        }
125    }
126}
127
128
129pub enum PollResult<'a, D> {
130    DiscoveredClient {
131        addr: SocketAddrV4,
132        discover_id: u32,
133        adv_data: &'a D,
134    },
135    DisconnectedClient {
136        addr: SocketAddrV4,
137        discover_id: u32
138    }
139}
140
141impl<D: AdvertisementData> MulticastDiscoverySocket<D> {
142    /// Create new socket for multicast discovery. Announcements are disabled (running without service)
143    /// Enable feature `bincode` for passing `Encode` + `Decode` types as adv_data
144    pub fn new_discover_only(cfg: &MulticastDiscoveryConfig) -> io::Result<Self> {
145        Self::new(cfg, None)
146    }
147    /// Create new socket for multicast discovery. Announcements are enabled depending on config.
148    /// Enable feature `bincode` for passing `Encode` + `Decode` types as adv_data
149    pub fn new_with_service(cfg: &MulticastDiscoveryConfig, service_port: u16, initial_adv_data: D) -> io::Result<Self> {
150        Self::new(cfg, Some((service_port, initial_adv_data)))
151    }
152    fn new(cfg: &MulticastDiscoveryConfig, service_port_and_adv_data: Option<(u16, D)>) -> io::Result<Self> {
153        let central_discovery_enabled = cfg.central_discovery_addr.is_some();
154        let mut is_primary = true;
155        
156        let mut interface_tracker = InterfaceTracker::new();
157        // Try primary and backup ports
158        let main_port = cfg.iter_ports().next().unwrap();
159        for port in cfg.iter_ports().chain(once(0)) {
160            match MultiInterfaceSocket::bind_port(port) {
161                Ok(socket) => {
162                    // Join multicast group on all interfaces
163                    for (interface, _) in interface_tracker.iter_mut() {
164                        if let IpAddr::V4(ip) = interface.ip() {
165                            if let Err(e) = socket.join_multicast_group(cfg.multicast_group_ip, ip) {
166                                warn!("Failed to join multicast group on interface {}: {}", interface.ip(), e);
167                            }
168                            else {
169                                info!("Joined multicast group on interface {}", interface.ip());
170                            }
171                        }
172                    }
173                   
174                    // Set non-blocking
175                    socket.set_nonblocking(true)?;
176                    
177                    let running_port = if is_primary {
178                        debug!("Using primary multicast port {} for discovery", port);
179                        MulticastRunningPort::Primary(port)
180                    }
181                    else if port == 0 {
182                        let failed_ports = cfg.iter_ports().filter(|p| *p != 0);
183                        warn!("Unable to start on the main or backup ports ({:?})!", &failed_ports.collect::<Vec<_>>());
184                        if !central_discovery_enabled {
185                            warn!("You may face issues with discovering");
186                        }
187                        else {
188                            warn!("You will be able to discover clients only when your network is online!");
189                        }
190                        MulticastRunningPort::Other
191                    }
192                    else {
193                        warn!("Using backup multicast port {} for discovery (unable to start on main port {})", port, main_port);
194                        MulticastRunningPort::Backup(port)
195                    };
196                    return Ok(Self {
197                        socket,
198                        interface_tracker,
199                        cfg: cfg.clone(),
200                        discover_id: rand::random_range(0..u32::MAX),
201                        running_port,
202
203                        announce_enabled: cfg.enable_announce,
204                        discover_replies: cfg.enable_announce,
205
206                        service_port_and_adv_data,
207                    })
208                }
209                Err(e) if e.kind() == io::ErrorKind::AddrInUse => {
210                    is_primary = false;
211                    continue
212                },
213                Err(e) => {
214                    is_primary = false;
215                    warn!("Failed to bind socket to port {}: {}", port, e);
216                    continue;
217                }
218            }
219        }
220    
221        error!("Failed to bind multicast discovery socket to any port!");
222        Err(io::Error::new(io::ErrorKind::AddrInUse, "Failed to bind socket to any port"))
223    }
224    
225    pub fn discover_id(&self) -> u32 {
226        self.discover_id
227    }
228    pub fn running_port(&self) -> MulticastRunningPort {
229        self.running_port
230    }
231
232    /// Setting this to `false` will disable periodic background announcements during `poll()`
233    /// Announcements are performed by periodic sending message `Announce` (and `ExtendAnnounce`)
234    pub fn set_announce_en(&mut self, en: bool) {
235        self.announce_enabled = en;
236    }
237
238    /// Setting this to `false` will disable automatic replies to `Discovery` messages
239    pub fn set_discover_replies_en(&mut self, enable: bool) {
240        self.discover_replies = enable;
241    }
242
243    /// Guaranteed to return Some(&mut D) if was created with `Self::new`
244    pub fn adv_data(&mut self) -> Option<&mut D> {
245        self.service_port_and_adv_data.as_mut().map(|s| &mut s.1)
246    }
247
248    /// Manually discover all clients on main or backup ports (using `Discovery` message).
249    /// Results can be collected by running `poll`.
250    pub fn discover(&mut self) {
251        info!("Multicast discovery: running manual discovery...");
252        let msg = DiscoveryMessage::Discovery::<D>.gen_message();
253        for (interface, _) in self.interface_tracker.iter_mut() {
254            let Some(index) = interface.index else {
255                continue;
256            };
257            if interface.ip().is_loopback() {
258                continue;
259            }
260
261            for port in self.cfg.iter_ports() {
262                if let Err(e) = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), index, interface.addr.ip()) {
263                    warn!("Failed to send discovery message on interface [{}] - {}: {}", interface.ip(), interface.name, e);
264                } else {
265                    debug!("Sent discovery message to port {} on interface [{}] - {}", port, interface.ip(), interface.name);
266                }
267            }
268        }
269    }
270
271    /// Run `poll` periodically to handle internal discovery mechanisms:
272    /// - `Announce` messages periodic sending
273    /// - `ExtendAnnounce` messages periodic sending (if running on backup port)
274    /// - handling incoming messages (and returning discovery results via `discover_msg` callback)
275    ///
276    /// It is recommended to call this function in a loop with ~100ms sleep
277    pub fn poll(&mut self, mut discover_msg: impl FnMut(PollResult<D>)) {
278        // 0. poll interface updates
279        self.interface_tracker.poll_updates(|new_ip| {
280            if let Err(e) = self.socket.join_multicast_group(self.cfg.multicast_group_ip, new_ip) {
281                warn!("Failed to join multicast group on interface {}: {}", new_ip, e);
282            }
283            else {
284                info!("Joined multicast group on interface {}!", new_ip);
285            }
286        });
287    
288        let mut interface_cnt = 0;
289        if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_mut() {
290            if self.announce_enabled {
291                for (interface, state) in self.interface_tracker.iter_mut() {
292                    let Some(interface_index) = interface.index else {
293                        continue;
294                    };
295                    // Skip for now
296                    if interface.ip().is_loopback() {
297                        continue;
298                    }
299
300                    // 1. Handle announcements
301                    let now = Instant::now();
302                    let should_announce = state.should_announce(now, &self.cfg);
303                    let should_extended_announce = state.should_extended_announce(now, &self.cfg);
304                    let should_send_extend_request = state.should_send_extend_request(now, &self.cfg);
305                    if should_announce {
306                        state.last_announce_tm = Some(now);
307
308                        let msg = DiscoveryMessage::Announce {
309                            service_port: *service_port,
310                            discover_id: self.discover_id,
311                            disconnected: false,
312                            adv_data: Cow::Borrowed(&*adv_data)
313                        }.gen_message();
314                        if should_extended_announce {
315                            for port in self.cfg.iter_ports() {
316                                let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), interface_index, interface.addr.ip());
317                                handle_err(res, "send extended announce", interface);
318                            }
319                        }
320                        else {
321                            let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, self.cfg.multicast_port), interface_index, interface.addr.ip());
322                            handle_err(res, "send normal announce", interface);
323                        }
324                        if state.extended_announce_enabled && !state.extended_announcements_enabled(now, &self.cfg) {
325                            state.extended_announce_enabled = false;
326                            info!("No longer sending extended announce on interface [{}] - {}", interface.ip(), interface.name);
327                        }
328                    }
329                    // 2. Sending extend requests
330                    if matches!(self.running_port, MulticastRunningPort::Backup(_)) && should_send_extend_request {
331                        state.extend_request_send_tm = Some(now);
332                        let msg = DiscoveryMessage::ExtendAnnouncements::<D>.gen_message();
333                        for port in self.cfg.iter_ports() {
334                            let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), interface_index, interface.addr.ip());
335                            handle_err(res, "send extended announce request", interface);
336                        }
337                    }
338                    interface_cnt += 1;
339                }
340
341                if interface_cnt == 0 {
342                    warn!("No available interface found!");
343                    thread::sleep(Duration::from_millis(500));
344                }
345            }
346        }
347
348        // 3. Handle incoming packets
349        let mut buf = [0u8;256];
350        while let Ok((data, addr, index)) = self.socket.recv_from_iface(&mut buf) {
351
352            // Shut up messages from ourselves on all interfaces
353            if self.interface_tracker.iter_mut().any(|(i, _)| i.ip() == IpAddr::V4(*addr.ip())) && addr.port() == *self.running_port {
354                continue;
355            }
356
357            match DiscoveryMessage::<D>::try_parse(data) {
358                Some(DiscoveryMessage::Discovery) => {
359                    if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_mut() {
360                        if self.discover_replies {
361                            let announce = DiscoveryMessage::Announce {
362                                disconnected: false,
363                                discover_id: self.discover_id,
364                                service_port: *service_port,
365                                adv_data: Cow::Borrowed(&*adv_data)
366                            }.gen_message();
367                            let source_addr = self.interface_tracker.iter_mapping().find(|(i, _)| *i == index);
368                            if let Some((_, a)) = source_addr {
369                                if let Err(e) = self.socket.send_to_iface(&announce, addr, index, a.into()) {
370                                    warn!("Failed to answer to discovery packet: {:?}", e);
371                                }
372                            }
373                            else {
374                                warn!("Failed to answer discovery packet: interface address not found for index!");
375                            }
376                        }
377                    }
378                }
379                Some(DiscoveryMessage::Announce { service_port, discover_id, disconnected, adv_data}) => {
380                    if disconnected {
381                        discover_msg(PollResult::DisconnectedClient {
382                            addr: SocketAddrV4::new(
383                                *addr.ip(),
384                                service_port,
385                            ),
386                            discover_id
387                        })
388                    }
389                    else {
390                        discover_msg(PollResult::DiscoveredClient {
391                            addr: SocketAddrV4::new(
392                                *addr.ip(),
393                                service_port,
394                            ),
395                            discover_id,
396                            adv_data: adv_data.as_ref(),
397                        })
398                    }
399                }
400                Some(DiscoveryMessage::ExtendAnnouncements) => {
401                    for (interface, state) in self.interface_tracker.iter_mut() {
402                        if interface.index == Some(index) {
403                            let now = Instant::now();
404                            if !state.extended_announcements_enabled(now, &self.cfg) {
405                                info!("Enabling extended announcements on interface [{}] - {}", interface.ip(), interface.name);
406                            }
407                            state.got_extend_announce_req(now);
408                        }
409                    }
410                }
411                None => {
412                    warn!("Received unknown message from {}: {:?}", addr, data);
413                }
414            }
415        }
416    }
417}
418impl<D: AdvertisementData> Drop for MulticastDiscoverySocket<D> {
419    fn drop(&mut self) {
420        // Announce disconnection
421        if !self.announce_enabled {
422            return;
423        }
424        if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_ref() {
425            for (interface, _) in self.interface_tracker.iter_mut() {
426                let Some(index) = interface.index else {
427                    continue;
428                };
429                // Skip for now
430                if interface.ip().is_loopback() {
431                    continue;
432                }
433
434                let msg = DiscoveryMessage::Announce {
435                    discover_id: self.discover_id,
436                    service_port: *service_port,
437                    disconnected: true,
438                    adv_data: Cow::Borrowed(adv_data)
439                }.gen_message();
440                for port in self.cfg.iter_ports() {
441                    let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port),index, interface.addr.ip());
442                    handle_err(res, "announce disconnected message", interface);
443                }
444            }
445        }
446    }
447}
448
449fn handle_err(result: io::Result<usize>, msg: &'static str, interface: &Interface) {
450    if let Err(e) = result {
451        warn!("Failed to {} on interface [{:?}] - {}: {}", msg, interface.ip(), interface.name, e);
452    }
453}