1use super::types::{SshHost, SshHostSource};
7use mdns_sd::{ServiceDaemon, ServiceEvent};
8use std::sync::mpsc;
9use std::time::Duration;
10
11pub struct MdnsDiscovery {
13 discovered: Vec<SshHost>,
15 scanning: bool,
17 receiver: Option<mpsc::Receiver<SshHost>>,
19}
20
21impl Default for MdnsDiscovery {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl MdnsDiscovery {
28 pub fn new() -> Self {
29 Self {
30 discovered: Vec::new(),
31 scanning: false,
32 receiver: None,
33 }
34 }
35
36 pub fn start_scan(&mut self, timeout_secs: u32) {
38 if self.scanning {
39 return;
40 }
41
42 self.scanning = true;
43 self.discovered.clear();
44
45 let (tx, rx) = mpsc::channel();
46 self.receiver = Some(rx);
47
48 let timeout = Duration::from_secs(u64::from(timeout_secs));
49
50 std::thread::spawn(move || {
51 run_mdns_scan(tx, timeout);
52 });
53 }
54
55 pub fn poll(&mut self) -> bool {
57 let receiver = match &self.receiver {
58 Some(r) => r,
59 None => return false,
60 };
61
62 let mut found_new = false;
63
64 loop {
66 match receiver.try_recv() {
67 Ok(host) => {
68 let duplicate = self
69 .discovered
70 .iter()
71 .any(|h| h.hostname == host.hostname && h.port == host.port);
72 if !duplicate {
73 self.discovered.push(host);
74 found_new = true;
75 }
76 }
77 Err(mpsc::TryRecvError::Empty) => break,
78 Err(mpsc::TryRecvError::Disconnected) => {
79 self.scanning = false;
81 self.receiver = None;
82 break;
83 }
84 }
85 }
86
87 found_new
88 }
89
90 pub fn hosts(&self) -> &[SshHost] {
92 &self.discovered
93 }
94
95 pub fn is_scanning(&self) -> bool {
97 self.scanning
98 }
99
100 pub fn clear(&mut self) {
102 self.discovered.clear();
103 self.scanning = false;
104 self.receiver = None;
105 }
106}
107
108fn run_mdns_scan(tx: mpsc::Sender<SshHost>, timeout: Duration) {
111 let daemon = match ServiceDaemon::new() {
112 Ok(d) => d,
113 Err(e) => {
114 log::warn!("Failed to start mDNS daemon: {}", e);
115 return;
116 }
117 };
118
119 let receiver = match daemon.browse("_ssh._tcp.local.") {
120 Ok(r) => r,
121 Err(e) => {
122 log::warn!("Failed to browse mDNS: {}", e);
123 let _ = daemon.shutdown();
124 return;
125 }
126 };
127
128 let deadline = std::time::Instant::now() + timeout;
129
130 loop {
131 if std::time::Instant::now() >= deadline {
132 break;
133 }
134
135 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
136 match receiver.recv_timeout(remaining.min(Duration::from_millis(500))) {
137 Ok(ServiceEvent::ServiceResolved(info)) => {
138 let hostname = info.get_hostname().trim_end_matches('.').to_string();
139 let port = info.get_port();
140 let service_name = info
141 .get_fullname()
142 .split("._ssh._tcp")
143 .next()
144 .unwrap_or(&hostname)
145 .to_string();
146
147 let host = SshHost {
148 alias: service_name,
149 hostname: Some(hostname),
150 user: None,
151 port: if port == 22 { None } else { Some(port) },
152 identity_file: None,
153 proxy_jump: None,
154 source: SshHostSource::Mdns,
155 };
156
157 if tx.send(host).is_err() {
158 break;
159 }
160 }
161 Ok(_) => {
162 }
164 Err(_) if receiver.is_disconnected() => break,
165 Err(_) => continue, }
167 }
168
169 let _ = daemon.shutdown();
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_mdns_discovery_new() {
178 let discovery = MdnsDiscovery::new();
179 assert!(!discovery.is_scanning());
180 assert!(discovery.hosts().is_empty());
181 }
182
183 #[test]
184 fn test_mdns_discovery_default() {
185 let discovery = MdnsDiscovery::default();
186 assert!(!discovery.is_scanning());
187 assert!(discovery.hosts().is_empty());
188 }
189
190 #[test]
191 fn test_mdns_discovery_clear() {
192 let mut discovery = MdnsDiscovery::new();
193 discovery.discovered.push(SshHost {
194 alias: "test".to_string(),
195 hostname: Some("test.local".to_string()),
196 user: None,
197 port: None,
198 identity_file: None,
199 proxy_jump: None,
200 source: SshHostSource::Mdns,
201 });
202 assert_eq!(discovery.hosts().len(), 1);
203
204 discovery.clear();
205 assert!(discovery.hosts().is_empty());
206 assert!(!discovery.is_scanning());
207 }
208
209 #[test]
210 fn test_poll_without_scan() {
211 let mut discovery = MdnsDiscovery::new();
212 assert!(!discovery.poll());
214 }
215
216 #[test]
217 fn test_poll_with_completed_channel() {
218 let mut discovery = MdnsDiscovery::new();
219 let (tx, rx) = mpsc::channel();
220 discovery.receiver = Some(rx);
221 discovery.scanning = true;
222
223 tx.send(SshHost {
225 alias: "myhost".to_string(),
226 hostname: Some("myhost.local".to_string()),
227 user: None,
228 port: None,
229 identity_file: None,
230 proxy_jump: None,
231 source: SshHostSource::Mdns,
232 })
233 .unwrap();
234 drop(tx);
235
236 let found = discovery.poll();
238 assert!(found);
239 assert_eq!(discovery.hosts().len(), 1);
240 assert_eq!(discovery.hosts()[0].alias, "myhost");
241 assert_eq!(
242 discovery.hosts()[0].hostname.as_deref(),
243 Some("myhost.local")
244 );
245 }
246
247 #[test]
248 fn test_poll_deduplicates() {
249 let mut discovery = MdnsDiscovery::new();
250 let (tx, rx) = mpsc::channel();
251 discovery.receiver = Some(rx);
252 discovery.scanning = true;
253
254 for _ in 0..2 {
256 tx.send(SshHost {
257 alias: "dup".to_string(),
258 hostname: Some("dup.local".to_string()),
259 user: None,
260 port: None,
261 identity_file: None,
262 proxy_jump: None,
263 source: SshHostSource::Mdns,
264 })
265 .unwrap();
266 }
267 drop(tx);
268
269 discovery.poll();
270 assert_eq!(discovery.hosts().len(), 1);
271 }
272
273 #[test]
274 fn test_scan_marks_scanning() {
275 let mut discovery = MdnsDiscovery::new();
276 assert!(!discovery.is_scanning());
277
278 discovery.start_scan(1);
280 assert!(discovery.is_scanning());
281
282 std::thread::sleep(Duration::from_secs(2));
284
285 for _ in 0..10 {
287 discovery.poll();
288 std::thread::sleep(Duration::from_millis(100));
289 }
290 }
291}