1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::io;
11use std::path::{Path, PathBuf};
12
13use crate::core::ssh_tunnel::{SshTunnelSpec, TunnelKind};
14
15#[derive(Debug, Clone, Default, Deserialize)]
17#[serde(default)]
18struct RawConfig {
19 known_ports: HashMap<String, String>,
20 alerts: Vec<AlertRuleConfig>,
21 ssh_hosts: Vec<SshHostConfig>,
22 ssh_tunnels: Vec<SshTunnelConfig>,
23}
24
25#[derive(Debug, Clone, Default)]
27pub struct PrtConfig {
28 pub known_ports: HashMap<u16, String>,
37
38 pub alerts: Vec<AlertRuleConfig>,
46
47 pub ssh_hosts: Vec<SshHostConfig>,
49
50 pub ssh_tunnels: Vec<SshTunnelConfig>,
52}
53
54impl From<RawConfig> for PrtConfig {
55 fn from(raw: RawConfig) -> Self {
56 let known_ports = raw
57 .known_ports
58 .into_iter()
59 .filter_map(|(k, v)| k.parse::<u16>().ok().map(|port| (port, v)))
60 .collect();
61 Self {
62 known_ports,
63 alerts: raw.alerts,
64 ssh_hosts: raw.ssh_hosts,
65 ssh_tunnels: raw.ssh_tunnels,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Default, Deserialize)]
72#[serde(default)]
73pub struct AlertRuleConfig {
74 pub port: Option<u16>,
75 pub process: Option<String>,
76 pub state: Option<String>,
77 pub connections_gt: Option<usize>,
78 #[serde(default = "default_action")]
79 pub action: String,
80}
81
82fn default_action() -> String {
83 "highlight".into()
84}
85
86#[derive(Debug, Clone, Default, Deserialize, Serialize)]
97#[serde(default)]
98pub struct SshHostConfig {
99 pub alias: String,
100 pub hostname: Option<String>,
101 pub user: Option<String>,
102 pub port: Option<u16>,
103 pub identity_file: Option<String>,
104}
105
106#[derive(Debug, Clone, Default, Deserialize, Serialize)]
118#[serde(default)]
119pub struct SshTunnelConfig {
120 pub name: Option<String>,
121 #[serde(default = "default_tunnel_kind")]
122 pub kind: String,
123 pub local_port: u16,
124 pub remote_host: Option<String>,
125 pub remote_port: Option<u16>,
126 pub host_alias: String,
127}
128
129fn default_tunnel_kind() -> String {
130 "local".into()
131}
132
133impl SshTunnelConfig {
134 pub fn to_spec(&self) -> Option<SshTunnelSpec> {
137 let kind = match self.kind.to_ascii_lowercase().as_str() {
138 "local" => TunnelKind::Local,
139 "dynamic" => TunnelKind::Dynamic,
140 _ => return None,
141 };
142 Some(SshTunnelSpec {
143 name: self.name.clone(),
144 kind,
145 local_port: self.local_port,
146 remote_host: self.remote_host.clone(),
147 remote_port: self.remote_port,
148 host_alias: self.host_alias.clone(),
149 })
150 }
151
152 pub fn from_spec(spec: &SshTunnelSpec) -> Self {
153 Self {
154 name: spec.name.clone(),
155 kind: spec.kind.label().into(),
156 local_port: spec.local_port,
157 remote_host: spec.remote_host.clone(),
158 remote_port: spec.remote_port,
159 host_alias: spec.host_alias.clone(),
160 }
161 }
162}
163
164pub fn config_dir() -> Option<PathBuf> {
166 dirs::config_dir().map(|d| d.join("prt"))
167}
168
169pub fn config_path() -> Option<PathBuf> {
171 config_dir().map(|d| d.join("config.toml"))
172}
173
174pub fn load_config() -> PrtConfig {
180 let path = match config_path() {
181 Some(p) => p,
182 None => return PrtConfig::default(),
183 };
184
185 let content = match std::fs::read_to_string(&path) {
186 Ok(c) => c,
187 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
188 return PrtConfig::default();
189 }
190 Err(e) => {
191 eprintln!("prt: warning: cannot read {}: {e}", path.display());
192 return PrtConfig::default();
193 }
194 };
195
196 match toml::from_str::<RawConfig>(&content) {
197 Ok(raw) => raw.into(),
198 Err(e) => {
199 eprintln!("prt: warning: cannot parse {}: {e}", path.display());
200 PrtConfig::default()
201 }
202 }
203}
204
205pub fn write_tunnels(path: &Path, specs: &[SshTunnelSpec]) -> io::Result<()> {
211 if let Some(parent) = path.parent() {
212 std::fs::create_dir_all(parent)?;
213 }
214 let existing = match std::fs::read_to_string(path) {
219 Ok(c) => c,
220 Err(e) if e.kind() == io::ErrorKind::NotFound => String::new(),
221 Err(e) => return Err(e),
222 };
223 let stripped = strip_ssh_tunnels_section(&existing);
224
225 let configs: Vec<SshTunnelConfig> = specs.iter().map(SshTunnelConfig::from_spec).collect();
226
227 #[derive(Serialize)]
228 struct Wrap<'a> {
229 ssh_tunnels: &'a [SshTunnelConfig],
230 }
231 let appended = if configs.is_empty() {
232 String::new()
233 } else {
234 toml::to_string(&Wrap {
235 ssh_tunnels: &configs,
236 })
237 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
238 };
239
240 let mut out = stripped.trim_end().to_string();
241 if !out.is_empty() {
242 out.push('\n');
243 out.push('\n');
244 }
245 out.push_str(&appended);
246 if !out.ends_with('\n') {
247 out.push('\n');
248 }
249 std::fs::write(path, out)
250}
251
252fn strip_ssh_tunnels_section(content: &str) -> String {
256 let mut out = String::with_capacity(content.len());
257 let mut skipping = false;
258 for line in content.lines() {
259 let trimmed = line.trim_start();
260 if trimmed.starts_with("[[ssh_tunnels]]") || trimmed.starts_with("[ssh_tunnels]") {
261 skipping = true;
262 continue;
263 }
264 if skipping {
265 if trimmed.starts_with('[') {
267 skipping = false;
268 } else {
269 continue;
270 }
271 }
272 out.push_str(line);
273 out.push('\n');
274 }
275 out
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn default_config_is_empty() {
284 let config = PrtConfig::default();
285 assert!(config.known_ports.is_empty());
286 assert!(config.alerts.is_empty());
287 assert!(config.ssh_hosts.is_empty());
288 assert!(config.ssh_tunnels.is_empty());
289 }
290
291 #[test]
292 fn parse_known_ports() {
293 let toml_str = r#"
294[known_ports]
2959090 = "prometheus"
2963000 = "grafana"
297"#;
298 let raw: RawConfig = toml::from_str(toml_str).unwrap();
299 let config: PrtConfig = raw.into();
300 assert_eq!(config.known_ports.get(&9090).unwrap(), "prometheus");
301 assert_eq!(config.known_ports.get(&3000).unwrap(), "grafana");
302 }
303
304 #[test]
305 fn parse_alert_rules() {
306 let toml_str = r#"
307[[alerts]]
308port = 22
309action = "bell"
310
311[[alerts]]
312process = "python"
313state = "LISTEN"
314action = "highlight"
315
316[[alerts]]
317connections_gt = 100
318"#;
319 let raw: RawConfig = toml::from_str(toml_str).unwrap();
320 let config: PrtConfig = raw.into();
321 assert_eq!(config.alerts.len(), 3);
322 assert_eq!(config.alerts[0].port, Some(22));
323 assert_eq!(config.alerts[0].action, "bell");
324 assert_eq!(config.alerts[1].process.as_deref(), Some("python"));
325 assert_eq!(config.alerts[2].connections_gt, Some(100));
326 assert_eq!(config.alerts[2].action, "highlight"); }
328
329 #[test]
330 fn parse_ssh_hosts() {
331 let toml_str = r#"
332[[ssh_hosts]]
333alias = "prod"
334hostname = "10.0.0.5"
335user = "deploy"
336port = 22
337identity_file = "~/.ssh/id_ed25519"
338"#;
339 let raw: RawConfig = toml::from_str(toml_str).unwrap();
340 let config: PrtConfig = raw.into();
341 assert_eq!(config.ssh_hosts.len(), 1);
342 assert_eq!(config.ssh_hosts[0].alias, "prod");
343 assert_eq!(config.ssh_hosts[0].hostname.as_deref(), Some("10.0.0.5"));
344 assert_eq!(config.ssh_hosts[0].port, Some(22));
345 }
346
347 #[test]
348 fn parse_ssh_tunnels() {
349 let toml_str = r#"
350[[ssh_tunnels]]
351name = "pg"
352kind = "local"
353local_port = 5433
354remote_host = "127.0.0.1"
355remote_port = 5432
356host_alias = "prod"
357
358[[ssh_tunnels]]
359kind = "dynamic"
360local_port = 1080
361host_alias = "prod"
362"#;
363 let raw: RawConfig = toml::from_str(toml_str).unwrap();
364 let config: PrtConfig = raw.into();
365 assert_eq!(config.ssh_tunnels.len(), 2);
366 let s0 = config.ssh_tunnels[0].to_spec().unwrap();
367 assert_eq!(s0.kind, TunnelKind::Local);
368 assert_eq!(s0.local_port, 5433);
369 let s1 = config.ssh_tunnels[1].to_spec().unwrap();
370 assert_eq!(s1.kind, TunnelKind::Dynamic);
371 }
372
373 #[test]
374 fn parse_empty_toml_returns_defaults() {
375 let raw: RawConfig = toml::from_str("").unwrap();
376 let config: PrtConfig = raw.into();
377 assert!(config.known_ports.is_empty());
378 assert!(config.alerts.is_empty());
379 }
380
381 #[test]
382 fn parse_invalid_port_key_is_skipped() {
383 let toml_str = r#"
384[known_ports]
3859090 = "prometheus"
386not_a_port = "ignored"
387"#;
388 let raw: RawConfig = toml::from_str(toml_str).unwrap();
389 let config: PrtConfig = raw.into();
390 assert_eq!(config.known_ports.len(), 1);
391 assert_eq!(config.known_ports.get(&9090).unwrap(), "prometheus");
392 }
393
394 #[test]
395 fn load_config_returns_defaults_when_no_file() {
396 let config = load_config();
398 assert!(config.known_ports.is_empty());
399 }
400
401 #[test]
402 fn strip_ssh_tunnels_preserves_other_sections() {
403 let content = r#"
404[known_ports]
4059090 = "prom"
406
407[[alerts]]
408port = 22
409
410[[ssh_tunnels]]
411name = "old"
412kind = "local"
413local_port = 1
414remote_host = "x"
415remote_port = 1
416host_alias = "y"
417
418[[ssh_hosts]]
419alias = "z"
420"#;
421 let stripped = strip_ssh_tunnels_section(content);
422 assert!(stripped.contains("[known_ports]"));
423 assert!(stripped.contains("[[alerts]]"));
424 assert!(stripped.contains("[[ssh_hosts]]"));
425 assert!(!stripped.contains("[[ssh_tunnels]]"));
426 assert!(!stripped.contains("\"old\""));
427 }
428
429 #[test]
430 fn write_tunnels_roundtrip() {
431 let dir = tempdir();
432 let path = dir.join("config.toml");
433
434 let specs = vec![
435 SshTunnelSpec {
436 name: Some("pg".into()),
437 kind: TunnelKind::Local,
438 local_port: 5433,
439 remote_host: Some("127.0.0.1".into()),
440 remote_port: Some(5432),
441 host_alias: "prod".into(),
442 },
443 SshTunnelSpec {
444 name: None,
445 kind: TunnelKind::Dynamic,
446 local_port: 1080,
447 remote_host: None,
448 remote_port: None,
449 host_alias: "prod".into(),
450 },
451 ];
452 write_tunnels(&path, &specs).unwrap();
453
454 let content = std::fs::read_to_string(&path).unwrap();
455 let raw: RawConfig = toml::from_str(&content).unwrap();
456 let cfg: PrtConfig = raw.into();
457 assert_eq!(cfg.ssh_tunnels.len(), 2);
458 let s0 = cfg.ssh_tunnels[0].to_spec().unwrap();
459 assert_eq!(s0.local_port, 5433);
460 assert_eq!(s0.kind, TunnelKind::Local);
461 let s1 = cfg.ssh_tunnels[1].to_spec().unwrap();
462 assert_eq!(s1.kind, TunnelKind::Dynamic);
463
464 write_tunnels(&path, &specs[..1]).unwrap();
466 let content = std::fs::read_to_string(&path).unwrap();
467 let raw: RawConfig = toml::from_str(&content).unwrap();
468 let cfg: PrtConfig = raw.into();
469 assert_eq!(cfg.ssh_tunnels.len(), 1);
470 }
471
472 #[test]
473 fn write_tunnels_propagates_read_errors() {
474 let dir = tempdir();
477 let path = dir.join("not-a-file");
478 std::fs::create_dir(&path).unwrap();
479
480 let specs = vec![SshTunnelSpec {
481 name: None,
482 kind: TunnelKind::Local,
483 local_port: 1,
484 remote_host: Some("h".into()),
485 remote_port: Some(2),
486 host_alias: "a".into(),
487 }];
488 let err = write_tunnels(&path, &specs).expect_err("should not silently overwrite");
489 assert_ne!(err.kind(), io::ErrorKind::NotFound);
490 }
491
492 #[test]
493 fn write_tunnels_preserves_other_sections() {
494 let dir = tempdir();
495 let path = dir.join("config.toml");
496 let initial = "[known_ports]\n9090 = \"prom\"\n\n[[alerts]]\nport = 22\n";
497 std::fs::write(&path, initial).unwrap();
498
499 let specs = vec![SshTunnelSpec {
500 name: None,
501 kind: TunnelKind::Local,
502 local_port: 1,
503 remote_host: Some("h".into()),
504 remote_port: Some(2),
505 host_alias: "a".into(),
506 }];
507 write_tunnels(&path, &specs).unwrap();
508
509 let content = std::fs::read_to_string(&path).unwrap();
510 assert!(content.contains("[known_ports]"));
511 assert!(content.contains("[[alerts]]"));
512 assert!(content.contains("[[ssh_tunnels]]"));
513 }
514
515 fn tempdir() -> PathBuf {
516 use std::sync::atomic::{AtomicU64, Ordering};
517 static SEQ: AtomicU64 = AtomicU64::new(0);
518 let n = SEQ.fetch_add(1, Ordering::Relaxed);
519 let mut p = std::env::temp_dir();
520 p.push(format!(
521 "prt-test-{}-{}-{}",
522 std::process::id(),
523 std::time::SystemTime::now()
524 .duration_since(std::time::UNIX_EPOCH)
525 .unwrap()
526 .as_nanos(),
527 n,
528 ));
529 std::fs::create_dir_all(&p).unwrap();
530 p
531 }
532}