1use crate::error::ConfigError;
2use indexmap::IndexMap;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Default, Serialize, Deserialize)]
7#[serde(deny_unknown_fields)]
8pub struct GlobalConfig {
9 #[serde(default)]
10 pub default: DefaultProfile,
11 #[serde(default)]
12 pub connection: IndexMap<String, ConnectionProfile>,
13 #[serde(default)]
14 pub history: HistoryConfig,
15 #[serde(default)]
16 pub slow_log: SlowLogConfig,
17 #[serde(default)]
18 pub cache: CacheConfig,
19}
20
21impl GlobalConfig {
22 pub fn load(explicit_path: Option<&str>) -> Result<Self, ConfigError> {
27 let path = if let Some(p) = explicit_path {
28 std::path::PathBuf::from(p)
29 } else {
30 Self::find_config_path()?
31 };
32 if !path.exists() {
33 return Ok(Self::default());
34 }
35 Self::load_from(&path)
36 }
37
38 pub fn load_from(path: &std::path::Path) -> Result<Self, ConfigError> {
40 let content = std::fs::read_to_string(path)
41 .map_err(|e| ConfigError::ConfigNotFound(format!("{}: {}", path.display(), e)))?;
42 let mut config: GlobalConfig =
43 toml::from_str(&content).map_err(|e| ConfigError::InvalidConfig(e.to_string()))?;
44 for profile in config.connection.values_mut() {
46 profile.url = crate::registry::interpolate_env_vars(&profile.url);
47 if let Some(host) = &profile.ssh_host {
48 profile.ssh_host = Some(crate::registry::interpolate_env_vars(host));
49 }
50 if let Some(user) = &profile.ssh_user {
51 profile.ssh_user = Some(crate::registry::interpolate_env_vars(user));
52 }
53 if let Some(key) = &profile.ssh_key {
54 profile.ssh_key = Some(crate::registry::interpolate_env_vars(key));
55 }
56 }
57 Ok(config)
58 }
59
60 fn find_config_path() -> Result<std::path::PathBuf, ConfigError> {
61 if let Ok(cwd) = std::env::current_dir() {
63 let local = cwd.join(".ferrule.toml");
64 if local.exists() {
65 return Ok(local);
66 }
67 }
68 let config_dir = dirs::config_dir()
70 .ok_or_else(|| {
71 ConfigError::ConfigNotFound("could not determine config directory".into())
72 })?
73 .join("ferrule");
74 Ok(config_dir.join("ferrule.toml"))
75 }
76
77 pub fn resolve_format(&self, cli: Option<&str>) -> String {
79 cli.map(|s| s.to_string())
80 .unwrap_or_else(|| self.default.format.clone())
81 }
82
83 pub fn resolve_limit(&self, cli: Option<usize>) -> Option<usize> {
85 cli.or(self.default.limit_checked())
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(deny_unknown_fields)]
91pub struct DefaultProfile {
92 #[serde(default = "default_format")]
93 pub format: String,
94 #[serde(default = "default_limit")]
95 pub limit: usize,
96 #[serde(default = "default_timeout")]
97 pub timeout: u64,
98}
99
100impl DefaultProfile {
101 pub fn limit_checked(&self) -> Option<usize> {
103 if self.limit == 0 {
104 None
105 } else {
106 Some(self.limit)
107 }
108 }
109}
110
111impl Default for DefaultProfile {
112 fn default() -> Self {
113 Self {
114 format: default_format(),
115 limit: default_limit(),
116 timeout: default_timeout(),
117 }
118 }
119}
120
121fn default_format() -> String {
122 "json".to_string()
123}
124
125fn default_limit() -> usize {
126 1000
127}
128
129fn default_timeout() -> u64 {
130 30
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
139#[serde(deny_unknown_fields)]
140pub struct HistoryConfig {
141 #[serde(default = "default_history_enabled")]
142 pub enabled: bool,
143 #[serde(default = "default_history_max_age_days")]
147 pub max_age_days: u32,
148 #[serde(default = "default_history_max_rows")]
152 pub max_rows: u64,
153 #[serde(default)]
157 pub path: Option<String>,
158}
159
160impl Default for HistoryConfig {
161 fn default() -> Self {
162 Self {
163 enabled: default_history_enabled(),
164 max_age_days: default_history_max_age_days(),
165 max_rows: default_history_max_rows(),
166 path: None,
167 }
168 }
169}
170
171fn default_history_enabled() -> bool {
172 true
173}
174
175fn default_history_max_age_days() -> u32 {
176 30
177}
178
179fn default_history_max_rows() -> u64 {
180 100_000
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
195#[serde(deny_unknown_fields)]
196pub struct SlowLogConfig {
197 #[serde(default)]
198 pub enabled: bool,
199 #[serde(default = "default_slow_threshold")]
202 pub threshold: String,
203 #[serde(default)]
206 pub path: Option<String>,
207 #[serde(default)]
211 pub max_size: Option<String>,
212}
213
214impl Default for SlowLogConfig {
215 fn default() -> Self {
216 Self {
217 enabled: false,
218 threshold: default_slow_threshold(),
219 path: None,
220 max_size: None,
221 }
222 }
223}
224
225impl SlowLogConfig {
226 pub fn threshold_ms(&self) -> Result<u64, String> {
230 parse_threshold_ms(&self.threshold)
231 }
232
233 pub fn max_size_bytes(&self) -> Result<Option<u64>, String> {
237 match self.max_size.as_deref() {
238 None => Ok(None),
239 Some(s) => crate::parse::parse_size(s).map(Some),
240 }
241 }
242}
243
244fn default_slow_threshold() -> String {
245 "1s".to_string()
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
256#[serde(deny_unknown_fields)]
257pub struct CacheConfig {
258 #[serde(default = "default_cache_enabled")]
259 pub enabled: bool,
260 #[serde(default = "default_cache_ttl")]
265 pub default_ttl: String,
266 #[serde(default = "default_cache_max_age_days")]
269 pub max_age_days: u32,
270 #[serde(default = "default_cache_max_rows")]
273 pub max_rows: u64,
274 #[serde(default)]
277 pub path: Option<String>,
278}
279
280impl Default for CacheConfig {
281 fn default() -> Self {
282 Self {
283 enabled: default_cache_enabled(),
284 default_ttl: default_cache_ttl(),
285 max_age_days: default_cache_max_age_days(),
286 max_rows: default_cache_max_rows(),
287 path: None,
288 }
289 }
290}
291
292fn default_cache_enabled() -> bool {
293 true
294}
295
296fn default_cache_ttl() -> String {
297 "5m".to_string()
298}
299
300fn default_cache_max_age_days() -> u32 {
301 7
302}
303
304fn default_cache_max_rows() -> u64 {
305 10_000
306}
307
308fn parse_threshold_ms(s: &str) -> Result<u64, String> {
315 let s = s.trim();
316 if s.is_empty() {
317 return Err("threshold is empty".into());
318 }
319 if let Ok(ms) = s.parse::<u64>() {
321 return Ok(ms);
322 }
323 crate::parse::parse_duration(s)
324 .map(|d| d.num_milliseconds() as u64)
325 .map_err(|e| format!("threshold: {e}"))
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
329#[serde(deny_unknown_fields)]
330pub struct ConnectionProfile {
331 pub url: String,
332 #[serde(default)]
333 pub password_url: Option<String>,
334 #[serde(default)]
335 pub headers: IndexMap<String, String>,
336
337 #[serde(default)]
341 pub ssh_host: Option<String>,
342 #[serde(default)]
344 pub ssh_user: Option<String>,
345 #[serde(default)]
347 pub ssh_port: Option<u16>,
348 #[serde(default)]
352 pub ssh_key: Option<String>,
353
354 #[serde(default)]
356 pub proxy_url: Option<String>,
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use std::io::Write;
363
364 #[test]
365 fn test_load_global_config_defaults() {
366 let config = GlobalConfig::load(Some("/nonexistent/path.toml")).unwrap();
367 assert_eq!(config.default.format, "json");
368 assert_eq!(config.default.limit, 1000);
369 assert_eq!(config.default.timeout, 30);
370 assert!(config.connection.is_empty());
371 }
372
373 #[test]
374 fn test_load_global_config_from_file() {
375 let mut tmp = tempfile::NamedTempFile::new().unwrap();
376 let content = r#"
377[default]
378format = "table"
379limit = 500
380timeout = 60
381
382[connection.production]
383url = "postgres://user:pass@host/db"
384"#;
385 tmp.write_all(content.as_bytes()).unwrap();
386 let config = GlobalConfig::load_from(tmp.path()).unwrap();
387 assert_eq!(config.default.format, "table");
388 assert_eq!(config.default.limit, 500);
389 assert_eq!(config.default.timeout, 60);
390 assert_eq!(config.connection.len(), 1);
391 let prod = config.connection.get("production").unwrap();
392 assert_eq!(prod.url, "postgres://user:pass@host/db");
393 }
394
395 fn slow(t: &str) -> SlowLogConfig {
396 SlowLogConfig {
397 enabled: true,
398 threshold: t.into(),
399 path: None,
400 max_size: None,
401 }
402 }
403
404 #[test]
405 fn slow_log_threshold_parses_humantime_and_bare_ms() {
406 assert_eq!(SlowLogConfig::default().threshold_ms().unwrap(), 1_000);
407 assert_eq!(slow("250ms").threshold_ms().unwrap(), 250);
408 assert_eq!(slow("500").threshold_ms().unwrap(), 500);
409 assert_eq!(slow("2s").threshold_ms().unwrap(), 2_000);
410 assert_eq!(slow("5m").threshold_ms().unwrap(), 300_000);
411 assert_eq!(slow("1h").threshold_ms().unwrap(), 3_600_000);
412 }
413
414 #[test]
415 fn slow_log_threshold_rejects_bad_input() {
416 assert!(slow("").threshold_ms().is_err());
417 assert!(slow("fast").threshold_ms().is_err());
418 assert!(slow("5x").threshold_ms().is_err());
419 }
420
421 #[test]
422 fn slow_log_max_size_bytes_resolves() {
423 assert_eq!(SlowLogConfig::default().max_size_bytes().unwrap(), None);
425 let mut cfg = SlowLogConfig {
427 max_size: Some("10MB".into()),
428 ..SlowLogConfig::default()
429 };
430 assert_eq!(cfg.max_size_bytes().unwrap(), Some(10_000_000));
431 cfg.max_size = Some("5MiB".into());
432 assert_eq!(cfg.max_size_bytes().unwrap(), Some(5 * 1_024 * 1_024));
433 cfg.max_size = Some("bad".into());
435 assert!(cfg.max_size_bytes().is_err());
436 }
437
438 #[test]
439 fn slow_log_max_size_round_trips_through_toml() {
440 let toml = r#"
443[slow_log]
444enabled = true
445threshold = "1s"
446max_size = "5MiB"
447"#;
448 let cfg: GlobalConfig = toml::from_str(toml).unwrap();
449 assert_eq!(
450 cfg.slow_log.max_size_bytes().unwrap(),
451 Some(5 * 1_024 * 1_024)
452 );
453 let toml2 = r#"
455[slow_log]
456enabled = true
457"#;
458 let cfg2: GlobalConfig = toml::from_str(toml2).unwrap();
459 assert_eq!(cfg2.slow_log.max_size_bytes().unwrap(), None);
460 }
461
462 #[test]
463 fn test_resolve_format_and_limit() {
464 let mut config = GlobalConfig::default();
465 config.default.format = "csv".into();
466 config.default.limit = 50;
467
468 assert_eq!(config.resolve_format(None), "csv");
469 assert_eq!(config.resolve_format(Some("json")), "json");
470 assert_eq!(config.resolve_limit(None), Some(50));
471 assert_eq!(config.resolve_limit(Some(10)), Some(10));
472 }
473
474 #[test]
475 fn test_env_interpolation_in_profile_url() {
476 std::env::set_var("FERRULE_TEST_PROFILE_HOST", "myhost");
477 let mut tmp = tempfile::NamedTempFile::new().unwrap();
478 let content = r#"
479[connection.test]
480url = "postgres://user@${FERRULE_TEST_PROFILE_HOST}/db"
481"#;
482 tmp.write_all(content.as_bytes()).unwrap();
483 let config = GlobalConfig::load_from(tmp.path()).unwrap();
484 let test = config.connection.get("test").unwrap();
485 assert_eq!(test.url, "postgres://user@myhost/db");
486 std::env::remove_var("FERRULE_TEST_PROFILE_HOST");
487 }
488
489 #[test]
490 fn ssh_keys_default_to_none() {
491 let mut tmp = tempfile::NamedTempFile::new().unwrap();
492 let content = r#"
493[connection.plain]
494url = "postgres://user:pass@host/db"
495"#;
496 tmp.write_all(content.as_bytes()).unwrap();
497 let config = GlobalConfig::load_from(tmp.path()).unwrap();
498 let plain = config.connection.get("plain").unwrap();
499 assert!(plain.ssh_host.is_none());
500 assert!(plain.ssh_user.is_none());
501 assert!(plain.ssh_port.is_none());
502 assert!(plain.ssh_key.is_none());
503 }
504
505 #[test]
506 fn ssh_keys_parse_when_present() {
507 let mut tmp = tempfile::NamedTempFile::new().unwrap();
508 let content = r#"
509[connection.tunneled]
510url = "postgres://app:pwd@10.0.0.5:5432/myapp"
511ssh_host = "bastion.example.com"
512ssh_user = "ec2-user"
513ssh_port = 2222
514ssh_key = "/home/me/.ssh/id_ed25519"
515"#;
516 tmp.write_all(content.as_bytes()).unwrap();
517 let config = GlobalConfig::load_from(tmp.path()).unwrap();
518 let tunneled = config.connection.get("tunneled").unwrap();
519 assert_eq!(tunneled.ssh_host.as_deref(), Some("bastion.example.com"));
520 assert_eq!(tunneled.ssh_user.as_deref(), Some("ec2-user"));
521 assert_eq!(tunneled.ssh_port, Some(2222));
522 assert_eq!(
523 tunneled.ssh_key.as_deref(),
524 Some("/home/me/.ssh/id_ed25519")
525 );
526 }
527
528 #[test]
529 fn ssh_partial_keys_parse_independently() {
530 let mut tmp = tempfile::NamedTempFile::new().unwrap();
534 let content = r#"
535[connection.minimal]
536url = "postgres://app@db-host/myapp"
537ssh_host = "bastion"
538"#;
539 tmp.write_all(content.as_bytes()).unwrap();
540 let config = GlobalConfig::load_from(tmp.path()).unwrap();
541 let minimal = config.connection.get("minimal").unwrap();
542 assert_eq!(minimal.ssh_host.as_deref(), Some("bastion"));
543 assert!(minimal.ssh_user.is_none());
544 assert!(minimal.ssh_port.is_none());
545 assert!(minimal.ssh_key.is_none());
546 }
547
548 #[test]
549 fn ssh_host_and_key_get_env_interpolation() {
550 std::env::set_var("FERRULE_TEST_BASTION", "bastion.prod");
551 std::env::set_var("FERRULE_TEST_KEYDIR", "/keys");
552 let mut tmp = tempfile::NamedTempFile::new().unwrap();
553 let content = r#"
554[connection.tmpl]
555url = "postgres://app@db/myapp"
556ssh_host = "${FERRULE_TEST_BASTION}"
557ssh_key = "${FERRULE_TEST_KEYDIR}/id_rsa"
558"#;
559 tmp.write_all(content.as_bytes()).unwrap();
560 let config = GlobalConfig::load_from(tmp.path()).unwrap();
561 let tmpl = config.connection.get("tmpl").unwrap();
562 assert_eq!(tmpl.ssh_host.as_deref(), Some("bastion.prod"));
563 assert_eq!(tmpl.ssh_key.as_deref(), Some("/keys/id_rsa"));
564 std::env::remove_var("FERRULE_TEST_BASTION");
565 std::env::remove_var("FERRULE_TEST_KEYDIR");
566 }
567}