multicast_discovery_socket/
lib.rs

1#![allow(clippy::int_plus_one)]
2#![allow(clippy::new_without_default)]
3
4use std::{io, thread};
5use std::borrow::Cow;
6use std::iter::once;
7use std::net::{IpAddr, SocketAddrV4};
8use std::ops::Deref;
9use std::time::{Duration, Instant};
10use if_addrs::Interface;
11use log::{debug, error, info, warn};
12use crate::config::MulticastDiscoveryConfig;
13use crate::interfaces::InterfaceTracker;
14use crate::protocol::DiscoveryMessage;
15use crate::socket::MultiInterfaceSocket;
16
17pub mod config;
18pub mod protocol;
19pub mod socket;
20pub mod interfaces;
21
22#[derive(Default)]
23pub struct PerInterfaceState {
24    last_announce_tm: Option<Instant>,
25    extended_announcements_request_tm: Option<Instant>,
26    extend_request_send_tm: Option<Instant>,
27
28    extended_announce_enabled: bool,
29}
30
31impl PerInterfaceState {
32    pub fn should_announce(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
33        self.last_announce_tm.is_none_or(|tm| now - tm > cfg.announce_interval)
34    }
35    pub fn should_send_extend_request(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
36        self.extend_request_send_tm.is_none_or(|tm| now - tm > cfg.extend_request_interval)
37    }
38    pub fn should_extended_announce(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
39        self.extended_announcements_enabled(now, cfg) && self.should_announce(now, cfg)
40    }
41
42    pub fn extended_announcements_enabled(&self, now: Instant, cfg: &MulticastDiscoveryConfig) -> bool {
43        self.extended_announcements_request_tm.is_some_and(|tm| now - tm < cfg.extended_announcement_effect_dur)
44    }
45    pub fn got_extend_announce_req(&mut self, now: Instant) {
46        self.extended_announcements_request_tm = Some(now);
47        self.extended_announce_enabled = true;
48    }
49}
50
51pub trait AdvertisementData: Sized + Clone {
52    fn encode_to_bytes(&self) -> Vec<u8>;
53    fn try_decode(bytes: &[u8]) -> Option<Self>;
54}
55#[cfg(not(feature="bincode"))]
56mod adv_data_impls {
57    use super::AdvertisementData;
58    impl AdvertisementData for () {
59        fn encode_to_bytes(&self) -> Vec<u8> {
60            Vec::new()
61        }
62        fn try_decode(bytes: &[u8]) -> Option<Self> {
63            if bytes.is_empty() {
64                Some(())
65            }
66            else {
67                None
68            }
69        }
70
71    }
72    impl AdvertisementData for Vec<u8> {
73        fn encode_to_bytes(&self) -> Vec<u8> {
74            self.clone()
75        }
76        fn try_decode(bytes: &[u8]) -> Option<Self> {
77            Some(bytes.to_vec())
78        }
79    }
80}
81
82#[cfg(feature="bincode")]
83use bincode::{Decode, Encode};
84#[cfg(feature="bincode")]
85impl<T> AdvertisementData for T
86    where T: Encode + Decode<()> + Clone {
87    fn encode_to_bytes(&self) -> Vec<u8> {
88        bincode::encode_to_vec(self, bincode::config::standard()).unwrap()
89    }
90    fn try_decode(bytes: &[u8]) -> Option<Self> {
91        bincode::decode_from_slice(bytes, bincode::config::standard())
92            .ok()
93            .map(|(v, _)| v)
94    }
95}
96
97pub struct MulticastDiscoverySocket<D: AdvertisementData> {
98    socket: MultiInterfaceSocket,
99    cfg: MulticastDiscoveryConfig,
100    discover_id: u32,
101    running_port: MulticastRunningPort,
102    interface_tracker: InterfaceTracker<PerInterfaceState>,
103
104    announce_enabled: bool,
105    discover_replies: bool,
106    /// Announce payload: service port. If not set, announcements are disabled
107    service_port_and_adv_data: Option<(u16, D)>,
108}
109
110#[derive(Debug, Copy, Clone)]
111pub enum MulticastRunningPort {
112    Primary(u16),
113    Backup(u16),
114    Other
115}
116
117impl Deref for MulticastRunningPort {
118    type Target = u16;
119    fn deref(&self) -> &Self::Target {
120        match self {
121            MulticastRunningPort::Primary(p) => p,
122            MulticastRunningPort::Backup(p) => p,
123            MulticastRunningPort::Other => &0,
124        }
125    }
126}
127
128
129pub enum PollResult<'a, D> {
130    DiscoveredClient {
131        addr: SocketAddrV4,
132        discover_id: u32,
133        adv_data: &'a D,
134    },
135    DisconnectedClient {
136        addr: SocketAddrV4,
137        discover_id: u32
138    }
139}
140
141impl<D: AdvertisementData> MulticastDiscoverySocket<D> {
142    /// Create new socket for multicast discovery. Announcements are disabled (running without service)
143    /// Enable feature `bincode` for passing `Encode` + `Decode` types as adv_data
144    pub fn new_discover_only(cfg: &MulticastDiscoveryConfig) -> io::Result<Self> {
145        Self::new(cfg, None)
146    }
147    /// Create new socket for multicast discovery. Announcements are enabled depending on config.
148    /// Enable feature `bincode` for passing `Encode` + `Decode` types as adv_data
149    pub fn new_with_service(cfg: &MulticastDiscoveryConfig, service_port: u16, initial_adv_data: D) -> io::Result<Self> {
150        Self::new(cfg, Some((service_port, initial_adv_data)))
151    }
152    fn new(cfg: &MulticastDiscoveryConfig, service_port_and_adv_data: Option<(u16, D)>) -> io::Result<Self> {
153        let central_discovery_enabled = cfg.central_discovery_addr.is_some();
154        let mut is_primary = true;
155        
156        let mut interface_tracker = InterfaceTracker::new();
157        // Try primary and backup ports
158        let main_port = cfg.iter_ports().next().unwrap();
159        for port in cfg.iter_ports().chain(once(0)) {
160            match MultiInterfaceSocket::bind_port(port) {
161                Ok(socket) => {
162                    // Join multicast group on all interfaces
163                    for (interface, _) in interface_tracker.iter_mut() {
164                        if let IpAddr::V4(ip) = interface.ip() {
165                            if let Err(e) = socket.join_multicast_group(cfg.multicast_group_ip, ip) {
166                                warn!("Failed to join multicast group on interface {}: {}", interface.ip(), e);
167                            }
168                            else {
169                                info!("Joined multicast group on interface {}", interface.ip());
170                            }
171                        }
172                    }
173                   
174                    // Set non-blocking
175                    socket.set_nonblocking(true)?;
176                    
177                    let running_port = if is_primary {
178                        debug!("Using primary multicast port {port} for discovery");
179                        MulticastRunningPort::Primary(port)
180                    }
181                    else if port == 0 {
182                        let failed_ports = cfg.iter_ports().filter(|p| *p != 0);
183                        warn!("Unable to start on the main or backup ports ({:?})!", &failed_ports.collect::<Vec<_>>());
184                        if !central_discovery_enabled {
185                            warn!("You may face issues with discovering");
186                        }
187                        else {
188                            warn!("You will be able to discover clients only when your network is online!");
189                        }
190                        MulticastRunningPort::Other
191                    }
192                    else {
193                        warn!("Using backup multicast port {port} for discovery (unable to start on main port {main_port})");
194                        MulticastRunningPort::Backup(port)
195                    };
196                    #[cfg(feature="rand")]
197                    let discover_id = rand::random_range(1..u32::MAX);
198                    #[cfg(not(feature="rand"))]
199                    let discover_id = std::time::UNIX_EPOCH.elapsed().unwrap_or_default().as_nanos() as u32;
200
201                    return Ok(Self {
202                        socket,
203                        interface_tracker,
204                        cfg: cfg.clone(),
205                        discover_id,
206                        running_port,
207
208                        announce_enabled: cfg.enable_announce,
209                        discover_replies: cfg.enable_announce,
210
211                        service_port_and_adv_data,
212                    })
213                }
214                Err(e) if e.kind() == io::ErrorKind::AddrInUse => {
215                    is_primary = false;
216                    continue
217                },
218                Err(e) => {
219                    is_primary = false;
220                    warn!("Failed to bind socket to port {port}: {e}");
221                    continue;
222                }
223            }
224        }
225    
226        error!("Failed to bind multicast discovery socket to any port!");
227        Err(io::Error::new(io::ErrorKind::AddrInUse, "Failed to bind socket to any port"))
228    }
229
230    /// Returns discover ID. It is generated randomly, and can be used to detect client application restart.
231    pub fn discover_id(&self) -> u32 {
232        self.discover_id
233    }
234    pub fn running_port(&self) -> MulticastRunningPort {
235        self.running_port
236    }
237
238    /// Setting this to `false` will disable periodic background announcements during `poll()`
239    /// Announcements are performed by periodic sending message `Announce` (and `ExtendAnnounce`)
240    pub fn set_announce_en(&mut self, en: bool) {
241        self.announce_enabled = en;
242    }
243
244    /// Setting this to `false` will disable automatic replies to `Discovery` messages
245    pub fn set_discover_replies_en(&mut self, enable: bool) {
246        self.discover_replies = enable;
247    }
248
249    /// Guaranteed to return Some(&mut D) if was created with `Self::new`
250    pub fn adv_data(&mut self) -> Option<&mut D> {
251        self.service_port_and_adv_data.as_mut().map(|s| &mut s.1)
252    }
253
254    /// Manually discover all clients on main or backup ports (using `Discovery` message).
255    /// Results can be collected by running `poll`.
256    pub fn discover(&mut self) {
257        info!("Multicast discovery: running manual discovery...");
258        let msg = DiscoveryMessage::Discovery::<D>.gen_message();
259        for (interface, _) in self.interface_tracker.iter_mut() {
260            let Some(index) = interface.index else {
261                continue;
262            };
263            if interface.ip().is_loopback() {
264                continue;
265            }
266
267            for port in self.cfg.iter_ports() {
268                if let Err(e) = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), index, interface.addr.ip()) {
269                    warn!("Failed to send discovery message on interface [{}] - {}: {}", interface.ip(), interface.name, e);
270                } else {
271                    debug!("Sent discovery message to port {} on interface [{}] - {}", port, interface.ip(), interface.name);
272                }
273            }
274        }
275    }
276
277    /// Run `poll` periodically to handle internal discovery mechanisms:
278    /// - `Announce` messages periodic sending
279    /// - `ExtendAnnounce` messages periodic sending (if running on backup port)
280    /// - handling incoming messages (and returning discovery results via `discover_msg` callback)
281    ///
282    /// It is recommended to call this function in a loop with ~100ms sleep
283    pub fn poll(&mut self, mut discover_msg: impl FnMut(PollResult<D>)) {
284        // 0. poll interface updates
285        self.interface_tracker.poll_updates(|new_ip| {
286            if let Err(e) = self.socket.join_multicast_group(self.cfg.multicast_group_ip, new_ip) {
287                warn!("Failed to join multicast group on interface {new_ip}: {e}");
288            }
289            else {
290                info!("Joined multicast group on interface {new_ip}!");
291            }
292        });
293    
294        let mut interface_cnt = 0;
295        if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_mut() {
296            if self.announce_enabled {
297                for (interface, state) in self.interface_tracker.iter_mut() {
298                    let Some(interface_index) = interface.index else {
299                        continue;
300                    };
301                    // Skip for now
302                    if interface.ip().is_loopback() {
303                        continue;
304                    }
305
306                    // 1. Handle announcements
307                    let now = Instant::now();
308                    let should_announce = state.should_announce(now, &self.cfg);
309                    let should_extended_announce = state.should_extended_announce(now, &self.cfg);
310                    let should_send_extend_request = state.should_send_extend_request(now, &self.cfg);
311                    if should_announce {
312                        state.last_announce_tm = Some(now);
313
314                        let msg = DiscoveryMessage::Announce {
315                            service_port: *service_port,
316                            discover_id: self.discover_id,
317                            disconnected: false,
318                            adv_data: Cow::Borrowed(&*adv_data)
319                        }.gen_message();
320                        if should_extended_announce {
321                            for port in self.cfg.iter_ports() {
322                                let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), interface_index, interface.addr.ip());
323                                handle_err(res, "send extended announce", interface);
324                            }
325                        }
326                        else {
327                            let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, self.cfg.multicast_port), interface_index, interface.addr.ip());
328                            handle_err(res, "send normal announce", interface);
329                        }
330                        if state.extended_announce_enabled && !state.extended_announcements_enabled(now, &self.cfg) {
331                            state.extended_announce_enabled = false;
332                            info!("No longer sending extended announce on interface [{}] - {}", interface.ip(), interface.name);
333                        }
334                    }
335                    // 2. Sending extend requests
336                    if matches!(self.running_port, MulticastRunningPort::Backup(_)) && should_send_extend_request {
337                        state.extend_request_send_tm = Some(now);
338                        let msg = DiscoveryMessage::ExtendAnnouncements::<D>.gen_message();
339                        for port in self.cfg.iter_ports() {
340                            let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port), interface_index, interface.addr.ip());
341                            handle_err(res, "send extended announce request", interface);
342                        }
343                    }
344                    interface_cnt += 1;
345                }
346
347                if interface_cnt == 0 {
348                    warn!("No available interface found!");
349                    thread::sleep(Duration::from_millis(500));
350                }
351            }
352        }
353
354        // 3. Handle incoming packets
355        let mut buf = [0u8;256];
356        while let Ok((data, addr, index)) = self.socket.recv_from_iface(&mut buf) {
357
358            // Shut up messages from ourselves on all interfaces
359            if self.interface_tracker.iter_mut().any(|(i, _)| i.ip() == IpAddr::V4(*addr.ip())) && addr.port() == *self.running_port {
360                continue;
361            }
362
363            match DiscoveryMessage::<D>::try_parse(data) {
364                Some(DiscoveryMessage::Discovery) => {
365                    if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_mut() {
366                        if self.discover_replies {
367                            let announce = DiscoveryMessage::Announce {
368                                disconnected: false,
369                                discover_id: self.discover_id,
370                                service_port: *service_port,
371                                adv_data: Cow::Borrowed(&*adv_data)
372                            }.gen_message();
373                            let source_addr = self.interface_tracker.iter_mapping().find(|(i, _)| *i == index);
374                            if let Some((_, a)) = source_addr {
375                                if let Err(e) = self.socket.send_to_iface(&announce, addr, index, a.into()) {
376                                    warn!("Failed to answer to discovery packet: {e:?}");
377                                }
378                            }
379                            else {
380                                warn!("Failed to answer discovery packet: interface address not found for index!");
381                            }
382                        }
383                    }
384                }
385                Some(DiscoveryMessage::Announce { service_port, discover_id, disconnected, adv_data}) => {
386                    if disconnected {
387                        discover_msg(PollResult::DisconnectedClient {
388                            addr: SocketAddrV4::new(
389                                *addr.ip(),
390                                service_port,
391                            ),
392                            discover_id
393                        })
394                    }
395                    else {
396                        discover_msg(PollResult::DiscoveredClient {
397                            addr: SocketAddrV4::new(
398                                *addr.ip(),
399                                service_port,
400                            ),
401                            discover_id,
402                            adv_data: adv_data.as_ref(),
403                        })
404                    }
405                }
406                Some(DiscoveryMessage::ExtendAnnouncements) => {
407                    for (interface, state) in self.interface_tracker.iter_mut() {
408                        if interface.index == Some(index) {
409                            let now = Instant::now();
410                            if !state.extended_announcements_enabled(now, &self.cfg) {
411                                info!("Enabling extended announcements on interface [{}] - {}", interface.ip(), interface.name);
412                            }
413                            state.got_extend_announce_req(now);
414                        }
415                    }
416                }
417                None => {
418                    warn!("Received unknown message from {addr}: {data:?}");
419                }
420            }
421        }
422    }
423}
424impl<D: AdvertisementData> Drop for MulticastDiscoverySocket<D> {
425    fn drop(&mut self) {
426        // Announce disconnection
427        if !self.announce_enabled {
428            return;
429        }
430        if let Some((service_port, adv_data)) = self.service_port_and_adv_data.as_ref() {
431            for (interface, _) in self.interface_tracker.iter_mut() {
432                let Some(index) = interface.index else {
433                    continue;
434                };
435                // Skip for now
436                if interface.ip().is_loopback() {
437                    continue;
438                }
439
440                let msg = DiscoveryMessage::Announce {
441                    discover_id: self.discover_id,
442                    service_port: *service_port,
443                    disconnected: true,
444                    adv_data: Cow::Borrowed(adv_data)
445                }.gen_message();
446                for port in self.cfg.iter_ports() {
447                    let res = self.socket.send_to_iface(&msg, SocketAddrV4::new(self.cfg.multicast_group_ip, port),index, interface.addr.ip());
448                    handle_err(res, "announce disconnected message", interface);
449                }
450            }
451        }
452    }
453}
454
455fn handle_err(result: io::Result<usize>, msg: &'static str, interface: &Interface) {
456    if let Err(e) = result {
457        warn!("Failed to {} on interface [{:?}] - {}: {}", msg, interface.ip(), interface.name, e);
458    }
459}