tosca_controller/
discovery.rs

1use std::borrow::Cow;
2use std::net::IpAddr;
3use std::time::Duration;
4
5use tosca::device::DeviceData;
6
7use flume::RecvTimeoutError;
8
9use mdns_sd::{IfKind, Receiver, ResolvedService, ServiceDaemon, ServiceEvent};
10
11use tokio::time::sleep;
12
13use tracing::{info, warn};
14
15use crate::device::{Description, Device, Devices, NetworkInformation, build_device_address};
16use crate::error::Error;
17use crate::events::Events;
18use crate::request::create_requests;
19
20// Service top-level domain.
21//
22// It defines the default top-level domain for a service.
23const TOP_LEVEL_DOMAIN: &str = "local";
24
25/// The discovery service transport protocol.
26#[derive(Debug, PartialEq)]
27pub enum TransportProtocol {
28    /// TCP-based service.
29    TCP,
30    /// UDP-based service.
31    UDP,
32}
33
34impl std::fmt::Display for TransportProtocol {
35    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
36        self.name().fmt(f)
37    }
38}
39
40impl TransportProtocol {
41    /// Returns the [`TransportProtocol`] name.
42    #[must_use]
43    pub const fn name(&self) -> &'static str {
44        match self {
45            Self::TCP => "tcp",
46            Self::UDP => "udp",
47        }
48    }
49}
50
51/// Device discovery service.
52///
53/// A service for identifying and registering all `tosca` devices within
54/// a network.
55#[derive(Debug, PartialEq)]
56pub struct Discovery {
57    domain: Cow<'static, str>,
58    transport_protocol: TransportProtocol,
59    top_level_domain: Cow<'static, str>,
60    timeout: Duration,
61    disable_ipv6: bool,
62    disable_ip: Option<IpAddr>,
63    disable_network_interface: Option<&'static str>,
64}
65
66impl Discovery {
67    /// Creates [`Discovery`].
68    #[must_use]
69    #[inline]
70    pub fn new(domain: impl Into<Cow<'static, str>>) -> Self {
71        Self {
72            domain: domain.into(),
73            transport_protocol: TransportProtocol::TCP,
74            top_level_domain: Cow::Borrowed(TOP_LEVEL_DOMAIN),
75            timeout: Duration::from_secs(2), // Default timeout of 2s.
76            disable_ipv6: false,
77            disable_ip: None,
78            disable_network_interface: None,
79        }
80    }
81
82    /// Sets the service timeout.
83    ///
84    /// The entire discovery process will last for the given timeout value.
85    #[must_use]
86    pub const fn timeout(mut self, timeout: Duration) -> Self {
87        self.timeout = timeout;
88        self
89    }
90
91    /// Sets the service transport protocol.
92    #[must_use]
93    pub const fn transport_protocol(mut self, transport_protocol: TransportProtocol) -> Self {
94        self.transport_protocol = transport_protocol;
95        self
96    }
97
98    /// Sets the service domain.
99    ///
100    /// The domain searched by the service. i.e. tosca
101    #[must_use]
102    #[inline]
103    pub fn domain(mut self, domain: impl Into<Cow<'static, str>>) -> Self {
104        self.domain = domain.into();
105        self
106    }
107
108    /// Sets the service top-level domain.
109    ///
110    /// A common top-level domain is `.local`.
111    #[must_use]
112    #[inline]
113    pub fn top_level_domain(mut self, top_level_domain: impl Into<Cow<'static, str>>) -> Self {
114        self.top_level_domain = top_level_domain.into();
115        self
116    }
117
118    /// Excludes devices with `IPv6` interfaces from the discovery service.
119    #[must_use]
120    pub const fn disable_ipv6(mut self) -> Self {
121        self.disable_ipv6 = true;
122        self
123    }
124
125    /// Excludes the device with the given `IP` from the discovery service.
126    #[must_use]
127    #[inline]
128    pub fn disable_ip(mut self, ip: impl Into<IpAddr>) -> Self {
129        self.disable_ip = Some(ip.into());
130        self
131    }
132
133    /// Disables the given network interface from the discovery service.
134    #[must_use]
135    pub const fn disable_network_interface(mut self, network_interface: &'static str) -> Self {
136        self.disable_network_interface = Some(network_interface);
137        self
138    }
139
140    pub(crate) async fn discover(&self) -> Result<Devices, Error> {
141        // Discover devices.
142        let discovery_info = self.discover_devices().await?;
143
144        Self::obtain_devices_data(discovery_info).await
145    }
146
147    async fn discover_devices(&self) -> Result<Vec<ResolvedService>, Error> {
148        // Create a mdns daemon
149        let mdns = ServiceDaemon::new()?;
150
151        // Disable IPv6 interface.
152        if self.disable_ipv6 {
153            mdns.disable_interface(IfKind::IPv6)?;
154        }
155
156        // Disable IP.
157        if let Some(ip) = self.disable_ip {
158            mdns.disable_interface(ip)?;
159        }
160
161        // Disable network interface.
162        if let Some(network_interface) = self.disable_network_interface {
163            mdns.disable_interface(network_interface)?;
164        }
165
166        // Service type.
167        let service_type = format!(
168            "_{}._{}.{}.",
169            self.domain,
170            self.transport_protocol.name(),
171            self.top_level_domain
172        );
173
174        // Detects devices.
175        let receiver = mdns.browse(&service_type)?;
176
177        // Discovery service.
178        let mut discovery_service = Vec::new();
179
180        // Run for n-seconds in search of devices and saves their information
181        // in memory.
182        while let Ok(event) = self.with_timeout(&receiver).await {
183            if let ServiceEvent::ServiceResolved(info) = event {
184                // Check whether there are device addresses.
185                //
186                // If no address has been found, prints a warning and
187                // continue the loop.
188                if info.get_addresses().is_empty() {
189                    warn!("No device address available for {:?}", info);
190                    continue;
191                }
192
193                // If two devices are equal, skip to the next one.
194                if Self::check_device_duplicates(&discovery_service, &info) {
195                    continue;
196                }
197
198                discovery_service.push(*info);
199            }
200        }
201
202        // Stop detection.
203        mdns.stop_browse(&service_type)?;
204
205        Ok(discovery_service)
206    }
207
208    #[inline]
209    async fn with_timeout<T>(&self, receiver: &Receiver<T>) -> Result<T, RecvTimeoutError> {
210        let timeout_future = sleep(self.timeout);
211
212        tokio::select! {
213            () = timeout_future => {
214                // This is the same error returned by the `recv_timeout`
215                // function in case of a timeout.
216                Err(RecvTimeoutError::Timeout)
217            }
218            result = receiver.recv_async() => {
219                result.map_err(|_| RecvTimeoutError::Disconnected)
220            }
221        }
222    }
223
224    async fn obtain_devices_data(
225        discovery_service: Vec<ResolvedService>,
226    ) -> Result<Devices, Error> {
227        // Devices collection.
228        let mut devices = Devices::new();
229
230        // Iterate over discovered metadata
231        for service in discovery_service {
232            // Try to contact each available address for a device
233            // to retrieve data.
234            for address in &service.addresses {
235                let complete_address = build_device_address(
236                    service
237                        .txt_properties
238                        .get_property_val_str("scheme")
239                        // If the scheme is not specified as a property,
240                        // fall back to `http` as default.
241                        .unwrap_or("http"),
242                    &address.to_ip_addr(),
243                    service.port,
244                );
245                info!("Complete address: {complete_address}");
246
247                // Contact devices to retrieve their data
248                match reqwest::Client::new()
249                    .get(&complete_address)
250                    .header("Connection", "close")
251                    .send()
252                    .await
253                {
254                    Ok(response) => {
255                        let device_data: DeviceData = response.json().await?;
256
257                        if device_data.wifi_mac.is_none() && device_data.ethernet_mac.is_none() {
258                            warn!(
259                                "Ignoring device {complete_address} because no valid MAC addresses have been found"
260                            );
261                            continue;
262                        }
263
264                        let requests = create_requests(
265                            device_data.route_configs,
266                            &complete_address,
267                            &device_data.main_route,
268                            device_data.environment,
269                        );
270
271                        let description = Description::new(
272                            device_data.kind,
273                            device_data.environment,
274                            device_data.main_route.into_owned(),
275                        );
276
277                        let mut network_info = NetworkInformation::new(
278                            service.fullname,
279                            service
280                                .addresses
281                                .into_iter()
282                                .map(|address| address.to_ip_addr())
283                                .collect(),
284                            service.port,
285                            service.txt_properties.into_property_map_str(),
286                            complete_address,
287                        );
288
289                        if let Some(mac) = device_data.wifi_mac {
290                            network_info = network_info.wifi_mac(mac);
291                        }
292
293                        if let Some(mac) = device_data.ethernet_mac {
294                            network_info = network_info.ethernet_mac(mac);
295                        }
296
297                        let events = device_data.events_description.map(Events::new);
298
299                        devices.add(Device::init(network_info, description, requests, events));
300
301                        // Only a single address is necessary.
302                        break;
303                    }
304                    Err(e) => {
305                        warn!("Impossible to contact address {complete_address}: {e}");
306                    }
307                }
308            }
309        }
310
311        Ok(devices)
312    }
313
314    // A discovered device is equal to another device when:
315    //
316    // - It has an address with IP and port identical to the ones of
317    //   another device address.
318    //   Devices belonging to the same local network CANNOT HAVE any IP
319    //   and port in common.
320    //
321    //   OR
322    //
323    // - It has the same full name of another device belonging to the same
324    //   network. A full name, in this case, represents the device ID.
325    //   Two devices belonging to the same network CANNOT HAVE the same ID.
326    fn check_device_duplicates(
327        discovery_service: &[ResolvedService],
328        info: &ResolvedService,
329    ) -> bool {
330        for disco_service in discovery_service {
331            // When the addresses have distinct ports, they are always
332            // different, so they are not considered.
333            if disco_service.port != info.get_port() {
334                continue;
335            }
336
337            for address in &disco_service.addresses {
338                if info.get_addresses().contains(address) {
339                    return true;
340                }
341            }
342
343            if disco_service.fullname == info.get_fullname() {
344                return true;
345            }
346        }
347        false
348    }
349}
350
351#[cfg(test)]
352pub(crate) mod tests {
353    use std::time::Duration;
354
355    use tracing::warn;
356
357    use serial_test::serial;
358
359    use crate::tests::{
360        DOMAIN, check_function_with_device, check_function_with_two_devices, compare_device_data,
361    };
362
363    use super::Discovery;
364
365    pub(crate) fn configure_discovery() -> Discovery {
366        Discovery::new(DOMAIN)
367            .timeout(Duration::from_secs(1))
368            .disable_ipv6()
369            .disable_network_interface("docker0")
370    }
371
372    async fn discovery_comparison(devices_len: usize) {
373        let devices = configure_discovery().discover().await.unwrap();
374
375        // Count devices.
376        assert_eq!(devices.len(), devices_len);
377
378        // Iterate over devices and compare data.
379        for device in devices {
380            compare_device_data(&device);
381        }
382    }
383
384    #[inline]
385    async fn run_discovery_function<F, Fut>(name: &str, function: F)
386    where
387        F: FnOnce() -> Fut,
388        Fut: Future<Output = ()>,
389    {
390        if option_env!("CI").is_some() {
391            warn!(
392                "Skipping test on CI: {} can run only on systems that expose physical MAC addresses.",
393                name
394            );
395        } else {
396            function().await;
397        }
398    }
399
400    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
401    #[serial]
402    async fn test_single_device_discovery() {
403        run_discovery_function("discovery_with_single_device", || async {
404            check_function_with_device(|| async {
405                discovery_comparison(1).await;
406            })
407            .await;
408        })
409        .await;
410    }
411
412    #[tokio::test(flavor = "multi_thread", worker_threads = 3)]
413    #[serial]
414    async fn test_more_devices_discovery() {
415        run_discovery_function("discovery_with_more_devices", || async {
416            check_function_with_two_devices(|| async {
417                discovery_comparison(2).await;
418            })
419            .await;
420        })
421        .await;
422    }
423}