multicast_discovery_socket/
lib.rs

1#![allow(clippy::int_plus_one)]
2#![allow(clippy::new_without_default)]
3
4use std::{io, thread};
5use std::borrow::Cow;
6use std::iter::once;
7use std::net::{IpAddr, SocketAddrV4};
8use std::ops::Deref;
9use std::time::{Duration, Instant};
10use log::{debug, error, info, warn};
11use crate::config::MulticastDiscoveryConfig;
12use crate::protocol::DiscoveryMessage;
13use crate::socket::MultiInterfaceSocket;
14
15pub mod config;
16pub mod protocol;
17pub mod socket;
18#[cfg(depend_if_addrs)]
19pub mod interfaces;
20#[cfg(not(depend_if_addrs))]
21#[path="interfaces_fallback.rs"]
22pub mod interfaces;
23
24pub use interfaces::Interface;
25use crate::interfaces::InterfaceTracker;
26
27
28#[derive(Default)]
29pub struct PerInterfaceState {
30    last_announce_tm: Option<Instant>,
31    extended_announcements_request_tm: Option<Instant>,
32    extend_request_send_tm: Option<Instant>,
33
34    extended_announce_enabled: bool,
35}
36
37impl PerInterfaceState {
38    pub fn should_announce(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
39        self.last_announce_tm.is_none_or(|tm| now - tm > cfg.announce_interval)
40    }
41    pub fn should_send_extend_request(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
42        self.extend_request_send_tm.is_none_or(|tm| now - tm > cfg.extend_request_interval)
43    }
44    pub fn should_extended_announce(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
45        self.extended_announcements_enabled(now, cfg) && self.should_announce(now, cfg)
46    }
47
48    pub fn extended_announcements_enabled(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
49        self.extended_announcements_request_tm.is_some_and(|tm| now - tm < cfg.extended_announcement_effect_dur)
50    }
51    pub fn got_extend_announce_req(&mut self, now: Instant) {
52        self.extended_announcements_request_tm = Some(now);
53        self.extended_announce_enabled = true;
54    }
55}
56
57pub trait AdvertisementData: Sized + Clone {
58    fn encode_to_bytes(&self) -> Vec<u8>;
59    fn try_decode(bytes: &[u8]) -> Option<Self>;
60}
61#[cfg(not(feature="bincode"))]
62mod adv_data_impls {
63    use super::AdvertisementData;
64    impl AdvertisementData for () {
65        fn encode_to_bytes(&self) -> Vec<u8> {
66            Vec::new()
67        }
68        fn try_decode(bytes: &[u8]) -> Option<Self> {
69            if bytes.is_empty() {
70                Some(())
71            }
72            else {
73                None
74            }
75        }
76
77    }
78    impl AdvertisementData for Vec<u8> {
79        fn encode_to_bytes(&self) -> Vec<u8> {
80            self.clone()
81        }
82        fn try_decode(bytes: &[u8]) -> Option<Self> {
83            Some(bytes.to_vec())
84        }
85    }
86}
87
88#[cfg(feature="bincode")]
89use bincode::{Decode, Encode};
90#[cfg(feature="bincode")]
91impl<T> AdvertisementData for T
92    where T: Encode + Decode<()> + Clone {
93    fn encode_to_bytes(&self) -> Vec<u8> {
94        bincode::encode_to_vec(self, bincode::config::standard()).unwrap()
95    }
96    fn try_decode(bytes: &[u8]) -> Option<Self> {
97        bincode::decode_from_slice(bytes, bincode::config::standard())
98            .ok()
99            .map(|(v, _)| v)
100    }
101}
102
103pub struct MulticastDiscoverySocket<D: AdvertisementData> {
104    socket: MultiInterfaceSocket,
105    cfg: MulticastDiscoveryConfig,
106    discover_id: u32,
107    running_port: MulticastRunningPort,
108    interface_tracker: InterfaceTracker<PerInterfaceState>,
109
110    announce_enabled: bool,
111    discover_replies: bool,
112    /// Announce payload: service port. If not set, announcements are disabled
113    service_port_and_adv_data: Option<(u16, D)>,
114}
115
116#[derive(Debug, Copy, Clone)]
117pub enum MulticastRunningPort {
118    Primary(u16),
119    Backup(u16),
120    Other
121}
122
123impl Deref for MulticastRunningPort {
124    type Target = u16;
125    fn deref(&self) -> &Self::Target {
126        match self {
127            MulticastRunningPort::Primary(p) => p,
128            MulticastRunningPort::Backup(p) => p,
129            MulticastRunningPort::Other => &0,
130        }
131    }
132}
133
134
135pub enum PollResult<'a, D> {
136    DiscoveredClient {
137        addr: SocketAddrV4,
138        discover_id: u32,
139        adv_data: &'a D,
140    },
141    DisconnectedClient {
142        addr: SocketAddrV4,
143        discover_id: u32
144    }
145}
146
147impl<D: AdvertisementData> MulticastDiscoverySocket<D> {
148    /// Create new socket for multicast discovery. Announcements are disabled (running without service)
149    /// Enable feature `bincode` for passing `Encode` + `Decode` types as adv_data
150    pub fn new_discover_only(cfg: &MulticastDiscoveryConfig) -> io::Result<Self> {
151        Self::new(cfg, None)
152    }
153    /// Create new socket for multicast discovery. Announcements are enabled depending on config.
154    /// Enable feature `bincode` for passing `Encode` + `Decode` types as adv_data
155    pub fn new_with_service(cfg: &MulticastDiscoveryConfig, service_port: u16, initial_adv_data: D) -> io::Result<Self> {
156        Self::new(cfg, Some((service_port, initial_adv_data)))
157    }
158    fn new(cfg: &MulticastDiscoveryConfig, service_port_and_adv_data: Option<(u16, D)>) -> io::Result<Self> {
159        let central_discovery_enabled = cfg.central_discovery_addr.is_some();
160        let mut is_primary = true;
161        
162        let mut interface_tracker = InterfaceTracker::new();
163        // Try primary and backup ports
164        let main_port = cfg.iter_ports().next().unwrap();
165        for port in cfg.iter_ports().chain(once(0)) {
166            match MultiInterfaceSocket::bind_port(port) {
167                Ok(socket) => {
168                    // Join multicast group on all interfaces
169                    for (interface, _) in interface_tracker.iter_mut() {
170                        if let IpAddr::V4(ip) = interface.ip() {
171                            if let Err(e) = socket.join_multicast_group(cfg.multicast_group_ip, ip) {
172                                warn!("Failed to join multicast group on interface {}: {:?}", interface.ip(), e);
173                            }
174                            else {
175                                info!("Joined multicast group on interface {}", interface.ip());
176                            }
177                        }
178                    }
179                   
180                    // Set non-blocking
181                    socket.set_nonblocking(true)?;
182                    
183                    let running_port = if is_primary {
184                        debug!("Using primary multicast port {port} for discovery");
185                        MulticastRunningPort::Primary(port)
186                    }
187                    else if port == 0 {
188                        let failed_ports = cfg.iter_ports().filter(|p| *p != 0);
189                        warn!("Unable to start on the main or backup ports ({:?})!", &failed_ports.collect::<Vec<_>>());
190                        if !central_discovery_enabled {
191                            warn!("You may face issues with discovering");
192                        }
193                        else {
194                            warn!("You will be able to discover clients only when your network is online!");
195                        }
196                        MulticastRunningPort::Other
197                    }
198                    else {
199                        warn!("Using backup multicast port {port} for discovery (unable to start on main port {main_port})");
200                        MulticastRunningPort::Backup(port)
201                    };
202                    #[cfg(feature="rand")]
203                    let discover_id = rand::random_range(1..u32::MAX);
204                    #[cfg(not(feature="rand"))]
205                    let discover_id = std::time::UNIX_EPOCH.elapsed().unwrap_or_default().as_nanos() as u32;
206
207                    return Ok(Self {
208                        socket,
209                        interface_tracker,
210                        cfg: cfg.clone(),
211                        discover_id,
212                        running_port,
213
214                        announce_enabled: cfg.enable_announce,
215                        discover_replies: cfg.discover_replies,
216
217                        service_port_and_adv_data,
218                    })
219                }
220                Err(e) if e.kind() == io::ErrorKind::AddrInUse => {
221                    is_primary = false;
222                    continue
223                },
224                Err(e) => {
225                    is_primary = false;
226                    warn!("Failed to bind socket to port {port}: {e:?}");
227                    continue;
228                }
229            }
230        }
231    
232        error!("Failed to bind multicast discovery socket to any port!");
233        Err(io::Error::new(io::ErrorKind::AddrInUse, "Failed to bind socket to any port"))
234    }
235
236    /// Returns discover ID. It is generated randomly, and can be used to detect client application restart.
237    pub fn discover_id(&self) -> u32 {
238        self.discover_id
239    }
240    pub fn running_port(&self) -> MulticastRunningPort {
241        self.running_port
242    }
243
244    /// Setting this to `false` will disable periodic background announcements during `poll()`
245    /// Announcements are performed by periodic sending message `Announce` (and `ExtendAnnounce`)
246    pub fn set_announce_en(&mut self, en: bool) {
247        self.announce_enabled = en;
248    }
249
250    /// Setting this to `false` will disable automatic replies to `Discovery` messages
251    pub fn set_discover_replies_en(&mut self, enable: bool) {
252        self.discover_replies = enable;
253    }
254
255    /// Guaranteed to return Some(&mut D) if was created with `Self::new`
256    pub fn adv_data(&mut self) -> Option<&mut D> {
257        self.service_port_and_adv_data.as_mut().map(|s| &mut s.1)
258    }
259
260    /// Manually discover all clients on main or backup ports (using `Discovery` message).
261    /// Results can be collected by running `poll`.
262    pub fn discover(&mut self) {
263        info!("Multicast discovery: running manual discovery...");
264        let msg = DiscoveryMessage::Discovery::<D>.gen_message();
265        for (interface, _) in self.interface_tracker.iter_mut() {
266            let Some(index) = interface.index else {
267                continue;
268            };
269            if interface.ip().is_loopback() {
270                continue;
271            }
272
273            for port in self.cfg.iter_ports() {
274                if let Err(e) = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), index, interface.ip()) {
275                    warn!("Failed to send discovery message on interface [{}] - {}: {:?}", interface.ip(), interface.name, e);
276                } else {
277                    debug!("Sent discovery message to port {} on interface [{}] - {}", port, interface.ip(), interface.name);
278                }
279            }
280        }
281    }
282
283    /// Run `poll` periodically to handle internal discovery mechanisms:
284    /// - `Announce` messages periodic sending
285    /// - `ExtendAnnounce` messages periodic sending (if running on backup port)
286    /// - handling incoming messages (and returning discovery results via `discover_msg` callback)
287    ///
288    /// It is recommended to call this function in a loop with ~100ms sleep
289    pub fn poll(&mut self, mut discover_msg: impl FnMut(PollResult<D>)) {
290        // 0. poll interface updates
291        self.interface_tracker.poll_updates(|new_ip| {
292            if let Err(e) = self.socket.join_multicast_group(self.cfg.multicast_group_ip, new_ip) {
293                warn!("Failed to join multicast group on interface {new_ip}: {e:?}");
294            }
295            else {
296                info!("Joined multicast group on interface {new_ip}!");
297            }
298        });
299    
300        let mut interface_cnt = 0;
301        if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_mut() {
302            if self.announce_enabled {
303                for (interface, state) in self.interface_tracker.iter_mut() {
304                    let Some(interface_index) = interface.index else {
305                        continue;
306                    };
307                    // Skip for now
308                    if interface.ip().is_loopback() {
309                        continue;
310                    }
311
312                    // 1. Handle announcements
313                    let now = Instant::now();
314                    let should_announce = state.should_announce(now, &self.cfg);
315                    let should_extended_announce = state.should_extended_announce(now, &self.cfg);
316                    let should_send_extend_request = state.should_send_extend_request(now, &self.cfg);
317                    if should_announce {
318                        state.last_announce_tm = Some(now);
319
320                        let msg = DiscoveryMessage::Announce {
321                            service_port: *service_port,
322                            discover_id: self.discover_id,
323                            disconnected: false,
324                            adv_data: Cow::Borrowed(&*adv_data)
325                        }.gen_message();
326                        if should_extended_announce {
327                            for port in self.cfg.iter_ports() {
328                                let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), interface_index, interface.ip());
329                                handle_err(res, "send extended announce", interface);
330                            }
331                        }
332                        else {
333                            let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, self.cfg.multicast_port), interface_index, interface.ip());
334                            handle_err(res, "send normal announce", interface);
335                        }
336                        if state.extended_announce_enabled && !state.extended_announcements_enabled(now, &self.cfg) {
337                            state.extended_announce_enabled = false;
338                            info!("No longer sending extended announce on interface [{}] - {}", interface.ip(), interface.name);
339                        }
340                    }
341                    // 2. Sending extend requests
342                    if matches!(self.running_port, MulticastRunningPort::Backup(_)) && should_send_extend_request {
343                        state.extend_request_send_tm = Some(now);
344                        let msg = DiscoveryMessage::ExtendAnnouncements::<D>.gen_message();
345                        for port in self.cfg.iter_ports() {
346                            let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), interface_index, interface.ip());
347                            handle_err(res, "send extended announce request", interface);
348                        }
349                    }
350                    interface_cnt += 1;
351                }
352
353                if interface_cnt == 0 {
354                    warn!("No available interface found!");
355                    thread::sleep(Duration::from_millis(500));
356                }
357            }
358        }
359
360        // 3. Handle incoming packets
361        let mut buf = [0u8;256];
362        while let Ok((data, addr, index)) = self.socket.recv_from_iface(&mut buf) {
363
364            // Shut up messages from ourselves on all interfaces
365            if self.interface_tracker.iter_mut().any(|(i, _)| i.ip() == IpAddr::V4(*addr.ip())) && addr.port() == *self.running_port {
366                continue;
367            }
368
369            match DiscoveryMessage::<D>::try_parse(data) {
370                Some(DiscoveryMessage::Discovery) => {
371                    if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_mut() {
372                        if self.discover_replies {
373                            let announce = DiscoveryMessage::Announce {
374                                disconnected: false,
375                                discover_id: self.discover_id,
376                                service_port: *service_port,
377                                adv_data: Cow::Borrowed(&*adv_data)
378                            }.gen_message();
379                            let source_addr = self.interface_tracker.iter_mapping().find(|(i, _)| *i == index);
380                            if let Some((_, a)) = source_addr {
381                                if let Err(e) = self.socket.send_to_iface(&announce, addr, index, a.into()) {
382                                    warn!("Failed to answer to discovery packet: {e:?}");
383                                }
384                            }
385                            else {
386                                warn!("Failed to answer discovery packet: interface address not found for index!");
387                            }
388                        }
389                    }
390                }
391                Some(DiscoveryMessage::Announce { service_port, discover_id, disconnected, adv_data}) => {
392                    if disconnected {
393                        discover_msg(PollResult::DisconnectedClient {
394                            addr: SocketAddrV4::new(
395                                *addr.ip(),
396                                service_port,
397                            ),
398                            discover_id
399                        })
400                    }
401                    else {
402                        discover_msg(PollResult::DiscoveredClient {
403                            addr: SocketAddrV4::new(
404                                *addr.ip(),
405                                service_port,
406                            ),
407                            discover_id,
408                            adv_data: adv_data.as_ref(),
409                        })
410                    }
411                }
412                Some(DiscoveryMessage::ExtendAnnouncements) => {
413                    for (interface, state) in self.interface_tracker.iter_mut() {
414                        if interface.index == Some(index) {
415                            let now = Instant::now();
416                            if !state.extended_announcements_enabled(now, &self.cfg) {
417                                info!("Enabling extended announcements on interface [{}] - {}", interface.ip(), interface.name);
418                            }
419                            state.got_extend_announce_req(now);
420                        }
421                    }
422                }
423                None => {
424                    warn!("Received unknown message from {addr}: {data:?}");
425                }
426            }
427        }
428    }
429}
430impl<D: AdvertisementData> Drop for MulticastDiscoverySocket<D> {
431    fn drop(&mut self) {
432        // Announce disconnection
433        if !self.announce_enabled {
434            return;
435        }
436        if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_ref() {
437            for (interface, _) in self.interface_tracker.iter_mut() {
438                let Some(index) = interface.index else {
439                    continue;
440                };
441                // Skip for now
442                if interface.ip().is_loopback() {
443                    continue;
444                }
445
446                let msg = DiscoveryMessage::Announce {
447                    discover_id: self.discover_id,
448                    service_port: *service_port,
449                    disconnected: true,
450                    adv_data: Cow::Borrowed(adv_data)
451                }.gen_message();
452                for port in self.cfg.iter_ports() {
453                    let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port),index, interface.ip());
454                    handle_err(res, "announce disconnected message", interface);
455                }
456            }
457        }
458    }
459}
460
461fn handle_err(result: io::Result<usize>, msg: &'static str, interface: &Interface) {
462    if let Err(e) = result {
463        warn!("Failed to {} on interface [{:?}] - {}: {:?}", msg, interface.ip(), interface.name, e);
464    }
465}