Skip to main content

prt_core/core/
ssh_config.rs

1//! Lightweight `~/.ssh/config` parser + merge with prt's own host config.
2//!
3//! Only the directives needed to identify a destination are recognised:
4//! `Host`, `HostName`, `User`, `Port`, `IdentityFile`. Everything else is
5//! silently ignored. Wildcard host blocks (`Host *` / `Host foo?bar`) are
6//! skipped — they match patterns rather than name a concrete target.
7
8use std::fs;
9use std::path::{Path, PathBuf};
10
11use crate::config::SshHostConfig;
12
13/// Where a host definition came from.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SshHostSource {
16    SshConfig,
17    PrtConfig,
18}
19
20impl SshHostSource {
21    pub fn label(self) -> &'static str {
22        match self {
23            Self::SshConfig => "ssh_config",
24            Self::PrtConfig => "prt",
25        }
26    }
27}
28
29/// One concrete SSH destination (no wildcards).
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct SshHost {
32    pub alias: String,
33    pub hostname: Option<String>,
34    pub user: Option<String>,
35    pub port: Option<u16>,
36    pub identity_file: Option<PathBuf>,
37    pub source: SshHostSource,
38}
39
40impl SshHost {
41    /// Display string: `user@hostname:port` (with sensible fallbacks).
42    pub fn target(&self) -> String {
43        let host = self.hostname.as_deref().unwrap_or(&self.alias);
44        let mut s = String::new();
45        if let Some(u) = &self.user {
46            s.push_str(u);
47            s.push('@');
48        }
49        s.push_str(host);
50        if let Some(p) = self.port {
51            s.push(':');
52            s.push_str(&p.to_string());
53        }
54        s
55    }
56}
57
58/// Default `~/.ssh/config` path.
59pub fn default_ssh_config_path() -> Option<PathBuf> {
60    dirs::home_dir().map(|h| h.join(".ssh").join("config"))
61}
62
63/// Parse `~/.ssh/config` (or any file with that grammar). Resolves
64/// `Include` directives relative to the SSH config root (the parent
65/// directory of the top-level file, e.g. `~/.ssh/`), matching OpenSSH
66/// semantics. Recursion is capped at [`MAX_INCLUDE_DEPTH`] so circular
67/// includes don't loop. On failure, returns an empty list — this is
68/// best-effort enrichment.
69pub fn parse_ssh_config(path: &Path) -> Vec<SshHost> {
70    let content = match fs::read_to_string(path) {
71        Ok(c) => c,
72        Err(_) => return Vec::new(),
73    };
74    let mut result: Vec<SshHost> = Vec::new();
75    let mut current: Vec<usize> = Vec::new();
76    parse_ssh_config_inner(&content, path.parent(), 0, &mut result, &mut current);
77    result
78}
79
80#[cfg(test)]
81fn parse_ssh_config_str(content: &str) -> Vec<SshHost> {
82    let mut result: Vec<SshHost> = Vec::new();
83    let mut current: Vec<usize> = Vec::new();
84    parse_ssh_config_inner(content, None, 0, &mut result, &mut current);
85    result
86}
87
88const MAX_INCLUDE_DEPTH: u32 = 16;
89
90/// Recursive parser body.
91///
92/// `root_dir` is the SSH config root (the directory of the *top-level*
93/// file) and stays the same across nested `Include` calls. Relative
94/// include paths are resolved against it.
95///
96/// `result` and `current` are shared across recursion. When an `Include`
97/// is encountered the outer `current` is snapshotted and a fresh copy is
98/// passed into the nested parse — that way directives without their own
99/// `Host` block in the included file still apply to the outer block,
100/// while the outer `current` is restored unchanged after the include
101/// returns (matching OpenSSH).
102fn parse_ssh_config_inner(
103    content: &str,
104    root_dir: Option<&Path>,
105    depth: u32,
106    result: &mut Vec<SshHost>,
107    current: &mut Vec<usize>,
108) {
109    for raw_line in content.lines() {
110        let trimmed = strip_inline_comment(raw_line.trim());
111        if trimmed.is_empty() {
112            continue;
113        }
114
115        let (key, value) = match split_kv(trimmed) {
116            Some(kv) => kv,
117            None => continue,
118        };
119        let key_lc = key.to_ascii_lowercase();
120
121        if key_lc == "include" {
122            if depth >= MAX_INCLUDE_DEPTH {
123                continue;
124            }
125            for token in value.split_whitespace() {
126                let raw = strip_quotes(token);
127                for include_path in resolve_include(raw, root_dir) {
128                    if let Ok(included) = fs::read_to_string(&include_path) {
129                        // Snapshot outer scope; nested file inherits it
130                        // but cannot mutate the caller's selection.
131                        let mut nested = current.clone();
132                        parse_ssh_config_inner(&included, root_dir, depth + 1, result, &mut nested);
133                    }
134                }
135            }
136            // `current` deliberately left intact: directives following
137            // an Include inside the same Host block still apply.
138            continue;
139        }
140
141        if key_lc == "host" {
142            current.clear();
143            for token in value.split_whitespace() {
144                let alias = strip_quotes(token);
145                if alias.is_empty()
146                    || alias.starts_with('!')
147                    || alias.contains('*')
148                    || alias.contains('?')
149                {
150                    continue;
151                }
152                result.push(SshHost {
153                    alias: alias.to_string(),
154                    hostname: None,
155                    user: None,
156                    port: None,
157                    identity_file: None,
158                    source: SshHostSource::SshConfig,
159                });
160                current.push(result.len() - 1);
161            }
162            continue;
163        }
164
165        if current.is_empty() {
166            continue;
167        }
168        let value = strip_quotes(value).to_string();
169        for &idx in current.iter() {
170            let host = &mut result[idx];
171            match key_lc.as_str() {
172                "hostname" => host.hostname = Some(value.clone()),
173                "user" => host.user = Some(value.clone()),
174                "port" => {
175                    if let Ok(p) = value.parse() {
176                        host.port = Some(p);
177                    }
178                }
179                "identityfile" => host.identity_file = Some(expand_tilde(&value)),
180                _ => {}
181            }
182        }
183    }
184}
185
186/// Resolve one `Include` token into a list of concrete file paths.
187///
188/// - `~/...` is expanded via `dirs::home_dir`.
189/// - Relative paths are resolved against `root_dir` (the SSH config root,
190///   typically `~/.ssh/` or `/etc/ssh/`), matching OpenSSH semantics —
191///   including for nested includes.
192/// - If the final path contains a `*` or `?` glob in its basename,
193///   the parent directory is listed and entries matching the basename
194///   pattern are returned. Globs in directory components are not supported
195///   (rare in real configs).
196fn resolve_include(raw: &str, root_dir: Option<&Path>) -> Vec<PathBuf> {
197    if raw.is_empty() {
198        return Vec::new();
199    }
200    let expanded = if let Some(rest) = raw.strip_prefix("~/") {
201        match dirs::home_dir() {
202            Some(h) => h.join(rest),
203            None => return Vec::new(),
204        }
205    } else {
206        let p = PathBuf::from(raw);
207        if p.is_absolute() {
208            p
209        } else {
210            match root_dir {
211                Some(b) => b.join(p),
212                None => p,
213            }
214        }
215    };
216
217    let basename = match expanded.file_name().and_then(|s| s.to_str()) {
218        Some(s) => s.to_string(),
219        None => return Vec::new(),
220    };
221
222    if !basename.contains('*') && !basename.contains('?') {
223        return vec![expanded];
224    }
225
226    let parent = match expanded.parent() {
227        Some(p) => p,
228        None => return Vec::new(),
229    };
230    let read = match fs::read_dir(parent) {
231        Ok(r) => r,
232        Err(_) => return Vec::new(),
233    };
234    let mut out = Vec::new();
235    for entry in read.flatten() {
236        let name = entry.file_name();
237        let name_str = match name.to_str() {
238            Some(s) => s,
239            None => continue,
240        };
241        if match_glob(&basename, name_str) {
242            out.push(entry.path());
243        }
244    }
245    out.sort();
246    out
247}
248
249/// Minimal fnmatch-style matcher: `*` matches any sequence (including empty),
250/// `?` matches exactly one character. No bracket classes, no escaping.
251fn match_glob(pattern: &str, name: &str) -> bool {
252    let p: Vec<char> = pattern.chars().collect();
253    let n: Vec<char> = name.chars().collect();
254    fn rec(p: &[char], n: &[char]) -> bool {
255        match p.first() {
256            None => n.is_empty(),
257            Some('*') => {
258                if rec(&p[1..], n) {
259                    return true;
260                }
261                if let Some((_, rest)) = n.split_first() {
262                    rec(p, rest)
263                } else {
264                    false
265                }
266            }
267            Some('?') => {
268                if let Some((_, rest)) = n.split_first() {
269                    rec(&p[1..], rest)
270                } else {
271                    false
272                }
273            }
274            Some(c) => match n.split_first() {
275                Some((nc, rest)) if nc == c => rec(&p[1..], rest),
276                _ => false,
277            },
278        }
279    }
280    rec(&p, &n)
281}
282
283fn split_kv(line: &str) -> Option<(&str, &str)> {
284    // ssh_config(5): key and value separated by whitespace and/or '='.
285    let bytes = line.as_bytes();
286    let mut i = 0;
287    while i < bytes.len() && !bytes[i].is_ascii_whitespace() && bytes[i] != b'=' {
288        i += 1;
289    }
290    if i == 0 {
291        return None;
292    }
293    let key = &line[..i];
294    let mut j = i;
295    while j < bytes.len() && (bytes[j].is_ascii_whitespace() || bytes[j] == b'=') {
296        j += 1;
297    }
298    if j >= bytes.len() {
299        return None;
300    }
301    Some((key, line[j..].trim()))
302}
303
304/// Drop everything from the first unquoted `#` onward and trim trailing
305/// whitespace. OpenSSH treats `#` as the start of a comment anywhere on a
306/// line, including after a directive value (e.g. `Port 22 # ssh`).
307fn strip_inline_comment(s: &str) -> &str {
308    let bytes = s.as_bytes();
309    let mut in_quotes = false;
310    let mut i = 0;
311    while i < bytes.len() {
312        match bytes[i] {
313            b'"' => in_quotes = !in_quotes,
314            b'#' if !in_quotes => return s[..i].trim_end(),
315            _ => {}
316        }
317        i += 1;
318    }
319    s
320}
321
322fn strip_quotes(s: &str) -> &str {
323    let s = s.trim();
324    if s.len() >= 2 && s.starts_with('"') && s.ends_with('"') {
325        &s[1..s.len() - 1]
326    } else {
327        s
328    }
329}
330
331fn expand_tilde(s: &str) -> PathBuf {
332    if let Some(rest) = s.strip_prefix("~/") {
333        if let Some(home) = dirs::home_dir() {
334            return home.join(rest);
335        }
336    }
337    PathBuf::from(s)
338}
339
340/// Convert a prt-config host entry to an `SshHost`.
341pub fn from_prt_config(cfg: &SshHostConfig) -> Option<SshHost> {
342    if cfg.alias.trim().is_empty() {
343        return None;
344    }
345    Some(SshHost {
346        alias: cfg.alias.clone(),
347        hostname: cfg.hostname.clone(),
348        user: cfg.user.clone(),
349        port: cfg.port,
350        identity_file: cfg.identity_file.as_ref().map(|p| expand_tilde(p)),
351        source: SshHostSource::PrtConfig,
352    })
353}
354
355/// Load known hosts: parse `~/.ssh/config` and merge with prt-config hosts.
356/// Aliases from prt-config win on collision.
357pub fn load_known_hosts(extra: &[SshHostConfig]) -> Vec<SshHost> {
358    let mut hosts: Vec<SshHost> = match default_ssh_config_path() {
359        Some(p) => parse_ssh_config(&p),
360        None => Vec::new(),
361    };
362
363    for cfg in extra {
364        if let Some(host) = from_prt_config(cfg) {
365            if let Some(pos) = hosts.iter().position(|h| h.alias == host.alias) {
366                hosts[pos] = host;
367            } else {
368                hosts.push(host);
369            }
370        }
371    }
372
373    hosts.sort_by(|a, b| a.alias.cmp(&b.alias));
374    hosts
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn parse_basic_host() {
383        let cfg = "Host prod\n  HostName 10.0.0.5\n  User deploy\n  Port 2222\n";
384        let hosts = parse_ssh_config_str(cfg);
385        assert_eq!(hosts.len(), 1);
386        let h = &hosts[0];
387        assert_eq!(h.alias, "prod");
388        assert_eq!(h.hostname.as_deref(), Some("10.0.0.5"));
389        assert_eq!(h.user.as_deref(), Some("deploy"));
390        assert_eq!(h.port, Some(2222));
391    }
392
393    #[test]
394    fn parse_skips_wildcards() {
395        let cfg = "Host *\n  User everyone\nHost prod\n  HostName p\n";
396        let hosts = parse_ssh_config_str(cfg);
397        assert_eq!(hosts.len(), 1);
398        assert_eq!(hosts[0].alias, "prod");
399    }
400
401    #[test]
402    fn parse_skips_negated_aliases() {
403        let cfg = "Host !bastion good\n  HostName ok\n";
404        let hosts = parse_ssh_config_str(cfg);
405        assert_eq!(hosts.len(), 1);
406        assert_eq!(hosts[0].alias, "good");
407    }
408
409    #[test]
410    fn parse_resolves_include_directive() {
411        // Build a small fixture: parent file with an `Include` pointing at
412        // a sibling fragment.
413        let dir = tmpdir();
414        let frag = dir.join("frag.conf");
415        std::fs::write(&frag, "Host included-alias\n  HostName included.example\n").unwrap();
416        let main = dir.join("config");
417        std::fs::write(
418            &main,
419            format!("Host top\n  HostName t\nInclude {}\n", frag.display()),
420        )
421        .unwrap();
422
423        let hosts = parse_ssh_config(&main);
424        let aliases: Vec<_> = hosts.iter().map(|h| h.alias.as_str()).collect();
425        assert!(aliases.contains(&"top"), "{aliases:?}");
426        assert!(aliases.contains(&"included-alias"), "{aliases:?}");
427    }
428
429    #[test]
430    fn parse_include_with_glob_pattern() {
431        let dir = tmpdir();
432        let sub = dir.join("conf.d");
433        std::fs::create_dir(&sub).unwrap();
434        std::fs::write(sub.join("a.conf"), "Host a\n  HostName ah\n").unwrap();
435        std::fs::write(sub.join("b.conf"), "Host b\n  HostName bh\n").unwrap();
436        std::fs::write(sub.join("ignore.txt"), "garbage\n").unwrap();
437
438        let main = dir.join("config");
439        std::fs::write(&main, format!("Include {}/*.conf\n", sub.display())).unwrap();
440
441        let hosts = parse_ssh_config(&main);
442        let aliases: Vec<_> = hosts.iter().map(|h| h.alias.as_str()).collect();
443        assert!(aliases.contains(&"a"), "{aliases:?}");
444        assert!(aliases.contains(&"b"), "{aliases:?}");
445        // Don't assert exact length — some macOS filesystems list files
446        // that we explicitly skip via the *.conf glob (e.g. .DS_Store)
447        // and we only care that the matching set is correct.
448        assert!(!aliases.contains(&"ignore"), "{aliases:?}");
449    }
450
451    #[test]
452    fn parse_include_resolves_relative_to_ssh_root() {
453        // Layout:
454        //   <root>/config       — top-level (root_dir = <root>)
455        //   <root>/sibling.conf — referenced as `Include sibling.conf`
456        //                         from inside <root>/sub/nested.conf
457        //   <root>/sub/nested.conf — referenced from top-level
458        //
459        // OpenSSH resolves relative includes against the SSH config root
460        // (parent of top-level file), regardless of where the include
461        // statement appears. Resolving against the parent of the *current*
462        // file would look for <root>/sub/sibling.conf and miss the alias.
463        let dir = tmpdir();
464        let sub = dir.join("sub");
465        std::fs::create_dir(&sub).unwrap();
466        std::fs::write(dir.join("sibling.conf"), "Host sibling\n  HostName s\n").unwrap();
467        std::fs::write(
468            sub.join("nested.conf"),
469            "Host inner\n  HostName i\nInclude sibling.conf\n",
470        )
471        .unwrap();
472        let main = dir.join("config");
473        std::fs::write(
474            &main,
475            format!("Include {}\n", sub.join("nested.conf").display()),
476        )
477        .unwrap();
478
479        let hosts = parse_ssh_config(&main);
480        let aliases: Vec<_> = hosts.iter().map(|h| h.alias.as_str()).collect();
481        assert!(aliases.contains(&"inner"), "{aliases:?}");
482        assert!(
483            aliases.contains(&"sibling"),
484            "relative Include resolved against wrong root: {aliases:?}"
485        );
486    }
487
488    #[test]
489    fn parse_include_inside_host_keeps_block_active() {
490        // `Host prod` is followed by `Include` and then `Port 2222`.
491        // The trailing Port must apply to prod despite the Include.
492        let dir = tmpdir();
493        let frag = dir.join("frag.conf");
494        std::fs::write(&frag, "Host other\n  HostName o\n").unwrap();
495        let main = dir.join("config");
496        std::fs::write(
497            &main,
498            format!(
499                "Host prod\n  HostName p\nInclude {}\n  Port 2222\n  User deploy\n",
500                frag.display()
501            ),
502        )
503        .unwrap();
504
505        let hosts = parse_ssh_config(&main);
506        let prod = hosts
507            .iter()
508            .find(|h| h.alias == "prod")
509            .expect("prod missing");
510        assert_eq!(prod.hostname.as_deref(), Some("p"));
511        assert_eq!(prod.port, Some(2222), "Port lost after Include");
512        assert_eq!(
513            prod.user.as_deref(),
514            Some("deploy"),
515            "User lost after Include"
516        );
517
518        // Included `Host other` must not contaminate prod nor disappear.
519        let other = hosts
520            .iter()
521            .find(|h| h.alias == "other")
522            .expect("other missing");
523        assert_eq!(other.hostname.as_deref(), Some("o"));
524        // Outer Port must NOT have leaked into the included host.
525        assert_eq!(other.port, None);
526    }
527
528    #[test]
529    fn match_glob_basics() {
530        assert!(match_glob("*.conf", "a.conf"));
531        assert!(match_glob("*.conf", ".conf"));
532        assert!(!match_glob("*.conf", "a.txt"));
533        assert!(match_glob("?.conf", "a.conf"));
534        assert!(!match_glob("?.conf", "ab.conf"));
535        assert!(match_glob("a*b", "axyzb"));
536        assert!(match_glob("a*", "abc"));
537        assert!(match_glob("*", "anything"));
538    }
539
540    fn tmpdir() -> std::path::PathBuf {
541        // Atomic counter avoids collisions when several tests in the same
542        // process call `tmpdir()` within the same nanosecond. macOS test
543        // runners parallelise aggressively and we observed a CI-only
544        // failure that this replaces with a deterministic unique path.
545        use std::sync::atomic::{AtomicU64, Ordering};
546        static SEQ: AtomicU64 = AtomicU64::new(0);
547        let n = SEQ.fetch_add(1, Ordering::Relaxed);
548        let mut p = std::env::temp_dir();
549        p.push(format!(
550            "prt-ssh-cfg-{}-{}-{}",
551            std::process::id(),
552            std::time::SystemTime::now()
553                .duration_since(std::time::UNIX_EPOCH)
554                .unwrap()
555                .as_nanos(),
556            n,
557        ));
558        std::fs::create_dir_all(&p).unwrap();
559        p
560    }
561
562    #[test]
563    fn parse_strips_inline_comments() {
564        let cfg = "Host prod # primary db\n  HostName 10.0.0.5  # internal\n  Port 22 # ssh\n";
565        let hosts = parse_ssh_config_str(cfg);
566        assert_eq!(hosts.len(), 1);
567        assert_eq!(hosts[0].alias, "prod");
568        assert_eq!(hosts[0].hostname.as_deref(), Some("10.0.0.5"));
569        assert_eq!(hosts[0].port, Some(22));
570    }
571
572    #[test]
573    fn parse_keeps_hash_inside_quotes() {
574        let cfg = "Host abc\n  HostName \"h#1.example\"\n";
575        let hosts = parse_ssh_config_str(cfg);
576        assert_eq!(hosts.len(), 1);
577        assert_eq!(hosts[0].hostname.as_deref(), Some("h#1.example"));
578    }
579
580    #[test]
581    fn parse_handles_comments_and_indent() {
582        let cfg = "# comment\n\n   Host foo\n     # nested\n     HostName f.example\n";
583        let hosts = parse_ssh_config_str(cfg);
584        assert_eq!(hosts.len(), 1);
585        assert_eq!(hosts[0].alias, "foo");
586        assert_eq!(hosts[0].hostname.as_deref(), Some("f.example"));
587    }
588
589    #[test]
590    fn parse_multiple_aliases_share_block() {
591        let cfg = "Host a b c\n  HostName shared\n  User root\n";
592        let hosts = parse_ssh_config_str(cfg);
593        assert_eq!(hosts.len(), 3);
594        for h in &hosts {
595            assert_eq!(h.hostname.as_deref(), Some("shared"));
596            assert_eq!(h.user.as_deref(), Some("root"));
597        }
598    }
599
600    #[test]
601    fn parse_case_insensitive_keys_and_equals() {
602        let cfg = "Host abc\n  HOSTNAME=h.example\n  user=joe\n  PORT = 22\n";
603        let hosts = parse_ssh_config_str(cfg);
604        assert_eq!(hosts.len(), 1);
605        assert_eq!(hosts[0].hostname.as_deref(), Some("h.example"));
606        assert_eq!(hosts[0].user.as_deref(), Some("joe"));
607        assert_eq!(hosts[0].port, Some(22));
608    }
609
610    #[test]
611    fn parse_quoted_values() {
612        let cfg = "Host abc\n  HostName \"example.com\"\n";
613        let hosts = parse_ssh_config_str(cfg);
614        assert_eq!(hosts[0].hostname.as_deref(), Some("example.com"));
615    }
616
617    #[test]
618    fn parse_unknown_keys_ignored() {
619        let cfg = "Host foo\n  ProxyCommand whatever\n  HostName ok\n";
620        let hosts = parse_ssh_config_str(cfg);
621        assert_eq!(hosts.len(), 1);
622        assert_eq!(hosts[0].hostname.as_deref(), Some("ok"));
623    }
624
625    #[test]
626    fn parse_empty_returns_empty() {
627        assert!(parse_ssh_config_str("").is_empty());
628    }
629
630    #[test]
631    fn parse_missing_file_returns_empty() {
632        let path = PathBuf::from("/nonexistent/.ssh/config_xxx");
633        assert!(parse_ssh_config(&path).is_empty());
634    }
635
636    #[test]
637    fn merge_prt_config_overrides_ssh_config() {
638        let prt = vec![SshHostConfig {
639            alias: "prod".into(),
640            hostname: Some("override".into()),
641            user: None,
642            port: None,
643            identity_file: None,
644        }];
645        // Simulate by manually parsing then merging
646        let mut hosts = parse_ssh_config_str("Host prod\n  HostName original\n");
647        for cfg in &prt {
648            if let Some(host) = from_prt_config(cfg) {
649                if let Some(pos) = hosts.iter().position(|h| h.alias == host.alias) {
650                    hosts[pos] = host;
651                } else {
652                    hosts.push(host);
653                }
654            }
655        }
656        assert_eq!(hosts.len(), 1);
657        assert_eq!(hosts[0].hostname.as_deref(), Some("override"));
658        assert_eq!(hosts[0].source, SshHostSource::PrtConfig);
659    }
660
661    #[test]
662    fn target_formats_user_host_port() {
663        let host = SshHost {
664            alias: "prod".into(),
665            hostname: Some("h".into()),
666            user: Some("u".into()),
667            port: Some(2222),
668            identity_file: None,
669            source: SshHostSource::SshConfig,
670        };
671        assert_eq!(host.target(), "u@h:2222");
672    }
673
674    #[test]
675    fn target_falls_back_to_alias() {
676        let host = SshHost {
677            alias: "prod".into(),
678            hostname: None,
679            user: None,
680            port: None,
681            identity_file: None,
682            source: SshHostSource::SshConfig,
683        };
684        assert_eq!(host.target(), "prod");
685    }
686}