1use std::borrow::Cow;
2use std::net::IpAddr;
3use std::time::Duration;
4
5use tosca::device::DeviceData;
6
7use flume::RecvTimeoutError;
8
9use mdns_sd::{IfKind, Receiver, ResolvedService, ServiceDaemon, ServiceEvent};
10
11use tokio::time::sleep;
12
13use tracing::{info, warn};
14
15use crate::device::{Description, Device, Devices, NetworkInformation, build_device_address};
16use crate::error::Error;
17use crate::events::Events;
18use crate::request::create_requests;
19
20const TOP_LEVEL_DOMAIN: &str = "local";
24
25#[derive(Debug, PartialEq)]
27pub enum TransportProtocol {
28 TCP,
30 UDP,
32}
33
34impl std::fmt::Display for TransportProtocol {
35 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
36 self.name().fmt(f)
37 }
38}
39
40impl TransportProtocol {
41 #[must_use]
43 pub const fn name(&self) -> &'static str {
44 match self {
45 Self::TCP => "tcp",
46 Self::UDP => "udp",
47 }
48 }
49}
50
51#[derive(Debug, PartialEq)]
56pub struct Discovery {
57 domain: Cow<'static, str>,
58 transport_protocol: TransportProtocol,
59 top_level_domain: Cow<'static, str>,
60 timeout: Duration,
61 disable_ipv6: bool,
62 disable_ip: Option<IpAddr>,
63 disable_network_interface: Option<&'static str>,
64}
65
66impl Discovery {
67 #[must_use]
69 #[inline]
70 pub fn new(domain: impl Into<Cow<'static, str>>) -> Self {
71 Self {
72 domain: domain.into(),
73 transport_protocol: TransportProtocol::TCP,
74 top_level_domain: Cow::Borrowed(TOP_LEVEL_DOMAIN),
75 timeout: Duration::from_secs(2), disable_ipv6: false,
77 disable_ip: None,
78 disable_network_interface: None,
79 }
80 }
81
82 #[must_use]
86 pub const fn timeout(mut self, timeout: Duration) -> Self {
87 self.timeout = timeout;
88 self
89 }
90
91 #[must_use]
93 pub const fn transport_protocol(mut self, transport_protocol: TransportProtocol) -> Self {
94 self.transport_protocol = transport_protocol;
95 self
96 }
97
98 #[must_use]
102 #[inline]
103 pub fn domain(mut self, domain: impl Into<Cow<'static, str>>) -> Self {
104 self.domain = domain.into();
105 self
106 }
107
108 #[must_use]
112 #[inline]
113 pub fn top_level_domain(mut self, top_level_domain: impl Into<Cow<'static, str>>) -> Self {
114 self.top_level_domain = top_level_domain.into();
115 self
116 }
117
118 #[must_use]
120 pub const fn disable_ipv6(mut self) -> Self {
121 self.disable_ipv6 = true;
122 self
123 }
124
125 #[must_use]
127 #[inline]
128 pub fn disable_ip(mut self, ip: impl Into<IpAddr>) -> Self {
129 self.disable_ip = Some(ip.into());
130 self
131 }
132
133 #[must_use]
135 pub const fn disable_network_interface(mut self, network_interface: &'static str) -> Self {
136 self.disable_network_interface = Some(network_interface);
137 self
138 }
139
140 pub(crate) async fn discover(&self) -> Result<Devices, Error> {
141 let discovery_info = self.discover_devices().await?;
143
144 Self::obtain_devices_data(discovery_info).await
145 }
146
147 async fn discover_devices(&self) -> Result<Vec<ResolvedService>, Error> {
148 let mdns = ServiceDaemon::new()?;
150
151 if self.disable_ipv6 {
153 mdns.disable_interface(IfKind::IPv6)?;
154 }
155
156 if let Some(ip) = self.disable_ip {
158 mdns.disable_interface(ip)?;
159 }
160
161 if let Some(network_interface) = self.disable_network_interface {
163 mdns.disable_interface(network_interface)?;
164 }
165
166 let service_type = format!(
168 "_{}._{}.{}.",
169 self.domain,
170 self.transport_protocol.name(),
171 self.top_level_domain
172 );
173
174 let receiver = mdns.browse(&service_type)?;
176
177 let mut discovery_service = Vec::new();
179
180 while let Ok(event) = self.with_timeout(&receiver).await {
183 if let ServiceEvent::ServiceResolved(info) = event {
184 if info.get_addresses().is_empty() {
189 warn!("No device address available for {:?}", info);
190 continue;
191 }
192
193 if Self::check_device_duplicates(&discovery_service, &info) {
195 continue;
196 }
197
198 discovery_service.push(*info);
199 }
200 }
201
202 mdns.stop_browse(&service_type)?;
204
205 Ok(discovery_service)
206 }
207
208 #[inline]
209 async fn with_timeout<T>(&self, receiver: &Receiver<T>) -> Result<T, RecvTimeoutError> {
210 let timeout_future = sleep(self.timeout);
211
212 tokio::select! {
213 () = timeout_future => {
214 Err(RecvTimeoutError::Timeout)
217 }
218 result = receiver.recv_async() => {
219 result.map_err(|_| RecvTimeoutError::Disconnected)
220 }
221 }
222 }
223
224 async fn obtain_devices_data(
225 discovery_service: Vec<ResolvedService>,
226 ) -> Result<Devices, Error> {
227 let mut devices = Devices::new();
229
230 for service in discovery_service {
232 for address in &service.addresses {
235 let complete_address = build_device_address(
236 service
237 .txt_properties
238 .get_property_val_str("scheme")
239 .unwrap_or("http"),
242 &address.to_ip_addr(),
243 service.port,
244 );
245 info!("Complete address: {complete_address}");
246
247 match reqwest::Client::new()
249 .get(&complete_address)
250 .header("Connection", "close")
251 .send()
252 .await
253 {
254 Ok(response) => {
255 let device_data: DeviceData = response.json().await?;
256
257 if device_data.wifi_mac.is_none() && device_data.ethernet_mac.is_none() {
258 warn!(
259 "Ignoring device {complete_address} because no valid MAC addresses have been found"
260 );
261 continue;
262 }
263
264 let requests = create_requests(
265 device_data.route_configs,
266 &complete_address,
267 &device_data.main_route,
268 device_data.environment,
269 );
270
271 let description = Description::new(
272 device_data.kind,
273 device_data.environment,
274 device_data.main_route.into_owned(),
275 );
276
277 let mut network_info = NetworkInformation::new(
278 service.fullname,
279 service
280 .addresses
281 .into_iter()
282 .map(|address| address.to_ip_addr())
283 .collect(),
284 service.port,
285 service.txt_properties.into_property_map_str(),
286 complete_address,
287 );
288
289 if let Some(mac) = device_data.wifi_mac {
290 network_info = network_info.wifi_mac(mac);
291 }
292
293 if let Some(mac) = device_data.ethernet_mac {
294 network_info = network_info.ethernet_mac(mac);
295 }
296
297 let events = device_data.events_description.map(Events::new);
298
299 devices.add(Device::init(network_info, description, requests, events));
300
301 break;
303 }
304 Err(e) => {
305 warn!("Impossible to contact address {complete_address}: {e}");
306 }
307 }
308 }
309 }
310
311 Ok(devices)
312 }
313
314 fn check_device_duplicates(
327 discovery_service: &[ResolvedService],
328 info: &ResolvedService,
329 ) -> bool {
330 for disco_service in discovery_service {
331 if disco_service.port != info.get_port() {
334 continue;
335 }
336
337 for address in &disco_service.addresses {
338 if info.get_addresses().contains(address) {
339 return true;
340 }
341 }
342
343 if disco_service.fullname == info.get_fullname() {
344 return true;
345 }
346 }
347 false
348 }
349}
350
351#[cfg(test)]
352pub(crate) mod tests {
353 use std::time::Duration;
354
355 use tracing::warn;
356
357 use serial_test::serial;
358
359 use crate::tests::{
360 DOMAIN, check_function_with_device, check_function_with_two_devices, compare_device_data,
361 };
362
363 use super::Discovery;
364
365 pub(crate) fn configure_discovery() -> Discovery {
366 Discovery::new(DOMAIN)
367 .timeout(Duration::from_secs(1))
368 .disable_ipv6()
369 .disable_network_interface("docker0")
370 }
371
372 async fn discovery_comparison(devices_len: usize) {
373 let devices = configure_discovery().discover().await.unwrap();
374
375 assert_eq!(devices.len(), devices_len);
377
378 for device in devices {
380 compare_device_data(&device);
381 }
382 }
383
384 #[inline]
385 async fn run_discovery_function<F, Fut>(name: &str, function: F)
386 where
387 F: FnOnce() -> Fut,
388 Fut: Future<Output = ()>,
389 {
390 if option_env!("CI").is_some() {
391 warn!(
392 "Skipping test on CI: {} can run only on systems that expose physical MAC addresses.",
393 name
394 );
395 } else {
396 function().await;
397 }
398 }
399
400 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
401 #[serial]
402 async fn test_single_device_discovery() {
403 run_discovery_function("discovery_with_single_device", || async {
404 check_function_with_device(|| async {
405 discovery_comparison(1).await;
406 })
407 .await;
408 })
409 .await;
410 }
411
412 #[tokio::test(flavor = "multi_thread", worker_threads = 3)]
413 #[serial]
414 async fn test_more_devices_discovery() {
415 run_discovery_function("discovery_with_more_devices", || async {
416 check_function_with_two_devices(|| async {
417 discovery_comparison(2).await;
418 })
419 .await;
420 })
421 .await;
422 }
423}