tl_cli/config/
manager.rs

1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::PathBuf;
6
7use crate::paths;
8use crate::style;
9use crate::ui::Style;
10
11/// Default settings in the `[tl]` section of config.toml.
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct TlConfig {
14    /// Default provider name.
15    pub provider: Option<String>,
16    /// Default model name.
17    pub model: Option<String>,
18    /// Default target language (ISO 639-1 code).
19    pub to: Option<String>,
20    /// Default translation style.
21    pub style: Option<String>,
22}
23
24/// Configuration for a translation provider.
25///
26/// Each provider has an endpoint and optional API key settings.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ProviderConfig {
29    /// The OpenAI-compatible API endpoint URL.
30    pub endpoint: String,
31    /// API key stored directly in config (not recommended).
32    #[serde(default)]
33    pub api_key: Option<String>,
34    /// Environment variable name containing the API key.
35    #[serde(default)]
36    pub api_key_env: Option<String>,
37    /// List of available models for this provider.
38    #[serde(default)]
39    pub models: Vec<String>,
40}
41
42impl ProviderConfig {
43    /// Gets the API key, preferring environment variable over config file.
44    pub fn get_api_key(&self) -> Option<String> {
45        if let Some(env_var) = &self.api_key_env
46            && let Ok(key) = std::env::var(env_var)
47            && !key.is_empty()
48        {
49            return Some(key);
50        }
51        self.api_key.clone()
52    }
53
54    /// Returns `true` if this provider requires an API key.
55    pub const fn requires_api_key(&self) -> bool {
56        self.api_key.is_some() || self.api_key_env.is_some()
57    }
58}
59
60/// A custom translation style defined by the user.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct CustomStyle {
63    /// Short description displayed in lists.
64    pub description: String,
65    /// The actual prompt sent to the LLM.
66    pub prompt: String,
67}
68
69/// The complete configuration file structure.
70///
71/// Corresponds to `~/.config/tl/config.toml`.
72#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct ConfigFile {
74    /// Default settings.
75    #[serde(default)]
76    pub tl: TlConfig,
77    /// Provider configurations keyed by name.
78    #[serde(default)]
79    pub providers: HashMap<String, ProviderConfig>,
80    /// Custom translation styles keyed by name.
81    #[serde(default)]
82    pub styles: HashMap<String, CustomStyle>,
83}
84
85/// Resolved configuration after merging CLI arguments and config file.
86#[derive(Debug, Clone)]
87pub struct ResolvedConfig {
88    /// The selected provider name.
89    pub provider_name: String,
90    /// The API endpoint URL.
91    pub endpoint: String,
92    /// The model to use for translation.
93    pub model: String,
94    /// The API key (if required).
95    pub api_key: Option<String>,
96    /// The target language code.
97    pub target_language: String,
98    /// The style name (for display).
99    pub style_name: Option<String>,
100    /// The resolved translation style prompt (for LLM).
101    pub style_prompt: Option<String>,
102}
103
104/// Options for resolving configuration.
105///
106/// Contains CLI overrides that take precedence over config file values.
107#[derive(Debug, Clone, Default)]
108pub struct ResolveOptions {
109    /// Target language code override.
110    pub to: Option<String>,
111    /// Provider name override.
112    pub provider: Option<String>,
113    /// Model name override.
114    pub model: Option<String>,
115    /// Style name override.
116    pub style: Option<String>,
117}
118
119/// Resolves configuration by merging CLI options with config file settings.
120///
121/// CLI options take precedence over config file values.
122///
123/// # Errors
124///
125/// Returns an error if required configuration (provider, model, target language)
126/// is missing or if the specified provider is not found.
127pub fn resolve_config(
128    options: &ResolveOptions,
129    config_file: &ConfigFile,
130) -> Result<ResolvedConfig> {
131    // Resolve provider
132    let provider_name = options
133        .provider
134        .as_ref()
135        .or(config_file.tl.provider.as_ref())
136        .cloned()
137        .ok_or_else(|| {
138            anyhow::anyhow!(
139                "Missing required configuration: 'provider'\n\n\
140                 Please provide it via:\n  \
141                 - CLI option: tl --provider <name>\n  \
142                 - Config file: ~/.config/tl/config.toml"
143            )
144        })?;
145
146    // Get provider config
147    let provider_config = config_file.providers.get(&provider_name).ok_or_else(|| {
148        let available: Vec<_> = config_file.providers.keys().collect();
149        if available.is_empty() {
150            anyhow::anyhow!(
151                "Provider '{provider_name}' not found\n\n\
152                 No providers configured. Add providers to ~/.config/tl/config.toml"
153            )
154        } else {
155            anyhow::anyhow!(
156                "Provider '{provider_name}' not found\n\n\
157                 Available providers:\n  \
158                 - {}\n\n\
159                 Add providers to ~/.config/tl/config.toml",
160                available
161                    .iter()
162                    .map(|s| s.as_str())
163                    .collect::<Vec<_>>()
164                    .join("\n  - ")
165            )
166        }
167    })?;
168
169    // Resolve model
170    let model = options
171        .model
172        .as_ref()
173        .or(config_file.tl.model.as_ref())
174        .cloned()
175        .ok_or_else(|| {
176            anyhow::anyhow!(
177                "Missing required configuration: 'model'\n\n\
178                 Please provide it via:\n  \
179                 - CLI option: tl --model <name>\n  \
180                 - Config file: ~/.config/tl/config.toml"
181            )
182        })?;
183
184    // Warn if model is not in provider's models list
185    if !provider_config.models.is_empty() && !provider_config.models.contains(&model) {
186        eprintln!(
187            "{} Model '{}' is not in the configured models list for '{}'\n\
188             Configured models: {}\n\
189             Proceeding anyway...\n",
190            Style::warning("Warning:"),
191            model,
192            provider_name,
193            provider_config.models.join(", ")
194        );
195    }
196
197    // Resolve target language
198    let target_language = options
199        .to
200        .as_ref()
201        .or(config_file.tl.to.as_ref())
202        .cloned()
203        .ok_or_else(|| {
204            anyhow::anyhow!(
205                "Missing required configuration: 'to' (target language)\n\n\
206                 Please provide it via:\n  \
207                 - CLI option: tl --to <lang>\n  \
208                 - Config file: ~/.config/tl/config.toml"
209            )
210        })?;
211
212    // Get API key
213    let api_key = provider_config.get_api_key();
214
215    // Check if API key is required but missing
216    if provider_config.requires_api_key() && api_key.is_none() {
217        let env_var = provider_config.api_key_env.as_deref().unwrap_or("API_KEY");
218        bail!(
219            "Provider '{provider_name}' requires an API key\n\n\
220             Set the {env_var} environment variable:\n  \
221             export {env_var}=\"your-api-key\"\n\n\
222             Or set api_key in ~/.config/tl/config.toml"
223        );
224    }
225
226    // Resolve style (optional)
227    let style_key = options.style.as_ref().or(config_file.tl.style.as_ref());
228
229    let (style_name, style_prompt) = if let Some(key) = style_key {
230        let resolved =
231            style::resolve_style(key, &config_file.styles).map_err(|e| anyhow::anyhow!("{e}"))?;
232        (Some(key.clone()), Some(resolved.prompt().to_string()))
233    } else {
234        (None, None)
235    };
236
237    Ok(ResolvedConfig {
238        provider_name,
239        endpoint: provider_config.endpoint.clone(),
240        model,
241        api_key,
242        target_language,
243        style_name,
244        style_prompt,
245    })
246}
247
248/// Manages loading and saving configuration files.
249pub struct ConfigManager {
250    config_path: PathBuf,
251}
252
253impl ConfigManager {
254    /// Creates a new config manager.
255    ///
256    /// Configuration is stored at `$XDG_CONFIG_HOME/tl/config.toml`
257    /// or `~/.config/tl/config.toml` if `XDG_CONFIG_HOME` is not set.
258    pub fn new() -> Result<Self> {
259        Ok(Self {
260            config_path: paths::config_dir()?.join("config.toml"),
261        })
262    }
263
264    pub const fn config_path(&self) -> &PathBuf {
265        &self.config_path
266    }
267
268    pub fn load(&self) -> Result<ConfigFile> {
269        let contents = fs::read_to_string(&self.config_path).with_context(|| {
270            format!("Failed to read config file: {}", self.config_path.display())
271        })?;
272
273        let config_file: ConfigFile =
274            toml::from_str(&contents).with_context(|| "Failed to parse config file")?;
275
276        Ok(config_file)
277    }
278
279    pub fn save(&self, config: &ConfigFile) -> Result<()> {
280        if let Some(parent) = self.config_path.parent() {
281            fs::create_dir_all(parent).with_context(|| {
282                format!("Failed to create config directory: {}", parent.display())
283            })?;
284        }
285
286        let contents = toml::to_string_pretty(config).context("Failed to serialize config")?;
287
288        fs::write(&self.config_path, contents).with_context(|| {
289            format!(
290                "Failed to write config file: {}",
291                self.config_path.display()
292            )
293        })?;
294
295        Ok(())
296    }
297
298    /// Loads the config file, returning defaults only if the file doesn't exist.
299    ///
300    /// Returns an error if the config file exists but cannot be parsed.
301    /// This prevents accidental overwrite of a malformed config file when saving.
302    pub fn load_or_default(&self) -> Result<ConfigFile> {
303        match fs::read_to_string(&self.config_path) {
304            Ok(contents) => {
305                toml::from_str(&contents).with_context(|| "Failed to parse config file")
306            }
307            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(ConfigFile::default()),
308            Err(e) => Err(anyhow::anyhow!(
309                "Failed to read config file: {}: {}",
310                self.config_path.display(),
311                e
312            )),
313        }
314    }
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used)]
319mod tests {
320    use super::*;
321    use tempfile::TempDir;
322
323    fn create_test_manager(temp_dir: &TempDir) -> ConfigManager {
324        ConfigManager {
325            config_path: temp_dir.path().join("config.toml"),
326        }
327    }
328
329    #[test]
330    fn test_save_and_load_config() {
331        let temp_dir = TempDir::new().unwrap();
332        let manager = create_test_manager(&temp_dir);
333
334        let mut providers = HashMap::new();
335        providers.insert(
336            "ollama".to_string(),
337            ProviderConfig {
338                endpoint: "http://localhost:11434".to_string(),
339                api_key: None,
340                api_key_env: None,
341                models: vec!["gemma3:12b".to_string(), "llama3.2".to_string()],
342            },
343        );
344
345        let config = ConfigFile {
346            tl: TlConfig {
347                provider: Some("ollama".to_string()),
348                model: Some("gemma3:12b".to_string()),
349                to: Some("ja".to_string()),
350                style: None,
351            },
352            providers,
353            styles: HashMap::new(),
354        };
355
356        manager.save(&config).unwrap();
357        let loaded = manager.load().unwrap();
358
359        assert_eq!(loaded.tl.provider, Some("ollama".to_string()));
360        assert_eq!(loaded.tl.model, Some("gemma3:12b".to_string()));
361        assert_eq!(loaded.tl.to, Some("ja".to_string()));
362        assert!(loaded.providers.contains_key("ollama"));
363    }
364
365    #[test]
366    fn test_load_nonexistent_config() {
367        let temp_dir = TempDir::new().unwrap();
368        let manager = create_test_manager(&temp_dir);
369
370        let result = manager.load();
371        assert!(result.is_err());
372    }
373
374    #[test]
375    fn test_provider_get_api_key_from_env() {
376        // SAFETY: This test runs in isolation and only modifies a test-specific env var
377        unsafe {
378            std::env::set_var("TEST_API_KEY", "test-key-value");
379        }
380
381        let provider = ProviderConfig {
382            endpoint: "https://api.example.com".to_string(),
383            api_key: Some("fallback-key".to_string()),
384            api_key_env: Some("TEST_API_KEY".to_string()),
385            models: vec![],
386        };
387
388        // Environment variable takes priority
389        assert_eq!(provider.get_api_key(), Some("test-key-value".to_string()));
390
391        // SAFETY: Cleanup test env var
392        unsafe {
393            std::env::remove_var("TEST_API_KEY");
394        }
395    }
396
397    #[test]
398    fn test_provider_get_api_key_fallback() {
399        // SAFETY: This test runs in isolation and only modifies a test-specific env var
400        unsafe {
401            std::env::remove_var("NONEXISTENT_KEY");
402        }
403
404        let provider = ProviderConfig {
405            endpoint: "https://api.example.com".to_string(),
406            api_key: Some("fallback-key".to_string()),
407            api_key_env: Some("NONEXISTENT_KEY".to_string()),
408            models: vec![],
409        };
410
411        // Falls back to api_key when env var not set
412        assert_eq!(provider.get_api_key(), Some("fallback-key".to_string()));
413    }
414
415    #[test]
416    fn test_provider_requires_api_key() {
417        let provider_with_key = ProviderConfig {
418            endpoint: "https://api.example.com".to_string(),
419            api_key: Some("key".to_string()),
420            api_key_env: None,
421            models: vec![],
422        };
423        assert!(provider_with_key.requires_api_key());
424
425        let provider_with_env = ProviderConfig {
426            endpoint: "https://api.example.com".to_string(),
427            api_key: None,
428            api_key_env: Some("API_KEY".to_string()),
429            models: vec![],
430        };
431        assert!(provider_with_env.requires_api_key());
432
433        let provider_without = ProviderConfig {
434            endpoint: "http://localhost:11434".to_string(),
435            api_key: None,
436            api_key_env: None,
437            models: vec![],
438        };
439        assert!(!provider_without.requires_api_key());
440    }
441
442    // resolve_config tests
443
444    fn create_test_options() -> ResolveOptions {
445        ResolveOptions {
446            to: Some("ja".to_string()),
447            provider: Some("ollama".to_string()),
448            model: Some("gemma3:12b".to_string()),
449            style: None,
450        }
451    }
452
453    fn create_test_config() -> ConfigFile {
454        let mut providers = HashMap::new();
455        providers.insert(
456            "ollama".to_string(),
457            ProviderConfig {
458                endpoint: "http://localhost:11434".to_string(),
459                api_key: None,
460                api_key_env: None,
461                models: vec!["gemma3:12b".to_string()],
462            },
463        );
464        providers.insert(
465            "openrouter".to_string(),
466            ProviderConfig {
467                endpoint: "https://openrouter.ai/api".to_string(),
468                api_key: None,
469                api_key_env: Some("TL_TEST_NONEXISTENT_API_KEY".to_string()),
470                models: vec!["gpt-4o".to_string()],
471            },
472        );
473
474        ConfigFile {
475            tl: TlConfig {
476                provider: Some("ollama".to_string()),
477                model: Some("gemma3:12b".to_string()),
478                to: Some("ja".to_string()),
479                style: None,
480            },
481            providers,
482            styles: HashMap::new(),
483        }
484    }
485
486    #[test]
487    fn test_resolve_config_with_cli_options() {
488        let options = create_test_options();
489        let config = create_test_config();
490
491        let resolved = resolve_config(&options, &config).unwrap();
492
493        assert_eq!(resolved.provider_name, "ollama");
494        assert_eq!(resolved.endpoint, "http://localhost:11434");
495        assert_eq!(resolved.model, "gemma3:12b");
496        assert_eq!(resolved.target_language, "ja");
497        assert!(resolved.api_key.is_none());
498    }
499
500    #[test]
501    fn test_resolve_config_cli_overrides_file() {
502        let mut options = create_test_options();
503        options.to = Some("en".to_string());
504        options.model = Some("llama3".to_string());
505
506        let config = create_test_config();
507
508        let resolved = resolve_config(&options, &config).unwrap();
509
510        assert_eq!(resolved.target_language, "en");
511        assert_eq!(resolved.model, "llama3");
512    }
513
514    #[test]
515    fn test_resolve_config_falls_back_to_file() {
516        let options = ResolveOptions::default();
517        let config = create_test_config();
518
519        let resolved = resolve_config(&options, &config).unwrap();
520
521        assert_eq!(resolved.provider_name, "ollama");
522        assert_eq!(resolved.model, "gemma3:12b");
523        assert_eq!(resolved.target_language, "ja");
524    }
525
526    #[test]
527    fn test_resolve_config_missing_provider() {
528        let options = ResolveOptions {
529            to: Some("ja".to_string()),
530            provider: None,
531            model: Some("model".to_string()),
532            style: None,
533        };
534        let config = ConfigFile::default();
535
536        let result = resolve_config(&options, &config);
537
538        assert!(result.is_err());
539        assert!(result.unwrap_err().to_string().contains("provider"));
540    }
541
542    #[test]
543    fn test_resolve_config_provider_not_found() {
544        let mut options = create_test_options();
545        options.provider = Some("nonexistent".to_string());
546
547        let config = create_test_config();
548
549        let result = resolve_config(&options, &config);
550
551        assert!(result.is_err());
552        assert!(result.unwrap_err().to_string().contains("not found"));
553    }
554
555    #[test]
556    fn test_resolve_config_missing_model() {
557        let mut options = create_test_options();
558        options.model = None;
559
560        let mut config = create_test_config();
561        config.tl.model = None;
562
563        let result = resolve_config(&options, &config);
564
565        assert!(result.is_err());
566        assert!(result.unwrap_err().to_string().contains("model"));
567    }
568
569    #[test]
570    fn test_resolve_config_missing_target_language() {
571        let mut options = create_test_options();
572        options.to = None;
573
574        let mut config = create_test_config();
575        config.tl.to = None;
576
577        let result = resolve_config(&options, &config);
578
579        assert!(result.is_err());
580        assert!(result.unwrap_err().to_string().contains("to"));
581    }
582
583    #[test]
584    fn test_resolve_config_api_key_required_but_missing() {
585        let mut options = create_test_options();
586        options.provider = Some("openrouter".to_string());
587
588        let config = create_test_config();
589
590        let result = resolve_config(&options, &config);
591
592        assert!(result.is_err());
593        assert!(result.unwrap_err().to_string().contains("API key"));
594    }
595
596    #[test]
597    fn test_load_or_default_nonexistent_file() {
598        let temp_dir = TempDir::new().unwrap();
599        let manager = create_test_manager(&temp_dir);
600
601        // Should return default when file doesn't exist
602        let result = manager.load_or_default();
603        assert!(result.is_ok());
604        let config = result.unwrap();
605        assert!(config.providers.is_empty());
606    }
607
608    #[test]
609    fn test_load_or_default_valid_file() {
610        let temp_dir = TempDir::new().unwrap();
611        let manager = create_test_manager(&temp_dir);
612
613        // Create a valid config file
614        let config = create_test_config();
615        manager.save(&config).unwrap();
616
617        // Should load the valid config
618        let result = manager.load_or_default();
619        assert!(result.is_ok());
620        let loaded = result.unwrap();
621        assert_eq!(loaded.tl.provider, Some("ollama".to_string()));
622    }
623
624    #[test]
625    fn test_load_or_default_invalid_file() {
626        let temp_dir = TempDir::new().unwrap();
627        let manager = create_test_manager(&temp_dir);
628
629        // Create an invalid TOML file
630        std::fs::write(&manager.config_path, "invalid toml [[[").unwrap();
631
632        // Should return error for malformed config
633        let result = manager.load_or_default();
634        assert!(result.is_err());
635        assert!(result.unwrap_err().to_string().contains("parse"));
636    }
637
638    #[test]
639    #[cfg(unix)]
640    fn test_load_or_default_unreadable_file() {
641        use std::os::unix::fs::PermissionsExt;
642
643        let temp_dir = TempDir::new().unwrap();
644        let manager = create_test_manager(&temp_dir);
645
646        // Create a valid config file
647        std::fs::write(&manager.config_path, "[tl]\nprovider = \"test\"").unwrap();
648
649        // Make it unreadable (no read permissions)
650        let mut perms = std::fs::metadata(&manager.config_path)
651            .unwrap()
652            .permissions();
653        perms.set_mode(0o000);
654        std::fs::set_permissions(&manager.config_path, perms).unwrap();
655
656        // Should return error for unreadable file
657        let result = manager.load_or_default();
658        assert!(result.is_err());
659
660        // Restore permissions for cleanup
661        let mut perms = std::fs::metadata(&manager.config_path)
662            .unwrap()
663            .permissions();
664        perms.set_mode(0o644);
665        std::fs::set_permissions(&manager.config_path, perms).unwrap();
666    }
667}