Skip to main content

tuillem_config/
lib.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use directories::ProjectDirs;
5use serde::{Deserialize, Serialize};
6
7// ---------------------------------------------------------------------------
8// Errors
9// ---------------------------------------------------------------------------
10
11#[derive(Debug, thiserror::Error)]
12pub enum ConfigError {
13    #[error("IO error: {0}")]
14    Io(#[from] std::io::Error),
15
16    #[error("YAML parse error: {0}")]
17    Parse(#[from] serde_yaml::Error),
18
19    #[error("Validation error: {0}")]
20    Validation(String),
21}
22
23// ---------------------------------------------------------------------------
24// Enums
25// ---------------------------------------------------------------------------
26
27#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum KeybindingPreset {
30    Vim,
31    Emacs,
32    #[default]
33    Default,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37#[serde(rename_all = "lowercase")]
38pub enum ProviderType {
39    Anthropic,
40    Openai,
41    Openrouter,
42    Ollama,
43}
44
45// ---------------------------------------------------------------------------
46// ThemeColors
47// ---------------------------------------------------------------------------
48
49#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
50pub struct ThemeColors {
51    pub bg: Option<String>,
52    pub fg: Option<String>,
53    pub sidebar_bg: Option<String>,
54    pub sidebar_fg: Option<String>,
55    pub sidebar_selected: Option<String>,
56    pub user_msg_bg: Option<String>,
57    pub assistant_msg_bg: Option<String>,
58    pub thinking_fg: Option<String>,
59    pub accent: Option<String>,
60    pub error: Option<String>,
61    pub success: Option<String>,
62    pub warning: Option<String>,
63    pub border: Option<String>,
64    pub code_bg: Option<String>,
65    pub code_fg: Option<String>,
66    pub heading: Option<String>,
67    pub link: Option<String>,
68    pub tag: Option<String>,
69    pub sidebar_selected_bg: Option<String>,
70}
71
72// ---------------------------------------------------------------------------
73// ProviderConfig
74// ---------------------------------------------------------------------------
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct ProviderConfig {
78    pub name: String,
79    pub provider_type: ProviderType,
80    pub api_key: Option<String>,
81    pub base_url: Option<String>,
82    pub default_model: Option<String>,
83    #[serde(default)]
84    pub models: Vec<String>,
85}
86
87// ---------------------------------------------------------------------------
88// DefaultsConfig
89// ---------------------------------------------------------------------------
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
92pub struct DefaultsConfig {
93    pub provider: Option<String>,
94    pub model: Option<String>,
95    pub system_prompt: Option<String>,
96}
97
98// ---------------------------------------------------------------------------
99// ToolConfig
100// ---------------------------------------------------------------------------
101
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
103pub struct ToolConfig {
104    pub name: String,
105    pub description: String,
106    pub command: String,
107    pub input_schema: Option<serde_json::Value>,
108    #[serde(default = "default_timeout")]
109    pub timeout: String,
110    #[serde(default)]
111    pub confirm: bool,
112    #[serde(default)]
113    pub env: HashMap<String, String>,
114}
115
116fn default_timeout() -> String {
117    "30s".to_string()
118}
119
120// ---------------------------------------------------------------------------
121// DatabaseConfig
122// ---------------------------------------------------------------------------
123
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
125pub struct DatabaseConfig {
126    #[serde(default = "default_database_path")]
127    pub path: String,
128}
129
130impl Default for DatabaseConfig {
131    fn default() -> Self {
132        Self {
133            path: default_database_path(),
134        }
135    }
136}
137
138fn default_database_path() -> String {
139    ProjectDirs::from("com", "tuillem", "tuillem")
140        .map(|dirs| {
141            dirs.data_dir()
142                .join("tuillem.db")
143                .to_string_lossy()
144                .into_owned()
145        })
146        .unwrap_or_else(|| "tuillem.db".to_string())
147}
148
149// ---------------------------------------------------------------------------
150// UiConfig
151// ---------------------------------------------------------------------------
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154pub struct UiConfig {
155    #[serde(default = "default_sidebar_width")]
156    pub sidebar_width: u16,
157    #[serde(default)]
158    pub show_thinking: bool,
159    #[serde(default = "default_true")]
160    pub show_token_usage: bool,
161    #[serde(default = "default_true")]
162    pub mouse: bool,
163    #[serde(default)]
164    pub show_stats: bool,
165    #[serde(default = "default_layout")]
166    pub layout: String,
167    #[serde(default = "default_date_format")]
168    pub date_format: String,
169    #[serde(default = "default_scroll_lines")]
170    pub scroll_lines: u16,
171    #[serde(default = "default_command_prefix")]
172    pub command_prefix: String,
173    #[serde(default = "default_true")]
174    pub nerd_fonts: bool,
175    #[serde(default = "default_color_mode")]
176    pub color_mode: String,
177    #[serde(default = "default_stream_visible_lines")]
178    pub stream_visible_lines: u16,
179}
180
181fn default_stream_visible_lines() -> u16 {
182    10
183}
184
185fn default_color_mode() -> String {
186    "auto".to_string()
187}
188
189fn default_command_prefix() -> String {
190    "/".to_string()
191}
192
193fn default_scroll_lines() -> u16 {
194    5
195}
196
197fn default_date_format() -> String {
198    "dd/mm/yyyy".to_string()
199}
200
201impl Default for UiConfig {
202    fn default() -> Self {
203        Self {
204            sidebar_width: 30,
205            show_thinking: false,
206            show_token_usage: true,
207            mouse: true,
208            show_stats: false,
209            layout: default_layout(),
210            date_format: default_date_format(),
211            scroll_lines: default_scroll_lines(),
212            command_prefix: default_command_prefix(),
213            nerd_fonts: true,
214            color_mode: default_color_mode(),
215            stream_visible_lines: default_stream_visible_lines(),
216        }
217    }
218}
219
220fn default_layout() -> String {
221    "loose".to_string()
222}
223
224fn default_sidebar_width() -> u16 {
225    30
226}
227
228fn default_true() -> bool {
229    true
230}
231
232// ---------------------------------------------------------------------------
233// Config (top-level)
234// ---------------------------------------------------------------------------
235
236fn default_editor() -> String {
237    std::env::var("VISUAL")
238        .or_else(|_| std::env::var("EDITOR"))
239        .unwrap_or_else(|_| "vi".to_string())
240}
241
242fn default_theme() -> String {
243    "dark".to_string()
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct Config {
248    #[serde(default = "default_editor")]
249    pub editor: String,
250
251    #[serde(default)]
252    pub keybindings: KeybindingPreset,
253
254    #[serde(default = "default_theme")]
255    pub theme: String,
256
257    #[serde(default)]
258    pub themes: HashMap<String, ThemeColors>,
259
260    #[serde(default)]
261    pub providers: Vec<ProviderConfig>,
262
263    #[serde(default)]
264    pub defaults: DefaultsConfig,
265
266    #[serde(default)]
267    pub tools: Vec<ToolConfig>,
268
269    #[serde(default)]
270    pub database: DatabaseConfig,
271
272    #[serde(default)]
273    pub ui: UiConfig,
274}
275
276impl Default for Config {
277    fn default() -> Self {
278        Self {
279            editor: default_editor(),
280            keybindings: KeybindingPreset::Default,
281            theme: "dark".to_string(),
282            themes: HashMap::new(),
283            providers: Vec::new(),
284            defaults: DefaultsConfig::default(),
285            tools: Vec::new(),
286            database: DatabaseConfig::default(),
287            ui: UiConfig::default(),
288        }
289    }
290}
291
292impl Config {
293    /// Parse a YAML string into a `Config`, then validate it.
294    /// Expands `${VAR}` patterns from environment variables before parsing.
295    pub fn from_yaml(yaml: &str) -> Result<Config, ConfigError> {
296        let expanded = expand_env_vars(yaml);
297        let config: Config = serde_yaml::from_str(&expanded)?;
298        config.validate()?;
299        Ok(config)
300    }
301
302    /// Read a file and parse it as YAML config.
303    pub fn from_file(path: &Path) -> Result<Config, ConfigError> {
304        let contents = std::fs::read_to_string(path)?;
305        Self::from_yaml(&contents)
306    }
307
308    /// Return the default XDG config path for the config file.
309    pub fn default_path() -> PathBuf {
310        ProjectDirs::from("com", "tuillem", "tuillem")
311            .map(|dirs| dirs.config_dir().join("config.yaml"))
312            .unwrap_or_else(|| PathBuf::from("config.yaml"))
313    }
314
315    /// Validate the configuration.
316    pub fn validate(&self) -> Result<(), ConfigError> {
317        // API-based providers must have an api_key.
318        for provider in &self.providers {
319            let needs_key = matches!(
320                provider.provider_type,
321                ProviderType::Anthropic | ProviderType::Openai | ProviderType::Openrouter
322            );
323            if needs_key && provider.api_key.is_none() {
324                return Err(ConfigError::Validation(format!(
325                    "Provider '{}' requires an api_key",
326                    provider.name
327                )));
328            }
329        }
330
331        // Default provider must exist in the providers list.
332        if let Some(ref default_provider) = self.defaults.provider {
333            let exists = self.providers.iter().any(|p| &p.name == default_provider);
334            if !exists {
335                return Err(ConfigError::Validation(format!(
336                    "Default provider '{}' not found in providers list",
337                    default_provider
338                )));
339            }
340        }
341
342        Ok(())
343    }
344}
345
346pub fn version() -> &'static str {
347    env!("CARGO_PKG_VERSION")
348}
349
350// ---------------------------------------------------------------------------
351// Tests
352// ---------------------------------------------------------------------------
353
354/// Expand `${VAR}` and `${VAR:-default}` patterns from environment variables.
355/// Leaves the pattern as-is if the variable is not set and no default is given.
356fn expand_env_vars(input: &str) -> String {
357    let mut result = String::with_capacity(input.len());
358    let mut chars = input.chars().peekable();
359
360    while let Some(c) = chars.next() {
361        if c == '$' && chars.peek() == Some(&'{') {
362            chars.next(); // consume '{'
363            let mut var_expr = String::new();
364            let mut found_close = false;
365            for ch in chars.by_ref() {
366                if ch == '}' {
367                    found_close = true;
368                    break;
369                }
370                var_expr.push(ch);
371            }
372            if found_close {
373                // Check for default: ${VAR:-default}
374                let (var_name, default_val) = if let Some(pos) = var_expr.find(":-") {
375                    (&var_expr[..pos], Some(&var_expr[pos + 2..]))
376                } else {
377                    (var_expr.as_str(), None)
378                };
379
380                match std::env::var(var_name) {
381                    Ok(val) if !val.is_empty() => result.push_str(&val),
382                    _ => {
383                        if let Some(def) = default_val {
384                            result.push_str(def);
385                        } else {
386                            // Leave unexpanded so the user sees it's not set
387                            result.push_str(&format!("${{{}}}", var_expr));
388                        }
389                    }
390                }
391            } else {
392                // Unclosed ${, write it literally
393                result.push('$');
394                result.push('{');
395                result.push_str(&var_expr);
396            }
397        } else {
398            result.push(c);
399        }
400    }
401    result
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_minimal_config() {
410        let config = Config::from_yaml("{}").expect("should parse empty config");
411        assert_eq!(config.theme, "dark");
412        assert_eq!(config.keybindings, KeybindingPreset::Default);
413        assert!(config.providers.is_empty());
414        assert!(config.tools.is_empty());
415        assert_eq!(config.ui.sidebar_width, 30);
416        assert!(!config.ui.show_thinking);
417        assert!(config.ui.show_token_usage);
418        assert!(config.ui.mouse);
419        assert_eq!(config.ui.layout, "loose");
420    }
421
422    #[test]
423    fn test_full_config() {
424        let yaml = r##"
425editor: nvim
426keybindings: vim
427theme: dark
428themes:
429  dark:
430    bg: "#1e1e2e"
431    fg: "#cdd6f4"
432providers:
433  - name: anthropic
434    provider_type: anthropic
435    api_key: "sk-ant-test"
436    default_model: claude-sonnet-4-20250514
437    models:
438      - claude-sonnet-4-20250514
439      - claude-3-haiku-20240307
440  - name: local
441    provider_type: ollama
442    base_url: "http://localhost:11434"
443    models:
444      - llama3
445defaults:
446  provider: anthropic
447  model: claude-sonnet-4-20250514
448  system_prompt: "You are a helpful assistant."
449tools:
450  - name: grep_tool
451    description: "Search files"
452    command: "grep -rn"
453    timeout: "10s"
454    confirm: true
455    env:
456      LANG: "en_US.UTF-8"
457database:
458  path: "/tmp/test.db"
459ui:
460  sidebar_width: 40
461  show_thinking: true
462  show_token_usage: false
463  mouse: false
464"##;
465        let config = Config::from_yaml(yaml).expect("should parse full config");
466        assert_eq!(config.editor, "nvim");
467        assert_eq!(config.keybindings, KeybindingPreset::Vim);
468        assert_eq!(config.theme, "dark");
469        assert_eq!(config.providers.len(), 2);
470        assert_eq!(config.providers[0].name, "anthropic");
471        assert_eq!(config.providers[0].api_key.as_deref(), Some("sk-ant-test"));
472        assert_eq!(config.providers[1].provider_type, ProviderType::Ollama);
473        assert_eq!(config.defaults.provider.as_deref(), Some("anthropic"));
474        assert_eq!(config.tools.len(), 1);
475        assert_eq!(config.tools[0].timeout, "10s");
476        assert!(config.tools[0].confirm);
477        assert_eq!(config.database.path, "/tmp/test.db");
478        assert_eq!(config.ui.sidebar_width, 40);
479        assert!(config.ui.show_thinking);
480        assert!(!config.ui.show_token_usage);
481        assert!(!config.ui.mouse);
482
483        // Theme check
484        let dark = config.themes.get("dark").expect("dark theme should exist");
485        assert_eq!(dark.bg.as_deref(), Some("#1e1e2e"));
486    }
487
488    #[test]
489    fn test_validation_missing_api_key() {
490        let yaml = "
491providers:
492  - name: anthropic
493    provider_type: anthropic
494";
495        let result = Config::from_yaml(yaml);
496        assert!(result.is_err());
497        let err = result.unwrap_err().to_string();
498        assert!(
499            err.contains("requires an api_key"),
500            "Expected api_key error, got: {err}",
501        );
502    }
503
504    #[test]
505    fn test_validation_invalid_default_provider() {
506        let yaml = "
507providers:
508  - name: anthropic
509    provider_type: anthropic
510    api_key: sk-test
511defaults:
512  provider: nonexistent
513";
514        let result = Config::from_yaml(yaml);
515        assert!(result.is_err());
516        let err = result.unwrap_err().to_string();
517        assert!(
518            err.contains("not found in providers list"),
519            "Expected provider-not-found error, got: {err}",
520        );
521    }
522
523    #[test]
524    fn test_ollama_no_api_key_required() {
525        let yaml = "
526providers:
527  - name: local
528    provider_type: ollama
529    base_url: http://localhost:11434
530    models:
531      - llama3
532";
533        let config = Config::from_yaml(yaml).expect("ollama should not require api_key");
534        assert_eq!(config.providers.len(), 1);
535        assert_eq!(config.providers[0].provider_type, ProviderType::Ollama);
536        assert!(config.providers[0].api_key.is_none());
537    }
538
539    #[test]
540    fn test_env_var_expansion() {
541        // SAFETY: test runs single-threaded
542        unsafe { std::env::set_var("TUILLEM_TEST_KEY", "sk-test-12345") };
543        let result = expand_env_vars("api_key: ${TUILLEM_TEST_KEY}");
544        assert_eq!(result, "api_key: sk-test-12345");
545        unsafe { std::env::remove_var("TUILLEM_TEST_KEY") };
546    }
547
548    #[test]
549    fn test_env_var_default() {
550        unsafe { std::env::remove_var("TUILLEM_UNSET_VAR") };
551        let result = expand_env_vars("key: ${TUILLEM_UNSET_VAR:-fallback_value}");
552        assert_eq!(result, "key: fallback_value");
553    }
554
555    #[test]
556    fn test_env_var_unset_no_default() {
557        unsafe { std::env::remove_var("TUILLEM_MISSING") };
558        let result = expand_env_vars("key: ${TUILLEM_MISSING}");
559        assert_eq!(result, "key: ${TUILLEM_MISSING}");
560    }
561
562    #[test]
563    fn test_env_var_in_config() {
564        unsafe { std::env::set_var("TUILLEM_TEST_API", "sk-ant-real-key") };
565        let yaml = r#"
566providers:
567  - name: anthropic
568    provider_type: anthropic
569    api_key: "${TUILLEM_TEST_API}"
570    models:
571      - claude-sonnet-4-20250514
572"#;
573        let config = Config::from_yaml(yaml).unwrap();
574        assert_eq!(
575            config.providers[0].api_key.as_deref(),
576            Some("sk-ant-real-key")
577        );
578        unsafe { std::env::remove_var("TUILLEM_TEST_API") };
579    }
580}