Skip to main content

netwatch_rs/
config.rs

1use crate::cli::{Args, DataUnit, TrafficUnit};
2use serde::{Deserialize, Serialize};
3use std::path::PathBuf;
4
5fn default_diagnostic_targets() -> Vec<String> {
6    vec![
7        "1.1.1.1".to_string(), // Cloudflare DNS (public, reliable)
8        "8.8.8.8".to_string(), // Google DNS (public, reliable)
9    ]
10}
11
12fn default_dns_domains() -> Vec<String> {
13    vec![
14        "cloudflare.com".to_string(), // Reliable test domain
15        "google.com".to_string(),     // Reliable test domain
16    ]
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Config {
21    #[serde(rename = "AverageWindow")]
22    pub average_window: u32,
23
24    #[serde(rename = "BarMaxIn")]
25    pub max_incoming: u64,
26
27    #[serde(rename = "BarMaxOut")]
28    pub max_outgoing: u64,
29
30    #[serde(rename = "DataFormat")]
31    pub data_format: String,
32
33    #[serde(rename = "Devices")]
34    pub devices: String,
35
36    #[serde(rename = "MultipleDevices")]
37    pub multiple_devices: bool,
38
39    #[serde(rename = "RefreshInterval")]
40    pub refresh_interval: u64,
41
42    #[serde(rename = "HighPerformance", default)]
43    pub high_performance: bool,
44
45    #[serde(rename = "TrafficFormat")]
46    pub traffic_format: String,
47
48    #[serde(rename = "DiagnosticTargets", default = "default_diagnostic_targets")]
49    pub diagnostic_targets: Vec<String>,
50
51    #[serde(rename = "DNSDomains", default = "default_dns_domains")]
52    pub dns_domains: Vec<String>,
53}
54
55impl Default for Config {
56    fn default() -> Self {
57        Self {
58            average_window: 300,
59            max_incoming: 0,
60            max_outgoing: 0,
61            data_format: "M".to_string(),
62            devices: "all".to_string(),
63            multiple_devices: false,
64            refresh_interval: 1000,
65            high_performance: false,
66            traffic_format: "k".to_string(),
67            diagnostic_targets: default_diagnostic_targets(),
68            dns_domains: default_dns_domains(),
69        }
70    }
71}
72
73impl Config {
74    pub fn load() -> anyhow::Result<Self> {
75        // Try to load from ~/.netwatch (modern) or ~/.nload (compatibility)
76        if let Some(home) = dirs::home_dir() {
77            let modern_config = home.join(".netwatch");
78            let legacy_config = home.join(".nload");
79
80            if modern_config.exists() {
81                let content = std::fs::read_to_string(modern_config)?;
82                return Ok(toml::from_str(&content)?);
83            } else if legacy_config.exists() {
84                // Parse nload format: Key="Value"
85                return Self::parse_nload_format(&legacy_config);
86            }
87        }
88
89        Ok(Self::default())
90    }
91
92    pub fn save(&self) -> anyhow::Result<()> {
93        if let Some(home) = dirs::home_dir() {
94            let config_path = home.join(".netwatch");
95            let content = toml::to_string_pretty(self)?;
96            std::fs::write(config_path, content)?;
97        }
98        Ok(())
99    }
100
101    pub fn apply_args(&mut self, args: &Args) {
102        self.average_window = args.average_window;
103        self.max_incoming = args.max_incoming;
104        self.max_outgoing = args.max_outgoing;
105        self.refresh_interval = args.refresh_interval;
106        self.high_performance = args.high_performance;
107        self.traffic_format = args.traffic_unit.to_string().to_string();
108        self.data_format = args.data_unit.to_string().to_string();
109        self.multiple_devices = args.multiple_devices;
110
111        // Enable high performance security monitoring if high-perf mode is enabled
112        if self.high_performance {
113            crate::security::enable_high_performance_security(true);
114        }
115    }
116
117    #[must_use]
118    pub fn get_traffic_unit(&self) -> TrafficUnit {
119        TrafficUnit::from_string(&self.traffic_format).unwrap_or(TrafficUnit::KiloBit)
120    }
121
122    #[must_use]
123    pub fn get_data_unit(&self) -> DataUnit {
124        DataUnit::from_string(&self.data_format).unwrap_or(DataUnit::MegaByte)
125    }
126
127    fn parse_nload_format(path: &PathBuf) -> anyhow::Result<Self> {
128        let content = std::fs::read_to_string(path)?;
129        let mut config = Self::default();
130
131        for line in content.lines() {
132            let line = line.trim();
133            if line.is_empty() || line.starts_with('#') {
134                continue;
135            }
136
137            if let Some((key, value)) = line.split_once('=') {
138                let key = key.trim();
139                let value = value.trim().trim_matches('"');
140
141                match key {
142                    "AverageWindow" => config.average_window = value.parse().unwrap_or(300),
143                    "BarMaxIn" => config.max_incoming = value.parse().unwrap_or(0),
144                    "BarMaxOut" => config.max_outgoing = value.parse().unwrap_or(0),
145                    "DataFormat" => config.data_format = value.to_string(),
146                    "Devices" => config.devices = value.to_string(),
147                    "MultipleDevices" => config.multiple_devices = value.parse().unwrap_or(false),
148                    "RefreshInterval" => config.refresh_interval = value.parse().unwrap_or(500),
149                    "TrafficFormat" => config.traffic_format = value.to_string(),
150                    _ => {} // Ignore unknown keys
151                }
152            }
153        }
154
155        Ok(config)
156    }
157}