1use serde::Deserialize;
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11#[derive(Debug, Clone, Default, Deserialize)]
13#[serde(default)]
14struct RawConfig {
15 known_ports: HashMap<String, String>,
16 alerts: Vec<AlertRuleConfig>,
17}
18
19#[derive(Debug, Clone, Default)]
21pub struct PrtConfig {
22 pub known_ports: HashMap<u16, String>,
31
32 pub alerts: Vec<AlertRuleConfig>,
40}
41
42impl From<RawConfig> for PrtConfig {
43 fn from(raw: RawConfig) -> Self {
44 let known_ports = raw
45 .known_ports
46 .into_iter()
47 .filter_map(|(k, v)| k.parse::<u16>().ok().map(|port| (port, v)))
48 .collect();
49 Self {
50 known_ports,
51 alerts: raw.alerts,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Default, Deserialize)]
58#[serde(default)]
59pub struct AlertRuleConfig {
60 pub port: Option<u16>,
61 pub process: Option<String>,
62 pub state: Option<String>,
63 pub connections_gt: Option<usize>,
64 #[serde(default = "default_action")]
65 pub action: String,
66}
67
68fn default_action() -> String {
69 "highlight".into()
70}
71
72pub fn config_dir() -> Option<PathBuf> {
74 dirs::config_dir().map(|d| d.join("prt"))
75}
76
77pub fn config_path() -> Option<PathBuf> {
79 config_dir().map(|d| d.join("config.toml"))
80}
81
82pub fn load_config() -> PrtConfig {
88 let path = match config_path() {
89 Some(p) => p,
90 None => return PrtConfig::default(),
91 };
92
93 let content = match std::fs::read_to_string(&path) {
94 Ok(c) => c,
95 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
96 return PrtConfig::default();
97 }
98 Err(e) => {
99 eprintln!("prt: warning: cannot read {}: {e}", path.display());
100 return PrtConfig::default();
101 }
102 };
103
104 match toml::from_str::<RawConfig>(&content) {
105 Ok(raw) => raw.into(),
106 Err(e) => {
107 eprintln!("prt: warning: cannot parse {}: {e}", path.display());
108 PrtConfig::default()
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn default_config_is_empty() {
119 let config = PrtConfig::default();
120 assert!(config.known_ports.is_empty());
121 assert!(config.alerts.is_empty());
122 }
123
124 #[test]
125 fn parse_known_ports() {
126 let toml_str = r#"
127[known_ports]
1289090 = "prometheus"
1293000 = "grafana"
130"#;
131 let raw: RawConfig = toml::from_str(toml_str).unwrap();
132 let config: PrtConfig = raw.into();
133 assert_eq!(config.known_ports.get(&9090).unwrap(), "prometheus");
134 assert_eq!(config.known_ports.get(&3000).unwrap(), "grafana");
135 }
136
137 #[test]
138 fn parse_alert_rules() {
139 let toml_str = r#"
140[[alerts]]
141port = 22
142action = "bell"
143
144[[alerts]]
145process = "python"
146state = "LISTEN"
147action = "highlight"
148
149[[alerts]]
150connections_gt = 100
151"#;
152 let raw: RawConfig = toml::from_str(toml_str).unwrap();
153 let config: PrtConfig = raw.into();
154 assert_eq!(config.alerts.len(), 3);
155 assert_eq!(config.alerts[0].port, Some(22));
156 assert_eq!(config.alerts[0].action, "bell");
157 assert_eq!(config.alerts[1].process.as_deref(), Some("python"));
158 assert_eq!(config.alerts[2].connections_gt, Some(100));
159 assert_eq!(config.alerts[2].action, "highlight"); }
161
162 #[test]
163 fn parse_empty_toml_returns_defaults() {
164 let raw: RawConfig = toml::from_str("").unwrap();
165 let config: PrtConfig = raw.into();
166 assert!(config.known_ports.is_empty());
167 assert!(config.alerts.is_empty());
168 }
169
170 #[test]
171 fn parse_invalid_port_key_is_skipped() {
172 let toml_str = r#"
173[known_ports]
1749090 = "prometheus"
175not_a_port = "ignored"
176"#;
177 let raw: RawConfig = toml::from_str(toml_str).unwrap();
178 let config: PrtConfig = raw.into();
179 assert_eq!(config.known_ports.len(), 1);
180 assert_eq!(config.known_ports.get(&9090).unwrap(), "prometheus");
181 }
182
183 #[test]
184 fn load_config_returns_defaults_when_no_file() {
185 let config = load_config();
187 assert!(config.known_ports.is_empty());
188 }
189}