Skip to main content

prt_core/
config.rs

1//! Configuration loading from `~/.config/prt/config.toml`.
2//!
3//! The config is mostly read-only — `prt` only writes to it through
4//! [`write_tunnels`] when the user explicitly requests "save tunnels".
5//! Missing file or parse errors fall back to defaults silently
6//! (a warning is printed to stderr on parse failure).
7
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::io;
11use std::path::{Path, PathBuf};
12
13use crate::core::ssh_tunnel::{SshTunnelSpec, TunnelKind};
14
15/// Raw TOML representation (TOML table keys are always strings).
16#[derive(Debug, Clone, Default, Deserialize)]
17#[serde(default)]
18struct RawConfig {
19    known_ports: HashMap<String, String>,
20    alerts: Vec<AlertRuleConfig>,
21    ssh_hosts: Vec<SshHostConfig>,
22    ssh_tunnels: Vec<SshTunnelConfig>,
23}
24
25/// Top-level configuration.
26#[derive(Debug, Clone, Default)]
27pub struct PrtConfig {
28    /// User-defined port → service name overrides.
29    /// These take precedence over the built-in known ports database.
30    ///
31    /// ```toml
32    /// [known_ports]
33    /// 9090 = "prometheus"
34    /// 3000 = "grafana"
35    /// ```
36    pub known_ports: HashMap<u16, String>,
37
38    /// Alert rules (populated by the alerts feature).
39    ///
40    /// ```toml
41    /// [[alerts]]
42    /// port = 22
43    /// action = "bell"
44    /// ```
45    pub alerts: Vec<AlertRuleConfig>,
46
47    /// Saved SSH hosts (in addition to `~/.ssh/config`).
48    pub ssh_hosts: Vec<SshHostConfig>,
49
50    /// Saved SSH tunnels (auto-restore on launch).
51    pub ssh_tunnels: Vec<SshTunnelConfig>,
52}
53
54impl From<RawConfig> for PrtConfig {
55    fn from(raw: RawConfig) -> Self {
56        let known_ports = raw
57            .known_ports
58            .into_iter()
59            .filter_map(|(k, v)| k.parse::<u16>().ok().map(|port| (port, v)))
60            .collect();
61        Self {
62            known_ports,
63            alerts: raw.alerts,
64            ssh_hosts: raw.ssh_hosts,
65            ssh_tunnels: raw.ssh_tunnels,
66        }
67    }
68}
69
70/// A single alert rule from the TOML config.
71#[derive(Debug, Clone, Default, Deserialize)]
72#[serde(default)]
73pub struct AlertRuleConfig {
74    pub port: Option<u16>,
75    pub process: Option<String>,
76    pub state: Option<String>,
77    pub connections_gt: Option<usize>,
78    #[serde(default = "default_action")]
79    pub action: String,
80}
81
82fn default_action() -> String {
83    "highlight".into()
84}
85
86/// User-defined SSH host (extends `~/.ssh/config`).
87///
88/// ```toml
89/// [[ssh_hosts]]
90/// alias = "prod-db"
91/// hostname = "10.0.1.5"
92/// user = "deploy"
93/// port = 22
94/// identity_file = "~/.ssh/id_ed25519_prod"
95/// ```
96#[derive(Debug, Clone, Default, Deserialize, Serialize)]
97#[serde(default)]
98pub struct SshHostConfig {
99    pub alias: String,
100    pub hostname: Option<String>,
101    pub user: Option<String>,
102    pub port: Option<u16>,
103    pub identity_file: Option<String>,
104}
105
106/// Saved SSH tunnel.
107///
108/// ```toml
109/// [[ssh_tunnels]]
110/// name = "prod-postgres"
111/// kind = "local"
112/// local_port = 5433
113/// remote_host = "127.0.0.1"
114/// remote_port = 5432
115/// host_alias = "prod-db"
116/// ```
117#[derive(Debug, Clone, Default, Deserialize, Serialize)]
118#[serde(default)]
119pub struct SshTunnelConfig {
120    pub name: Option<String>,
121    #[serde(default = "default_tunnel_kind")]
122    pub kind: String,
123    pub local_port: u16,
124    pub remote_host: Option<String>,
125    pub remote_port: Option<u16>,
126    pub host_alias: String,
127}
128
129fn default_tunnel_kind() -> String {
130    "local".into()
131}
132
133impl SshTunnelConfig {
134    /// Convert to runtime [`SshTunnelSpec`]. Returns `None` for unknown kinds
135    /// or invalid combinations (logged at the call site, not here).
136    pub fn to_spec(&self) -> Option<SshTunnelSpec> {
137        let kind = match self.kind.to_ascii_lowercase().as_str() {
138            "local" => TunnelKind::Local,
139            "dynamic" => TunnelKind::Dynamic,
140            _ => return None,
141        };
142        Some(SshTunnelSpec {
143            name: self.name.clone(),
144            kind,
145            local_port: self.local_port,
146            remote_host: self.remote_host.clone(),
147            remote_port: self.remote_port,
148            host_alias: self.host_alias.clone(),
149        })
150    }
151
152    pub fn from_spec(spec: &SshTunnelSpec) -> Self {
153        Self {
154            name: spec.name.clone(),
155            kind: spec.kind.label().into(),
156            local_port: spec.local_port,
157            remote_host: spec.remote_host.clone(),
158            remote_port: spec.remote_port,
159            host_alias: spec.host_alias.clone(),
160        }
161    }
162}
163
164/// Returns the config directory path: `~/.config/prt/`.
165pub fn config_dir() -> Option<PathBuf> {
166    dirs::config_dir().map(|d| d.join("prt"))
167}
168
169/// Returns the path to the main config file: `~/.config/prt/config.toml`.
170pub fn config_path() -> Option<PathBuf> {
171    config_dir().map(|d| d.join("config.toml"))
172}
173
174/// Load configuration from `~/.config/prt/config.toml`.
175///
176/// Returns [`PrtConfig::default()`] if the file does not exist.
177/// Prints a warning to stderr and returns defaults if the file
178/// exists but cannot be parsed.
179pub fn load_config() -> PrtConfig {
180    let path = match config_path() {
181        Some(p) => p,
182        None => return PrtConfig::default(),
183    };
184
185    let content = match std::fs::read_to_string(&path) {
186        Ok(c) => c,
187        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
188            return PrtConfig::default();
189        }
190        Err(e) => {
191            eprintln!("prt: warning: cannot read {}: {e}", path.display());
192            return PrtConfig::default();
193        }
194    };
195
196    match toml::from_str::<RawConfig>(&content) {
197        Ok(raw) => raw.into(),
198        Err(e) => {
199            eprintln!("prt: warning: cannot parse {}: {e}", path.display());
200            PrtConfig::default()
201        }
202    }
203}
204
205/// Persist `[[ssh_tunnels]]` to the config file at `path`.
206///
207/// Reads any existing TOML, strips all existing `[[ssh_tunnels]]` blocks,
208/// and appends fresh ones rebuilt from `specs`. The rest of the file is
209/// preserved verbatim. Creates the file (and parent directory) if missing.
210pub fn write_tunnels(path: &Path, specs: &[SshTunnelSpec]) -> io::Result<()> {
211    if let Some(parent) = path.parent() {
212        std::fs::create_dir_all(parent)?;
213    }
214    // Only fall back to empty content when the file genuinely doesn't exist.
215    // Other I/O errors (permission denied, transient disk error, decoding)
216    // must propagate so we don't blow away unrelated sections like
217    // `known_ports` / `alerts` / `ssh_hosts`.
218    let existing = match std::fs::read_to_string(path) {
219        Ok(c) => c,
220        Err(e) if e.kind() == io::ErrorKind::NotFound => String::new(),
221        Err(e) => return Err(e),
222    };
223    let stripped = strip_ssh_tunnels_section(&existing);
224
225    let configs: Vec<SshTunnelConfig> = specs.iter().map(SshTunnelConfig::from_spec).collect();
226
227    #[derive(Serialize)]
228    struct Wrap<'a> {
229        ssh_tunnels: &'a [SshTunnelConfig],
230    }
231    let appended = if configs.is_empty() {
232        String::new()
233    } else {
234        toml::to_string(&Wrap {
235            ssh_tunnels: &configs,
236        })
237        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
238    };
239
240    let mut out = stripped.trim_end().to_string();
241    if !out.is_empty() {
242        out.push('\n');
243        out.push('\n');
244    }
245    out.push_str(&appended);
246    if !out.ends_with('\n') {
247        out.push('\n');
248    }
249    std::fs::write(path, out)
250}
251
252/// Remove every `[[ssh_tunnels]]` block (including its key/value lines)
253/// from a raw TOML string. A block runs until the next `[`-prefixed line
254/// or EOF. Comments belonging to the block are removed too.
255fn strip_ssh_tunnels_section(content: &str) -> String {
256    let mut out = String::with_capacity(content.len());
257    let mut skipping = false;
258    for line in content.lines() {
259        let trimmed = line.trim_start();
260        if trimmed.starts_with("[[ssh_tunnels]]") || trimmed.starts_with("[ssh_tunnels]") {
261            skipping = true;
262            continue;
263        }
264        if skipping {
265            // A new top-level table or array-of-tables ends the skipped block.
266            if trimmed.starts_with('[') {
267                skipping = false;
268            } else {
269                continue;
270            }
271        }
272        out.push_str(line);
273        out.push('\n');
274    }
275    out
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn default_config_is_empty() {
284        let config = PrtConfig::default();
285        assert!(config.known_ports.is_empty());
286        assert!(config.alerts.is_empty());
287        assert!(config.ssh_hosts.is_empty());
288        assert!(config.ssh_tunnels.is_empty());
289    }
290
291    #[test]
292    fn parse_known_ports() {
293        let toml_str = r#"
294[known_ports]
2959090 = "prometheus"
2963000 = "grafana"
297"#;
298        let raw: RawConfig = toml::from_str(toml_str).unwrap();
299        let config: PrtConfig = raw.into();
300        assert_eq!(config.known_ports.get(&9090).unwrap(), "prometheus");
301        assert_eq!(config.known_ports.get(&3000).unwrap(), "grafana");
302    }
303
304    #[test]
305    fn parse_alert_rules() {
306        let toml_str = r#"
307[[alerts]]
308port = 22
309action = "bell"
310
311[[alerts]]
312process = "python"
313state = "LISTEN"
314action = "highlight"
315
316[[alerts]]
317connections_gt = 100
318"#;
319        let raw: RawConfig = toml::from_str(toml_str).unwrap();
320        let config: PrtConfig = raw.into();
321        assert_eq!(config.alerts.len(), 3);
322        assert_eq!(config.alerts[0].port, Some(22));
323        assert_eq!(config.alerts[0].action, "bell");
324        assert_eq!(config.alerts[1].process.as_deref(), Some("python"));
325        assert_eq!(config.alerts[2].connections_gt, Some(100));
326        assert_eq!(config.alerts[2].action, "highlight"); // default
327    }
328
329    #[test]
330    fn parse_ssh_hosts() {
331        let toml_str = r#"
332[[ssh_hosts]]
333alias = "prod"
334hostname = "10.0.0.5"
335user = "deploy"
336port = 22
337identity_file = "~/.ssh/id_ed25519"
338"#;
339        let raw: RawConfig = toml::from_str(toml_str).unwrap();
340        let config: PrtConfig = raw.into();
341        assert_eq!(config.ssh_hosts.len(), 1);
342        assert_eq!(config.ssh_hosts[0].alias, "prod");
343        assert_eq!(config.ssh_hosts[0].hostname.as_deref(), Some("10.0.0.5"));
344        assert_eq!(config.ssh_hosts[0].port, Some(22));
345    }
346
347    #[test]
348    fn parse_ssh_tunnels() {
349        let toml_str = r#"
350[[ssh_tunnels]]
351name = "pg"
352kind = "local"
353local_port = 5433
354remote_host = "127.0.0.1"
355remote_port = 5432
356host_alias = "prod"
357
358[[ssh_tunnels]]
359kind = "dynamic"
360local_port = 1080
361host_alias = "prod"
362"#;
363        let raw: RawConfig = toml::from_str(toml_str).unwrap();
364        let config: PrtConfig = raw.into();
365        assert_eq!(config.ssh_tunnels.len(), 2);
366        let s0 = config.ssh_tunnels[0].to_spec().unwrap();
367        assert_eq!(s0.kind, TunnelKind::Local);
368        assert_eq!(s0.local_port, 5433);
369        let s1 = config.ssh_tunnels[1].to_spec().unwrap();
370        assert_eq!(s1.kind, TunnelKind::Dynamic);
371    }
372
373    #[test]
374    fn parse_empty_toml_returns_defaults() {
375        let raw: RawConfig = toml::from_str("").unwrap();
376        let config: PrtConfig = raw.into();
377        assert!(config.known_ports.is_empty());
378        assert!(config.alerts.is_empty());
379    }
380
381    #[test]
382    fn parse_invalid_port_key_is_skipped() {
383        let toml_str = r#"
384[known_ports]
3859090 = "prometheus"
386not_a_port = "ignored"
387"#;
388        let raw: RawConfig = toml::from_str(toml_str).unwrap();
389        let config: PrtConfig = raw.into();
390        assert_eq!(config.known_ports.len(), 1);
391        assert_eq!(config.known_ports.get(&9090).unwrap(), "prometheus");
392    }
393
394    #[test]
395    fn load_config_returns_defaults_when_no_file() {
396        // In test environment, config_path() likely points to a nonexistent file
397        let config = load_config();
398        assert!(config.known_ports.is_empty());
399    }
400
401    #[test]
402    fn strip_ssh_tunnels_preserves_other_sections() {
403        let content = r#"
404[known_ports]
4059090 = "prom"
406
407[[alerts]]
408port = 22
409
410[[ssh_tunnels]]
411name = "old"
412kind = "local"
413local_port = 1
414remote_host = "x"
415remote_port = 1
416host_alias = "y"
417
418[[ssh_hosts]]
419alias = "z"
420"#;
421        let stripped = strip_ssh_tunnels_section(content);
422        assert!(stripped.contains("[known_ports]"));
423        assert!(stripped.contains("[[alerts]]"));
424        assert!(stripped.contains("[[ssh_hosts]]"));
425        assert!(!stripped.contains("[[ssh_tunnels]]"));
426        assert!(!stripped.contains("\"old\""));
427    }
428
429    #[test]
430    fn write_tunnels_roundtrip() {
431        let dir = tempdir();
432        let path = dir.join("config.toml");
433
434        let specs = vec![
435            SshTunnelSpec {
436                name: Some("pg".into()),
437                kind: TunnelKind::Local,
438                local_port: 5433,
439                remote_host: Some("127.0.0.1".into()),
440                remote_port: Some(5432),
441                host_alias: "prod".into(),
442            },
443            SshTunnelSpec {
444                name: None,
445                kind: TunnelKind::Dynamic,
446                local_port: 1080,
447                remote_host: None,
448                remote_port: None,
449                host_alias: "prod".into(),
450            },
451        ];
452        write_tunnels(&path, &specs).unwrap();
453
454        let content = std::fs::read_to_string(&path).unwrap();
455        let raw: RawConfig = toml::from_str(&content).unwrap();
456        let cfg: PrtConfig = raw.into();
457        assert_eq!(cfg.ssh_tunnels.len(), 2);
458        let s0 = cfg.ssh_tunnels[0].to_spec().unwrap();
459        assert_eq!(s0.local_port, 5433);
460        assert_eq!(s0.kind, TunnelKind::Local);
461        let s1 = cfg.ssh_tunnels[1].to_spec().unwrap();
462        assert_eq!(s1.kind, TunnelKind::Dynamic);
463
464        // Re-write replaces, doesn't duplicate.
465        write_tunnels(&path, &specs[..1]).unwrap();
466        let content = std::fs::read_to_string(&path).unwrap();
467        let raw: RawConfig = toml::from_str(&content).unwrap();
468        let cfg: PrtConfig = raw.into();
469        assert_eq!(cfg.ssh_tunnels.len(), 1);
470    }
471
472    #[test]
473    fn write_tunnels_propagates_read_errors() {
474        // A directory in place of the config file forces read_to_string to
475        // fail with a non-NotFound error (IsADirectory / Other on Linux).
476        let dir = tempdir();
477        let path = dir.join("not-a-file");
478        std::fs::create_dir(&path).unwrap();
479
480        let specs = vec![SshTunnelSpec {
481            name: None,
482            kind: TunnelKind::Local,
483            local_port: 1,
484            remote_host: Some("h".into()),
485            remote_port: Some(2),
486            host_alias: "a".into(),
487        }];
488        let err = write_tunnels(&path, &specs).expect_err("should not silently overwrite");
489        assert_ne!(err.kind(), io::ErrorKind::NotFound);
490    }
491
492    #[test]
493    fn write_tunnels_preserves_other_sections() {
494        let dir = tempdir();
495        let path = dir.join("config.toml");
496        let initial = "[known_ports]\n9090 = \"prom\"\n\n[[alerts]]\nport = 22\n";
497        std::fs::write(&path, initial).unwrap();
498
499        let specs = vec![SshTunnelSpec {
500            name: None,
501            kind: TunnelKind::Local,
502            local_port: 1,
503            remote_host: Some("h".into()),
504            remote_port: Some(2),
505            host_alias: "a".into(),
506        }];
507        write_tunnels(&path, &specs).unwrap();
508
509        let content = std::fs::read_to_string(&path).unwrap();
510        assert!(content.contains("[known_ports]"));
511        assert!(content.contains("[[alerts]]"));
512        assert!(content.contains("[[ssh_tunnels]]"));
513    }
514
515    fn tempdir() -> PathBuf {
516        use std::sync::atomic::{AtomicU64, Ordering};
517        static SEQ: AtomicU64 = AtomicU64::new(0);
518        let n = SEQ.fetch_add(1, Ordering::Relaxed);
519        let mut p = std::env::temp_dir();
520        p.push(format!(
521            "prt-test-{}-{}-{}",
522            std::process::id(),
523            std::time::SystemTime::now()
524                .duration_since(std::time::UNIX_EPOCH)
525                .unwrap()
526                .as_nanos(),
527            n,
528        ));
529        std::fs::create_dir_all(&p).unwrap();
530        p
531    }
532}