Skip to main content

par_term_ssh/
mdns.rs

1//! mDNS/Bonjour discovery for SSH hosts on the local network.
2//!
3//! Uses the `mdns-sd` crate to browse for `_ssh._tcp.local.` services.
4//! Discovery runs asynchronously and sends results via an mpsc channel.
5
6use super::types::{SshHost, SshHostSource};
7use mdns_sd::{ServiceDaemon, ServiceEvent};
8use std::sync::mpsc;
9use std::time::Duration;
10
11/// mDNS discovery state.
12pub struct MdnsDiscovery {
13    /// Discovered hosts from mDNS
14    discovered: Vec<SshHost>,
15    /// Whether a scan is currently running
16    scanning: bool,
17    /// Receiver for hosts from background scan
18    receiver: Option<mpsc::Receiver<SshHost>>,
19}
20
21impl Default for MdnsDiscovery {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl MdnsDiscovery {
28    pub fn new() -> Self {
29        Self {
30            discovered: Vec::new(),
31            scanning: false,
32            receiver: None,
33        }
34    }
35
36    /// Start an mDNS scan for SSH services.
37    pub fn start_scan(&mut self, timeout_secs: u32) {
38        if self.scanning {
39            return;
40        }
41
42        self.scanning = true;
43        self.discovered.clear();
44
45        let (tx, rx) = mpsc::channel();
46        self.receiver = Some(rx);
47
48        let timeout = Duration::from_secs(u64::from(timeout_secs));
49
50        std::thread::spawn(move || {
51            run_mdns_scan(tx, timeout);
52        });
53    }
54
55    /// Poll for newly discovered hosts. Returns true if new hosts were found.
56    pub fn poll(&mut self) -> bool {
57        let receiver = match &self.receiver {
58            Some(r) => r,
59            None => return false,
60        };
61
62        let mut found_new = false;
63
64        // Drain all available hosts from the channel
65        loop {
66            match receiver.try_recv() {
67                Ok(host) => {
68                    let duplicate = self
69                        .discovered
70                        .iter()
71                        .any(|h| h.hostname == host.hostname && h.port == host.port);
72                    if !duplicate {
73                        self.discovered.push(host);
74                        found_new = true;
75                    }
76                }
77                Err(mpsc::TryRecvError::Empty) => break,
78                Err(mpsc::TryRecvError::Disconnected) => {
79                    // Scan thread has finished
80                    self.scanning = false;
81                    self.receiver = None;
82                    break;
83                }
84            }
85        }
86
87        found_new
88    }
89
90    /// Returns the list of discovered hosts.
91    pub fn hosts(&self) -> &[SshHost] {
92        &self.discovered
93    }
94
95    /// Returns whether a scan is currently in progress.
96    pub fn is_scanning(&self) -> bool {
97        self.scanning
98    }
99
100    /// Clear all discovered hosts and stop any in-progress scan.
101    pub fn clear(&mut self) {
102        self.discovered.clear();
103        self.scanning = false;
104        self.receiver = None;
105    }
106}
107
108/// Run an mDNS scan in a background thread, sending discovered SSH hosts
109/// through the provided channel.
110fn run_mdns_scan(tx: mpsc::Sender<SshHost>, timeout: Duration) {
111    let daemon = match ServiceDaemon::new() {
112        Ok(d) => d,
113        Err(e) => {
114            log::warn!("Failed to start mDNS daemon: {}", e);
115            return;
116        }
117    };
118
119    let receiver = match daemon.browse("_ssh._tcp.local.") {
120        Ok(r) => r,
121        Err(e) => {
122            log::warn!("Failed to browse mDNS: {}", e);
123            let _ = daemon.shutdown();
124            return;
125        }
126    };
127
128    let deadline = std::time::Instant::now() + timeout;
129
130    loop {
131        if std::time::Instant::now() >= deadline {
132            break;
133        }
134
135        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
136        match receiver.recv_timeout(remaining.min(Duration::from_millis(500))) {
137            Ok(ServiceEvent::ServiceResolved(info)) => {
138                let hostname = info.get_hostname().trim_end_matches('.').to_string();
139                let port = info.get_port();
140                let service_name = info
141                    .get_fullname()
142                    .split("._ssh._tcp")
143                    .next()
144                    .unwrap_or(&hostname)
145                    .to_string();
146
147                let host = SshHost {
148                    alias: service_name,
149                    hostname: Some(hostname),
150                    user: None,
151                    port: if port == 22 { None } else { Some(port) },
152                    identity_file: None,
153                    proxy_jump: None,
154                    source: SshHostSource::Mdns,
155                };
156
157                if tx.send(host).is_err() {
158                    break;
159                }
160            }
161            Ok(_) => {
162                // Ignore other events (SearchStarted, ServiceFound, etc.)
163            }
164            Err(_) if receiver.is_disconnected() => break,
165            Err(_) => continue, // Timeout — keep waiting
166        }
167    }
168
169    let _ = daemon.shutdown();
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_mdns_discovery_new() {
178        let discovery = MdnsDiscovery::new();
179        assert!(!discovery.is_scanning());
180        assert!(discovery.hosts().is_empty());
181    }
182
183    #[test]
184    fn test_mdns_discovery_default() {
185        let discovery = MdnsDiscovery::default();
186        assert!(!discovery.is_scanning());
187        assert!(discovery.hosts().is_empty());
188    }
189
190    #[test]
191    fn test_mdns_discovery_clear() {
192        let mut discovery = MdnsDiscovery::new();
193        discovery.discovered.push(SshHost {
194            alias: "test".to_string(),
195            hostname: Some("test.local".to_string()),
196            user: None,
197            port: None,
198            identity_file: None,
199            proxy_jump: None,
200            source: SshHostSource::Mdns,
201        });
202        assert_eq!(discovery.hosts().len(), 1);
203
204        discovery.clear();
205        assert!(discovery.hosts().is_empty());
206        assert!(!discovery.is_scanning());
207    }
208
209    #[test]
210    fn test_poll_without_scan() {
211        let mut discovery = MdnsDiscovery::new();
212        // Should return false when no scan is running
213        assert!(!discovery.poll());
214    }
215
216    #[test]
217    fn test_poll_with_completed_channel() {
218        let mut discovery = MdnsDiscovery::new();
219        let (tx, rx) = mpsc::channel();
220        discovery.receiver = Some(rx);
221        discovery.scanning = true;
222
223        // Send a host then drop the sender to simulate scan completion
224        tx.send(SshHost {
225            alias: "myhost".to_string(),
226            hostname: Some("myhost.local".to_string()),
227            user: None,
228            port: None,
229            identity_file: None,
230            proxy_jump: None,
231            source: SshHostSource::Mdns,
232        })
233        .unwrap();
234        drop(tx);
235
236        // First poll should find the host
237        let found = discovery.poll();
238        assert!(found);
239        assert_eq!(discovery.hosts().len(), 1);
240        assert_eq!(discovery.hosts()[0].alias, "myhost");
241        assert_eq!(
242            discovery.hosts()[0].hostname.as_deref(),
243            Some("myhost.local")
244        );
245    }
246
247    #[test]
248    fn test_poll_deduplicates() {
249        let mut discovery = MdnsDiscovery::new();
250        let (tx, rx) = mpsc::channel();
251        discovery.receiver = Some(rx);
252        discovery.scanning = true;
253
254        // Send two hosts with the same hostname and port
255        for _ in 0..2 {
256            tx.send(SshHost {
257                alias: "dup".to_string(),
258                hostname: Some("dup.local".to_string()),
259                user: None,
260                port: None,
261                identity_file: None,
262                proxy_jump: None,
263                source: SshHostSource::Mdns,
264            })
265            .unwrap();
266        }
267        drop(tx);
268
269        discovery.poll();
270        assert_eq!(discovery.hosts().len(), 1);
271    }
272
273    #[test]
274    fn test_scan_marks_scanning() {
275        let mut discovery = MdnsDiscovery::new();
276        assert!(!discovery.is_scanning());
277
278        // Starting a scan sets the scanning flag
279        discovery.start_scan(1);
280        assert!(discovery.is_scanning());
281
282        // Wait for background thread to finish
283        std::thread::sleep(Duration::from_secs(2));
284
285        // Poll until scan completes
286        for _ in 0..10 {
287            discovery.poll();
288            std::thread::sleep(Duration::from_millis(100));
289        }
290    }
291}