Skip to main content

rch_common/
discovery.rs

1//! Worker discovery from SSH config and shell aliases.
2//!
3//! This module provides functionality to automatically discover potential
4//! worker machines from the user's existing SSH configuration and shell aliases.
5
6use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10/// A host discovered from SSH config or shell aliases.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DiscoveredHost {
13    /// The alias or short name (e.g., "css", "worker1")
14    pub alias: String,
15    /// The actual hostname or IP address
16    pub hostname: String,
17    /// SSH username
18    pub user: String,
19    /// Path to SSH identity file (private key)
20    pub identity_file: Option<String>,
21    /// SSH port (default 22)
22    pub port: u16,
23    /// Where this host was discovered from
24    pub source: DiscoverySource,
25}
26
27/// Source of a discovered host.
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum DiscoverySource {
31    /// From ~/.ssh/config
32    SshConfig,
33    /// From ~/.bashrc
34    Bashrc,
35    /// From ~/.zshrc
36    Zshrc,
37    /// From ~/.bash_aliases
38    BashAliases,
39    /// From ~/.zsh_aliases
40    ZshAliases,
41}
42
43impl std::fmt::Display for DiscoverySource {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            Self::SshConfig => write!(f, "~/.ssh/config"),
47            Self::Bashrc => write!(f, "~/.bashrc"),
48            Self::Zshrc => write!(f, "~/.zshrc"),
49            Self::BashAliases => write!(f, "~/.bash_aliases"),
50            Self::ZshAliases => write!(f, "~/.zsh_aliases"),
51        }
52    }
53}
54
55/// Parse ~/.ssh/config and extract potential worker hosts.
56///
57/// SSH config format:
58/// ```text
59/// Host fmd
60///     HostName 51.222.245.56
61///     User ubuntu
62///     IdentityFile ~/.ssh/my_key.pem
63///
64/// Host yto
65///     HostName 37.187.75.150
66///     User ubuntu
67///     IdentityFile ~/.ssh/my_key.pem
68/// ```
69///
70/// # Returns
71/// A list of discovered hosts. Returns empty vec if config doesn't exist.
72pub fn parse_ssh_config() -> Result<Vec<DiscoveredHost>> {
73    let home = dirs::home_dir().context("Could not determine home directory")?;
74    let ssh_config_path = home.join(".ssh").join("config");
75
76    if !ssh_config_path.exists() {
77        return Ok(vec![]);
78    }
79
80    parse_ssh_config_file(&ssh_config_path)
81}
82
83/// Parse an SSH config file at the given path.
84pub fn parse_ssh_config_file(path: &PathBuf) -> Result<Vec<DiscoveredHost>> {
85    let content = std::fs::read_to_string(path)
86        .with_context(|| format!("Failed to read SSH config: {}", path.display()))?;
87
88    parse_ssh_config_content(&content)
89}
90
91/// Parse SSH config content and extract hosts.
92pub fn parse_ssh_config_content(content: &str) -> Result<Vec<DiscoveredHost>> {
93    let mut hosts = Vec::new();
94    let mut current_host: Option<SshConfigHost> = None;
95
96    for line in content.lines() {
97        let line = line.trim();
98
99        // Skip comments and empty lines
100        if line.is_empty() || line.starts_with('#') {
101            continue;
102        }
103
104        // Parse the line into key-value
105        let (key, value) = match parse_ssh_config_line(line) {
106            Some(kv) => kv,
107            None => continue,
108        };
109
110        match key.to_lowercase().as_str() {
111            "host" => {
112                // Save previous host if valid
113                if let Some(host) = current_host.take()
114                    && let Some(discovered) = host.into_discovered()
115                {
116                    hosts.push(discovered);
117                }
118
119                // Start new host block
120                // Handle multiple aliases on one line: "Host foo bar baz"
121                let aliases: Vec<&str> = value.split_whitespace().collect();
122                if let Some(first_alias) = aliases.first() {
123                    // Skip wildcards and special patterns
124                    if !first_alias.contains('*') && !first_alias.contains('?') {
125                        current_host = Some(SshConfigHost::new(first_alias.to_string()));
126                    }
127                }
128            }
129            "hostname" => {
130                if let Some(ref mut host) = current_host {
131                    host.hostname = Some(value.to_string());
132                }
133            }
134            "user" => {
135                if let Some(ref mut host) = current_host {
136                    host.user = Some(value.to_string());
137                }
138            }
139            "identityfile" => {
140                if let Some(ref mut host) = current_host {
141                    host.identity_file = Some(expand_tilde(value));
142                }
143            }
144            "port" => {
145                if let Some(ref mut host) = current_host {
146                    host.port = value.parse().ok();
147                }
148            }
149            _ => {
150                // Ignore other SSH config options
151            }
152        }
153    }
154
155    // Don't forget the last host
156    if let Some(host) = current_host
157        && let Some(discovered) = host.into_discovered()
158    {
159        hosts.push(discovered);
160    }
161
162    // Filter out hosts that are clearly not workers
163    let hosts = hosts
164        .into_iter()
165        .filter(|h| is_potential_worker(&h.alias, &h.hostname))
166        .collect();
167
168    Ok(hosts)
169}
170
171/// Internal struct for parsing SSH config blocks.
172struct SshConfigHost {
173    alias: String,
174    hostname: Option<String>,
175    user: Option<String>,
176    identity_file: Option<String>,
177    port: Option<u16>,
178}
179
180impl SshConfigHost {
181    fn new(alias: String) -> Self {
182        Self {
183            alias,
184            hostname: None,
185            user: None,
186            identity_file: None,
187            port: None,
188        }
189    }
190
191    fn into_discovered(self) -> Option<DiscoveredHost> {
192        // Must have at least a hostname to be useful
193        // If no hostname, use alias as hostname (common for simple configs)
194        let hostname = self.hostname.unwrap_or_else(|| self.alias.clone());
195
196        // Get current username as default
197        let default_user = std::env::var("USER")
198            .or_else(|_| std::env::var("USERNAME"))
199            .unwrap_or_else(|_| "ubuntu".to_string());
200
201        Some(DiscoveredHost {
202            alias: self.alias,
203            hostname,
204            user: self.user.unwrap_or(default_user),
205            identity_file: self.identity_file,
206            port: self.port.unwrap_or(22),
207            source: DiscoverySource::SshConfig,
208        })
209    }
210}
211
212/// Parse a single SSH config line into key-value pair.
213fn parse_ssh_config_line(line: &str) -> Option<(&str, &str)> {
214    // SSH config uses whitespace or = as separator
215    // Examples:
216    //   Host foo
217    //   HostName=192.168.1.1
218    //   User ubuntu
219
220    let line = line.trim();
221    if line.is_empty() || line.starts_with('#') {
222        return None;
223    }
224
225    // Try = separator first
226    if let Some((key, value)) = line.split_once('=') {
227        return Some((key.trim(), value.trim()));
228    }
229
230    // Try whitespace separator
231    if let Some((key, value)) = line.split_once(char::is_whitespace) {
232        return Some((key.trim(), value.trim()));
233    }
234
235    None
236}
237
238/// Expand ~ to home directory in paths.
239fn expand_tilde(path: &str) -> String {
240    if let Some(rest) = path.strip_prefix("~/")
241        && let Some(home) = dirs::home_dir()
242    {
243        return home.join(rest).display().to_string();
244    }
245    path.to_string()
246}
247
248/// Parse shell RC files for SSH aliases.
249///
250/// Looks for patterns like:
251/// - `alias css='ssh -i ~/.ssh/key.pem ubuntu@192.168.1.100'`
252/// - `alias csd="ssh user@host"`
253/// - `alias foo='ssh host'`
254pub fn parse_shell_aliases() -> Result<Vec<DiscoveredHost>> {
255    let home = dirs::home_dir().context("Could not determine home directory")?;
256    let mut all_hosts = Vec::new();
257
258    // List of shell RC files to check
259    let rc_files = [
260        (home.join(".bashrc"), DiscoverySource::Bashrc),
261        (home.join(".zshrc"), DiscoverySource::Zshrc),
262        (home.join(".bash_aliases"), DiscoverySource::BashAliases),
263        (home.join(".zsh_aliases"), DiscoverySource::ZshAliases),
264    ];
265
266    for (path, source) in &rc_files {
267        if path.exists() {
268            match parse_shell_aliases_file(path, source.clone()) {
269                Ok(hosts) => all_hosts.extend(hosts),
270                Err(_) => continue, // Ignore parse errors in individual files
271            }
272        }
273    }
274
275    Ok(all_hosts)
276}
277
278/// Parse a shell RC file for SSH aliases.
279pub fn parse_shell_aliases_file(
280    path: &PathBuf,
281    source: DiscoverySource,
282) -> Result<Vec<DiscoveredHost>> {
283    let content = std::fs::read_to_string(path)
284        .with_context(|| format!("Failed to read shell RC file: {}", path.display()))?;
285    parse_shell_aliases_content(&content, source)
286}
287
288/// Parse shell alias content for SSH commands.
289pub fn parse_shell_aliases_content(
290    content: &str,
291    source: DiscoverySource,
292) -> Result<Vec<DiscoveredHost>> {
293    use regex::Regex;
294
295    let mut hosts = Vec::new();
296
297    // Match alias definitions with ssh commands
298    // Handles: alias NAME='ssh ...' or alias NAME="ssh ..."
299    let alias_re = Regex::new(r#"(?m)^\s*alias\s+(\w+)\s*=\s*['"]ssh\s+(.*)['"]"#)
300        .context("Failed to compile alias regex")?;
301
302    // Extract -i identity file
303    let identity_re = Regex::new(r"-i\s+(\S+)").context("Failed to compile identity regex")?;
304
305    // Extract -p port
306    let port_re = Regex::new(r"-p\s+(\d+)").context("Failed to compile port regex")?;
307
308    for caps in alias_re.captures_iter(content) {
309        let alias_name = match caps.get(1) {
310            Some(m) => m.as_str().to_string(),
311            None => continue,
312        };
313        let ssh_args = match caps.get(2) {
314            Some(m) => m.as_str(),
315            None => continue,
316        };
317
318        // Extract identity file if present
319        let identity_file = identity_re
320            .captures(ssh_args)
321            .and_then(|c| c.get(1))
322            .map(|m| expand_tilde(m.as_str()));
323
324        // Extract port if present
325        let port = port_re
326            .captures(ssh_args)
327            .and_then(|c| c.get(1))
328            .and_then(|m| m.as_str().parse::<u16>().ok())
329            .unwrap_or(22);
330
331        // Extract user@host from the end of the command
332        // Strip all options first (anything starting with -)
333        let args_without_options: Vec<&str> = ssh_args
334            .split_whitespace()
335            .filter(|s| !s.starts_with('-'))
336            .filter(|s| {
337                // Also filter out values that follow -i or -p
338                if let Some(prev_idx) = ssh_args.find(s)
339                    && prev_idx > 0
340                {
341                    let before = &ssh_args[..prev_idx].trim_end();
342                    if before.ends_with("-i") || before.ends_with("-p") {
343                        return false;
344                    }
345                }
346                true
347            })
348            .collect();
349
350        // The host specification is typically the last non-option argument
351        let host_spec = match args_without_options.last() {
352            Some(s) => *s,
353            None => continue,
354        };
355
356        // Parse user@host or just host
357        let (user, hostname) = if let Some((u, h)) = host_spec.split_once('@') {
358            (u.to_string(), h.to_string())
359        } else {
360            // No user specified, use current user
361            let default_user = std::env::var("USER")
362                .or_else(|_| std::env::var("USERNAME"))
363                .unwrap_or_else(|_| "ubuntu".to_string());
364            (default_user, host_spec.to_string())
365        };
366
367        // Skip if we couldn't extract a valid host
368        if hostname.is_empty() {
369            continue;
370        }
371
372        // Filter out non-workers
373        if !is_potential_worker(&alias_name, &hostname) {
374            continue;
375        }
376
377        hosts.push(DiscoveredHost {
378            alias: alias_name,
379            hostname,
380            user,
381            identity_file,
382            port,
383            source: source.clone(),
384        });
385    }
386
387    Ok(hosts)
388}
389
390/// Discover all potential workers from all sources.
391pub fn discover_all() -> Result<Vec<DiscoveredHost>> {
392    let mut all_hosts = Vec::new();
393
394    // Parse SSH config
395    if let Ok(hosts) = parse_ssh_config() {
396        all_hosts.extend(hosts);
397    }
398
399    // Parse shell aliases
400    if let Ok(hosts) = parse_shell_aliases() {
401        all_hosts.extend(hosts);
402    }
403
404    // Deduplicate by hostname (keep first occurrence, which is typically SSH config)
405    let mut seen_hostnames = std::collections::HashSet::new();
406    all_hosts.retain(|h| seen_hostnames.insert(h.hostname.clone()));
407
408    Ok(all_hosts)
409}
410
411/// Check if a host is potentially a worker (not a common non-worker host).
412fn is_potential_worker(alias: &str, hostname: &str) -> bool {
413    let skip_patterns = [
414        "github.com",
415        "gitlab.com",
416        "bitbucket.org",
417        "localhost",
418        "127.0.0.1",
419        "::1",
420    ];
421
422    let skip_aliases = ["github", "gitlab", "bitbucket", "local"];
423
424    // Check hostname
425    for pattern in skip_patterns {
426        if hostname.contains(pattern) {
427            return false;
428        }
429    }
430
431    // Check alias
432    let alias_lower = alias.to_lowercase();
433    for skip in skip_aliases {
434        if alias_lower == skip {
435            return false;
436        }
437    }
438
439    true
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_parse_basic_ssh_config() {
448        let content = r#"
449Host fmd
450    HostName 51.222.245.56
451    User ubuntu
452    IdentityFile ~/.ssh/my_key.pem
453
454Host yto
455    HostName 37.187.75.150
456    User root
457    IdentityFile ~/.ssh/other_key.pem
458    Port 2222
459"#;
460
461        let hosts = parse_ssh_config_content(content).unwrap();
462        assert_eq!(hosts.len(), 2);
463
464        let fmd = &hosts[0];
465        assert_eq!(fmd.alias, "fmd");
466        assert_eq!(fmd.hostname, "51.222.245.56");
467        assert_eq!(fmd.user, "ubuntu");
468        assert!(fmd.identity_file.as_ref().unwrap().contains("my_key.pem"));
469        assert_eq!(fmd.port, 22);
470        assert_eq!(fmd.source, DiscoverySource::SshConfig);
471
472        let yto = &hosts[1];
473        assert_eq!(yto.alias, "yto");
474        assert_eq!(yto.hostname, "37.187.75.150");
475        assert_eq!(yto.user, "root");
476        assert_eq!(yto.port, 2222);
477    }
478
479    #[test]
480    fn test_skip_wildcard_hosts() {
481        let content = r#"
482Host *
483    ServerAliveInterval 60
484
485Host worker1
486    HostName 192.168.1.10
487    User ubuntu
488"#;
489
490        let hosts = parse_ssh_config_content(content).unwrap();
491        assert_eq!(hosts.len(), 1);
492        assert_eq!(hosts[0].alias, "worker1");
493    }
494
495    #[test]
496    fn test_skip_github() {
497        let content = r#"
498Host github.com
499    HostName github.com
500    User git
501    IdentityFile ~/.ssh/github_key
502
503Host worker1
504    HostName 192.168.1.10
505    User ubuntu
506"#;
507
508        let hosts = parse_ssh_config_content(content).unwrap();
509        assert_eq!(hosts.len(), 1);
510        assert_eq!(hosts[0].alias, "worker1");
511    }
512
513    #[test]
514    fn test_handle_multiple_aliases() {
515        let content = r#"
516Host foo bar baz
517    HostName 192.168.1.10
518    User ubuntu
519"#;
520
521        let hosts = parse_ssh_config_content(content).unwrap();
522        // Should use first alias
523        assert_eq!(hosts.len(), 1);
524        assert_eq!(hosts[0].alias, "foo");
525    }
526
527    #[test]
528    fn test_handle_equals_separator() {
529        let content = r#"
530Host worker
531    HostName=192.168.1.10
532    User=ubuntu
533"#;
534
535        let hosts = parse_ssh_config_content(content).unwrap();
536        assert_eq!(hosts.len(), 1);
537        assert_eq!(hosts[0].hostname, "192.168.1.10");
538        assert_eq!(hosts[0].user, "ubuntu");
539    }
540
541    #[test]
542    fn test_handle_comments() {
543        let content = r#"
544# This is a comment
545Host worker1
546    # Another comment
547    HostName 192.168.1.10
548    User ubuntu
549"#;
550
551        let hosts = parse_ssh_config_content(content).unwrap();
552        assert_eq!(hosts.len(), 1);
553        assert_eq!(hosts[0].alias, "worker1");
554    }
555
556    #[test]
557    fn test_empty_config() {
558        let content = "";
559        let hosts = parse_ssh_config_content(content).unwrap();
560        assert!(hosts.is_empty());
561    }
562
563    #[test]
564    fn test_host_without_hostname_uses_alias() {
565        let content = r#"
566Host myserver
567    User ubuntu
568    IdentityFile ~/.ssh/key.pem
569"#;
570
571        let hosts = parse_ssh_config_content(content).unwrap();
572        assert_eq!(hosts.len(), 1);
573        // When no HostName, alias is used as hostname
574        assert_eq!(hosts[0].hostname, "myserver");
575    }
576
577    #[test]
578    fn test_expand_tilde() {
579        let path = "~/.ssh/key.pem";
580        let expanded = expand_tilde(path);
581        assert!(!expanded.starts_with("~"));
582        assert!(expanded.contains(".ssh/key.pem"));
583    }
584
585    #[test]
586    fn test_expand_tilde_no_tilde() {
587        let path = "/absolute/path/key.pem";
588        assert_eq!(expand_tilde(path), path);
589    }
590
591    #[test]
592    fn test_is_potential_worker() {
593        assert!(is_potential_worker("worker1", "192.168.1.10"));
594        assert!(is_potential_worker("css", "209.145.54.164"));
595        assert!(!is_potential_worker("github", "github.com"));
596        assert!(!is_potential_worker("local", "localhost"));
597        assert!(!is_potential_worker("home", "127.0.0.1"));
598    }
599
600    // Shell alias parsing tests
601
602    #[test]
603    fn test_parse_shell_aliases_basic() {
604        let content = r#"
605# Some other config
606export PATH="/usr/local/bin:$PATH"
607
608alias ll='ls -la'
609alias css='ssh -i ~/.ssh/key.pem ubuntu@192.168.1.100'
610alias csd='ssh root@10.0.0.5'
611"#;
612
613        let hosts = parse_shell_aliases_content(content, DiscoverySource::Bashrc).unwrap();
614        assert_eq!(hosts.len(), 2);
615
616        let css = hosts.iter().find(|h| h.alias == "css").unwrap();
617        assert_eq!(css.hostname, "192.168.1.100");
618        assert_eq!(css.user, "ubuntu");
619        assert!(css.identity_file.is_some());
620        assert_eq!(css.source, DiscoverySource::Bashrc);
621
622        let csd = hosts.iter().find(|h| h.alias == "csd").unwrap();
623        assert_eq!(csd.hostname, "10.0.0.5");
624        assert_eq!(csd.user, "root");
625    }
626
627    #[test]
628    fn test_parse_shell_aliases_double_quotes() {
629        let content = r#"
630alias server="ssh -i ~/.ssh/id_rsa admin@example.com"
631"#;
632
633        let hosts = parse_shell_aliases_content(content, DiscoverySource::Zshrc).unwrap();
634        assert_eq!(hosts.len(), 1);
635        assert_eq!(hosts[0].alias, "server");
636        assert_eq!(hosts[0].hostname, "example.com");
637        assert_eq!(hosts[0].user, "admin");
638    }
639
640    #[test]
641    fn test_parse_shell_aliases_with_port() {
642        let content = r#"
643alias custom='ssh -p 2222 user@192.168.1.50'
644"#;
645
646        let hosts = parse_shell_aliases_content(content, DiscoverySource::Bashrc).unwrap();
647        assert_eq!(hosts.len(), 1);
648        assert_eq!(hosts[0].port, 2222);
649    }
650
651    #[test]
652    fn test_parse_shell_aliases_simple_host() {
653        let content = r#"
654alias myserver='ssh myserver.example.com'
655"#;
656
657        let hosts = parse_shell_aliases_content(content, DiscoverySource::Bashrc).unwrap();
658        assert_eq!(hosts.len(), 1);
659        assert_eq!(hosts[0].hostname, "myserver.example.com");
660        // User should default to current user or ubuntu
661        assert!(!hosts[0].user.is_empty());
662    }
663
664    #[test]
665    fn test_parse_shell_aliases_skips_localhost() {
666        let content = r#"
667alias local='ssh localhost'
668alias loopback='ssh 127.0.0.1'
669alias remote='ssh 192.168.1.1'
670"#;
671
672        let hosts = parse_shell_aliases_content(content, DiscoverySource::Bashrc).unwrap();
673        assert_eq!(hosts.len(), 1);
674        assert_eq!(hosts[0].alias, "remote");
675    }
676
677    #[test]
678    fn test_parse_shell_aliases_skips_non_ssh() {
679        let content = r#"
680alias ll='ls -la'
681alias grep='grep --color=auto'
682alias ssh_host='ssh worker@192.168.1.10'
683"#;
684
685        let hosts = parse_shell_aliases_content(content, DiscoverySource::Bashrc).unwrap();
686        assert_eq!(hosts.len(), 1);
687        assert_eq!(hosts[0].alias, "ssh_host");
688    }
689
690    #[test]
691    fn test_parse_shell_aliases_empty() {
692        let content = "";
693        let hosts = parse_shell_aliases_content(content, DiscoverySource::Bashrc).unwrap();
694        assert!(hosts.is_empty());
695    }
696}