Skip to main content

llm_cascade/
config.rs

1//! Configuration loading and types for `config.toml`.
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use serde::{Deserialize, Serialize};
7
8/// Top-level application configuration.
9#[derive(Debug, Deserialize, Serialize)]
10pub struct AppConfig {
11    /// Named provider definitions (keyed by provider name).
12    #[serde(default)]
13    pub providers: HashMap<String, ProviderConfig>,
14    /// Named cascade definitions (keyed by cascade name).
15    #[serde(default)]
16    pub cascades: HashMap<String, CascadeConfig>,
17    /// SQLite database configuration.
18    #[serde(default)]
19    pub database: DatabaseConfig,
20    /// Failed prompt persistence configuration.
21    #[serde(default)]
22    pub failure_persistence: FailureConfig,
23}
24
25/// Configuration for a single LLM provider endpoint.
26#[derive(Debug, Deserialize, Serialize, Clone)]
27pub struct ProviderConfig {
28    /// The provider protocol type.
29    pub r#type: ProviderType,
30    /// Keyring service name for API key lookup.
31    #[serde(default)]
32    pub api_key_service: Option<String>,
33    /// Environment variable name for API key fallback.
34    #[serde(default)]
35    pub api_key_env: Option<String>,
36    /// Override the default base URL for this provider.
37    #[serde(default)]
38    pub base_url: Option<String>,
39}
40
41/// Supported provider protocol types.
42#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
43#[serde(rename_all = "lowercase")]
44pub enum ProviderType {
45    /// OpenAI Chat Completions API (and compatible endpoints like Groq, Together, vLLM, Ollama Cloud).
46    Openai,
47    /// Anthropic Messages API.
48    Anthropic,
49    /// Google Gemini generateContent API.
50    Gemini,
51    /// Ollama local inference server.
52    Ollama,
53}
54
55/// A single entry in a cascade, referencing a provider and model.
56#[derive(Debug, Deserialize, Serialize, Clone)]
57pub struct CascadeEntry {
58    /// The provider name (must match a key in `[providers]`).
59    pub provider: String,
60    /// The model identifier to use with this provider.
61    pub model: String,
62}
63
64/// An ordered list of provider/model entries to try in sequence.
65#[derive(Debug, Deserialize, Serialize)]
66pub struct CascadeConfig {
67    /// The cascade entries, tried in order until one succeeds.
68    pub entries: Vec<CascadeEntry>,
69}
70
71/// SQLite database path configuration.
72#[derive(Debug, Deserialize, Serialize)]
73pub struct DatabaseConfig {
74    /// Path to the SQLite database file. Tilde (`~`) is expanded.
75    #[serde(default = "default_db_path")]
76    pub path: String,
77}
78
79fn default_db_path() -> String {
80    "~/.local/share/llm-cascade/db.sqlite".to_string()
81}
82
83impl Default for DatabaseConfig {
84    fn default() -> Self {
85        Self {
86            path: default_db_path(),
87        }
88    }
89}
90
91/// Failed prompt persistence directory configuration.
92#[derive(Debug, Deserialize, Serialize)]
93pub struct FailureConfig {
94    /// Directory where failed conversation `.json` files are saved. Tilde is expanded.
95    #[serde(default = "default_failure_dir")]
96    pub dir: String,
97}
98
99fn default_failure_dir() -> String {
100    "~/.local/share/llm-cascade/failed_prompts".to_string()
101}
102
103impl Default for FailureConfig {
104    fn default() -> Self {
105        Self {
106            dir: default_failure_dir(),
107        }
108    }
109}
110
111/// Expands a leading `~/` to the user's home directory.
112pub fn expand_tilde(path: &str) -> PathBuf {
113    if let Some(rest) = path.strip_prefix("~/")
114        && let Some(home) = dirs_home()
115    {
116        return home.join(rest);
117    }
118    PathBuf::from(path)
119}
120
121fn dirs_home() -> Option<PathBuf> {
122    std::env::var("HOME").ok().map(PathBuf::from)
123}
124
125/// Reads and parses a TOML configuration file.
126pub fn load_config(path: &Path) -> Result<AppConfig, String> {
127    let content = std::fs::read_to_string(path)
128        .map_err(|e| format!("Failed to read config file '{}': {}", path.display(), e))?;
129    let config: AppConfig = toml::from_str(&content)
130        .map_err(|e| format!("Failed to parse config file '{}': {}", path.display(), e))?;
131    Ok(config)
132}
133
134/// Returns the default configuration path (`$XDG_CONFIG_HOME/llm-cascade/config.toml`).
135pub fn default_config_path() -> PathBuf {
136    let config_dir = std::env::var("XDG_CONFIG_HOME")
137        .map(PathBuf::from)
138        .unwrap_or_else(|_| expand_tilde("~/.config"));
139    config_dir.join("llm-cascade").join("config.toml")
140}