Skip to main content

gritty/
config.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use serde::Deserialize;
5
6/// Embedded default config template (from repo root config.toml).
7pub const DEFAULT_CONFIG: &str = include_str!("../config.toml");
8
9/// Resolved session settings after merging all config layers.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct SessionSettings {
12    pub forward_agent: bool,
13    pub forward_open: bool,
14    pub no_escape: bool,
15    pub no_redraw: bool,
16    pub oauth_redirect: bool,
17    pub oauth_timeout: u64,
18    pub heartbeat_interval: u64,
19    pub heartbeat_timeout: u64,
20    pub ring_buffer_size: u64,
21    pub oauth_tunnel_idle_timeout: u64,
22}
23
24impl Default for SessionSettings {
25    fn default() -> Self {
26        Self {
27            forward_agent: false,
28            forward_open: false,
29            no_escape: false,
30            no_redraw: false,
31            oauth_redirect: true,
32            oauth_timeout: 180,
33            heartbeat_interval: 5,
34            heartbeat_timeout: 15,
35            ring_buffer_size: 1 << 20, // 1 MB
36            oauth_tunnel_idle_timeout: 5,
37        }
38    }
39}
40
41/// Resolved connect settings after merging all config layers.
42#[derive(Debug, Clone, Default, PartialEq, Eq)]
43pub struct ConnectSettings {
44    pub session: SessionSettings,
45    pub ssh_options: Vec<String>,
46    pub no_server_start: bool,
47}
48
49/// Top-level config file structure.
50#[derive(Debug, Clone, Default, Deserialize)]
51#[serde(default)]
52pub struct ConfigFile {
53    pub defaults: Defaults,
54    pub host: HashMap<String, HostConfig>,
55}
56
57/// Global defaults section.
58#[derive(Debug, Clone, Default, Deserialize)]
59#[serde(default, rename_all = "kebab-case")]
60pub struct Defaults {
61    pub forward_agent: Option<bool>,
62    pub forward_open: Option<bool>,
63    pub no_escape: Option<bool>,
64    pub no_redraw: Option<bool>,
65    pub oauth_redirect: Option<bool>,
66    pub oauth_timeout: Option<u64>,
67    pub heartbeat_interval: Option<u64>,
68    pub heartbeat_timeout: Option<u64>,
69    pub ring_buffer_size: Option<u64>,
70    pub oauth_tunnel_idle_timeout: Option<u64>,
71    pub connect: Option<ConnectDefaults>,
72}
73
74/// Connect-specific defaults nested under [defaults.connect].
75#[derive(Debug, Clone, Default, Deserialize)]
76#[serde(default, rename_all = "kebab-case")]
77pub struct ConnectDefaults {
78    pub ssh_options: Option<Vec<String>>,
79    pub no_server_start: Option<bool>,
80}
81
82/// Per-host override section.
83#[derive(Debug, Clone, Default, Deserialize)]
84#[serde(default, rename_all = "kebab-case")]
85pub struct HostConfig {
86    pub forward_agent: Option<bool>,
87    pub forward_open: Option<bool>,
88    pub no_escape: Option<bool>,
89    pub no_redraw: Option<bool>,
90    pub oauth_redirect: Option<bool>,
91    pub oauth_timeout: Option<u64>,
92    pub heartbeat_interval: Option<u64>,
93    pub heartbeat_timeout: Option<u64>,
94    pub ring_buffer_size: Option<u64>,
95    pub oauth_tunnel_idle_timeout: Option<u64>,
96    pub connect: Option<ConnectDefaults>,
97}
98
99/// Return the config file path: $XDG_CONFIG_HOME/gritty/config.toml
100pub fn config_path() -> PathBuf {
101    if let Some(proj) = directories::ProjectDirs::from("", "", "gritty") {
102        return proj.config_dir().join("config.toml");
103    }
104    PathBuf::from(".config").join("gritty").join("config.toml")
105}
106
107impl ConfigFile {
108    /// Load config from the default path. Returns default on missing or malformed file.
109    pub fn load() -> Self {
110        Self::load_from(&config_path())
111    }
112
113    /// Load config from a specific path. Returns default on missing or malformed file.
114    pub fn load_from(path: &std::path::Path) -> Self {
115        let content = match std::fs::read_to_string(path) {
116            Ok(c) => c,
117            Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Self::default(),
118            Err(e) => {
119                eprintln!("warning: cannot read config {}: {e}", path.display());
120                return Self::default();
121            }
122        };
123        match toml::from_str(&content) {
124            Ok(c) => c,
125            Err(e) => {
126                eprintln!("warning: malformed config at {}: {e}", path.display());
127                Self::default()
128            }
129        }
130    }
131
132    /// Resolve session settings for a given host (or None for local).
133    pub fn resolve_session(&self, host: Option<&str>) -> SessionSettings {
134        let d = &self.defaults;
135        let h = host.and_then(|name| self.host.get(name));
136
137        SessionSettings {
138            forward_agent: pick(h.and_then(|h| h.forward_agent), d.forward_agent),
139            forward_open: pick(h.and_then(|h| h.forward_open), d.forward_open),
140            no_escape: pick(h.and_then(|h| h.no_escape), d.no_escape),
141            no_redraw: pick(h.and_then(|h| h.no_redraw), d.no_redraw),
142            oauth_redirect: h.and_then(|h| h.oauth_redirect).or(d.oauth_redirect).unwrap_or(true),
143            oauth_timeout: h.and_then(|h| h.oauth_timeout).or(d.oauth_timeout).unwrap_or(180),
144            heartbeat_interval: h
145                .and_then(|h| h.heartbeat_interval)
146                .or(d.heartbeat_interval)
147                .unwrap_or(5),
148            heartbeat_timeout: h
149                .and_then(|h| h.heartbeat_timeout)
150                .or(d.heartbeat_timeout)
151                .unwrap_or(15),
152            ring_buffer_size: h
153                .and_then(|h| h.ring_buffer_size)
154                .or(d.ring_buffer_size)
155                .unwrap_or(1 << 20),
156            oauth_tunnel_idle_timeout: h
157                .and_then(|h| h.oauth_tunnel_idle_timeout)
158                .or(d.oauth_tunnel_idle_timeout)
159                .unwrap_or(5),
160        }
161    }
162
163    /// Resolve connect settings for a given host.
164    pub fn resolve_connect(&self, host: &str) -> ConnectSettings {
165        let d = &self.defaults;
166        let dc = d.connect.as_ref();
167        let h = self.host.get(host);
168        let hc = h.and_then(|h| h.connect.as_ref());
169
170        // ssh-options: host-specific first, then defaults (SSH uses first-match)
171        let mut ssh_options = Vec::new();
172        if let Some(opts) = hc.and_then(|c| c.ssh_options.as_ref()) {
173            ssh_options.extend(opts.iter().cloned());
174        }
175        if let Some(opts) = dc.and_then(|c| c.ssh_options.as_ref()) {
176            ssh_options.extend(opts.iter().cloned());
177        }
178
179        ConnectSettings {
180            session: self.resolve_session(Some(host)),
181            ssh_options,
182            no_server_start: pick(
183                hc.and_then(|c| c.no_server_start),
184                dc.and_then(|c| c.no_server_start),
185            ),
186        }
187    }
188}
189
190/// Pick the most specific value: host override > default > false.
191fn pick(host_val: Option<bool>, default_val: Option<bool>) -> bool {
192    host_val.or(default_val).unwrap_or(false)
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn empty_config_returns_defaults() {
201        let cfg: ConfigFile = toml::from_str("").unwrap();
202        let s = cfg.resolve_session(None);
203        assert_eq!(s, SessionSettings::default());
204    }
205
206    #[test]
207    fn defaults_apply_when_no_host() {
208        let cfg: ConfigFile = toml::from_str(
209            r#"
210            [defaults]
211            forward-agent = true
212            forward-open = true
213            "#,
214        )
215        .unwrap();
216        let s = cfg.resolve_session(None);
217        assert!(s.forward_agent);
218        assert!(s.forward_open);
219        assert!(!s.no_escape);
220    }
221
222    #[test]
223    fn host_overrides_defaults() {
224        let cfg: ConfigFile = toml::from_str(
225            r#"
226            [defaults]
227            forward-agent = true
228            forward-open = false
229
230            [host.devbox]
231            forward-agent = false
232            forward-open = true
233            "#,
234        )
235        .unwrap();
236        let s = cfg.resolve_session(Some("devbox"));
237        assert!(!s.forward_agent);
238        assert!(s.forward_open);
239    }
240
241    #[test]
242    fn unknown_host_uses_defaults() {
243        let cfg: ConfigFile = toml::from_str(
244            r#"
245            [defaults]
246            forward-agent = true
247
248            [host.devbox]
249            forward-open = true
250            "#,
251        )
252        .unwrap();
253        let s = cfg.resolve_session(Some("unknown"));
254        assert!(s.forward_agent);
255        assert!(!s.forward_open);
256    }
257
258    #[test]
259    fn host_partial_override_inherits_defaults() {
260        let cfg: ConfigFile = toml::from_str(
261            r#"
262            [defaults]
263            forward-agent = true
264            no-escape = true
265
266            [host.devbox]
267            forward-open = true
268            "#,
269        )
270        .unwrap();
271        let s = cfg.resolve_session(Some("devbox"));
272        assert!(s.forward_agent); // from defaults
273        assert!(s.forward_open); // from host
274        assert!(s.no_escape); // from defaults
275    }
276
277    #[test]
278    fn connect_settings_merge_ssh_options() {
279        let cfg: ConfigFile = toml::from_str(
280            r#"
281            [defaults.connect]
282            ssh-options = ["Compression=yes"]
283
284            [host.devbox.connect]
285            ssh-options = ["IdentityFile=~/.ssh/key"]
286            "#,
287        )
288        .unwrap();
289        let c = cfg.resolve_connect("devbox");
290        // Host-specific first, then defaults
291        assert_eq!(c.ssh_options, vec!["IdentityFile=~/.ssh/key", "Compression=yes"]);
292    }
293
294    #[test]
295    fn connect_settings_no_host_ssh_options() {
296        let cfg: ConfigFile = toml::from_str(
297            r#"
298            [defaults.connect]
299            ssh-options = ["Compression=yes"]
300            "#,
301        )
302        .unwrap();
303        let c = cfg.resolve_connect("unknown");
304        assert_eq!(c.ssh_options, vec!["Compression=yes"]);
305    }
306
307    #[test]
308    fn connect_no_server_start() {
309        let cfg: ConfigFile = toml::from_str(
310            r#"
311            [host.prod.connect]
312            no-server-start = true
313            "#,
314        )
315        .unwrap();
316        let c = cfg.resolve_connect("prod");
317        assert!(c.no_server_start);
318        assert!(!cfg.resolve_connect("devbox").no_server_start);
319    }
320
321    #[test]
322    fn missing_file_returns_default() {
323        let cfg = ConfigFile::load_from(std::path::Path::new("/nonexistent/config.toml"));
324        assert_eq!(cfg.resolve_session(None), SessionSettings::default());
325    }
326
327    #[test]
328    fn config_path_ends_with_expected_suffix() {
329        // Can't safely set env vars in tests (Rust 2024), but we can verify the
330        // function returns a path ending in gritty/config.toml
331        let p = config_path();
332        assert!(p.ends_with("gritty/config.toml"), "got: {}", p.display());
333    }
334
335    #[test]
336    fn no_redraw_configurable() {
337        let cfg: ConfigFile = toml::from_str(
338            r#"
339            [defaults]
340            no-redraw = true
341
342            [host.devbox]
343            no-redraw = false
344            "#,
345        )
346        .unwrap();
347        assert!(cfg.resolve_session(None).no_redraw);
348        assert!(cfg.resolve_session(Some("unknown")).no_redraw);
349        assert!(!cfg.resolve_session(Some("devbox")).no_redraw);
350    }
351
352    #[test]
353    fn oauth_settings_defaults() {
354        let cfg: ConfigFile = toml::from_str("").unwrap();
355        let s = cfg.resolve_session(None);
356        assert!(s.oauth_redirect);
357        assert_eq!(s.oauth_timeout, 180);
358    }
359
360    #[test]
361    fn oauth_settings_configurable() {
362        let cfg: ConfigFile = toml::from_str(
363            r#"
364            [defaults]
365            oauth-redirect = false
366            oauth-timeout = 60
367
368            [host.devbox]
369            oauth-redirect = true
370            oauth-timeout = 300
371            "#,
372        )
373        .unwrap();
374        let s = cfg.resolve_session(None);
375        assert!(!s.oauth_redirect);
376        assert_eq!(s.oauth_timeout, 60);
377
378        let s = cfg.resolve_session(Some("devbox"));
379        assert!(s.oauth_redirect);
380        assert_eq!(s.oauth_timeout, 300);
381    }
382
383    #[test]
384    fn oauth_settings_host_partial_override() {
385        let cfg: ConfigFile = toml::from_str(
386            r#"
387            [defaults]
388            oauth-timeout = 90
389
390            [host.devbox]
391            oauth-redirect = false
392            "#,
393        )
394        .unwrap();
395        let s = cfg.resolve_session(Some("devbox"));
396        assert!(!s.oauth_redirect); // from host
397        assert_eq!(s.oauth_timeout, 90); // from defaults
398    }
399
400    #[test]
401    fn oauth_tunnel_idle_timeout_configurable() {
402        let cfg: ConfigFile = toml::from_str(
403            r#"
404            [defaults]
405            oauth-tunnel-idle-timeout = 10
406
407            [host.devbox]
408            oauth-tunnel-idle-timeout = 30
409            "#,
410        )
411        .unwrap();
412        assert_eq!(cfg.resolve_session(None).oauth_tunnel_idle_timeout, 10);
413        assert_eq!(cfg.resolve_session(Some("devbox")).oauth_tunnel_idle_timeout, 30);
414        assert_eq!(cfg.resolve_session(Some("unknown")).oauth_tunnel_idle_timeout, 10);
415    }
416
417    #[test]
418    fn unknown_keys_ignored() {
419        let cfg: ConfigFile = toml::from_str(
420            r#"
421            [defaults]
422            forward-agent = true
423            some-future-setting = "ignored"
424            "#,
425        )
426        .unwrap();
427        assert!(cfg.resolve_session(None).forward_agent);
428    }
429
430    #[test]
431    fn connect_session_settings_resolved() {
432        let cfg: ConfigFile = toml::from_str(
433            r#"
434            [defaults]
435            forward-agent = true
436
437            [host.devbox]
438            forward-open = true
439            "#,
440        )
441        .unwrap();
442        let c = cfg.resolve_connect("devbox");
443        assert!(c.session.forward_agent);
444        assert!(c.session.forward_open);
445    }
446}