nabla_cli/cli/
config.rs

1use anyhow::{Result, anyhow};
2use clap::Subcommand;
3use home;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7
8#[derive(Serialize, Deserialize, Default, Clone)]
9pub struct LLMProvider {
10    pub name: String,
11    pub provider_type: String, // "openai", "groq", "together", "local"
12    pub api_key: Option<String>,
13    pub base_url: String,
14    pub model: Option<String>,
15    pub default: bool,
16}
17
18#[derive(Serialize, Deserialize, Default)]
19pub struct LLMProvidersConfig {
20    pub providers: HashMap<String, LLMProvider>,
21}
22
23#[derive(Serialize, Deserialize, Default)]
24pub struct ConfigStore {
25    #[serde(default)]
26    settings: HashMap<String, String>,
27}
28
29impl ConfigStore {
30    pub fn new() -> Result<Self> {
31        let config_path = Self::get_config_path()?;
32        if config_path.exists() {
33            let content = fs::read_to_string(&config_path)?;
34            let config: ConfigStore = serde_json::from_str(&content)?;
35            Ok(config)
36        } else {
37            Ok(ConfigStore::default())
38        }
39    }
40
41    pub fn get_base_url(&self) -> Result<String> {
42        // Try config setting first, then environment variable, then error
43        if let Some(url) = self.get_setting("base_url")? {
44            return Ok(url);
45        }
46
47        if let Ok(url) = std::env::var("NABLA_BASE_URL") {
48            return Ok(url);
49        }
50
51        Err(anyhow!(
52            "No base URL configured. Set with 'nabla config set-base-url <url>' or NABLA_BASE_URL env var"
53        ))
54    }
55
56    pub fn get_setting(&self, key: &str) -> Result<Option<String>> {
57        Ok(self.settings.get(key).cloned())
58    }
59
60    pub fn set_setting(&mut self, key: &str, value: &str) -> Result<()> {
61        self.settings.insert(key.to_string(), value.to_string());
62        self.save()
63    }
64
65    pub fn set_base_url(&mut self, url: &str) -> Result<()> {
66        // Validate URL format
67        if !url.starts_with("http://") && !url.starts_with("https://") {
68            return Err(anyhow!("Base URL must start with http:// or https://"));
69        }
70        self.set_setting("base_url", url)?;
71        println!("✅ Base URL set to: {}", url);
72        Ok(())
73    }
74
75    pub fn list_settings(&self) -> Result<Vec<(String, String)>> {
76        let mut settings: Vec<(String, String)> = self
77            .settings
78            .iter()
79            .map(|(k, v)| (k.clone(), v.clone()))
80            .collect();
81        settings.sort_by(|a, b| a.0.cmp(&b.0));
82        Ok(settings)
83    }
84
85    fn get_config_path() -> Result<std::path::PathBuf> {
86        let home = home::home_dir().ok_or_else(|| anyhow!("Could not determine home directory"))?;
87        let config_dir = home.join(".nabla");
88        fs::create_dir_all(&config_dir)?;
89        Ok(config_dir.join("config.json"))
90    }
91
92    fn save(&self) -> Result<()> {
93        let config_path = Self::get_config_path()?;
94        let content = serde_json::to_string_pretty(&self)?;
95        fs::write(&config_path, content)?;
96        Ok(())
97    }
98}
99
100impl LLMProvidersConfig {
101    pub fn new() -> Result<Self> {
102        let providers_path = Self::get_providers_path()?;
103        if providers_path.exists() {
104            let content = fs::read_to_string(&providers_path)?;
105            let config: LLMProvidersConfig = serde_json::from_str(&content)?;
106            Ok(config)
107        } else {
108            Ok(LLMProvidersConfig::default())
109        }
110    }
111
112    pub fn add_provider(&mut self, provider: LLMProvider) -> Result<()> {
113        // If this is marked as default, unset other defaults
114        if provider.default {
115            for (_, existing_provider) in self.providers.iter_mut() {
116                existing_provider.default = false;
117            }
118        }
119
120        self.providers.insert(provider.name.clone(), provider);
121        self.save()
122    }
123
124    pub fn remove_provider(&mut self, name: &str) -> Result<()> {
125        self.providers.remove(name);
126        self.save()
127    }
128
129    pub fn get_provider(&self, name: &str) -> Option<&LLMProvider> {
130        self.providers.get(name)
131    }
132
133    pub fn get_default_provider(&self) -> Option<&LLMProvider> {
134        self.providers.values().find(|p| p.default)
135    }
136
137    pub fn list_providers(&self) -> Vec<&LLMProvider> {
138        let mut providers: Vec<&LLMProvider> = self.providers.values().collect();
139        providers.sort_by(|a, b| a.name.cmp(&b.name));
140        providers
141    }
142
143    pub fn set_default_provider(&mut self, name: &str) -> Result<()> {
144        // Unset all defaults first
145        for (_, provider) in self.providers.iter_mut() {
146            provider.default = false;
147        }
148
149        // Set the new default
150        if let Some(provider) = self.providers.get_mut(name) {
151            provider.default = true;
152            self.save()
153        } else {
154            Err(anyhow!("Provider '{}' not found", name))
155        }
156    }
157
158    fn get_providers_path() -> Result<std::path::PathBuf> {
159        let home = home::home_dir().ok_or_else(|| anyhow!("Could not determine home directory"))?;
160        let config_dir = home.join(".nabla");
161        fs::create_dir_all(&config_dir)?;
162        Ok(config_dir.join("llm_providers.json"))
163    }
164
165    fn save(&self) -> Result<()> {
166        let providers_path = Self::get_providers_path()?;
167        let content = serde_json::to_string_pretty(&self)?;
168        fs::write(&providers_path, content)?;
169        Ok(())
170    }
171}
172
173#[derive(Subcommand)]
174pub enum ConfigCommands {
175    Get {
176        key: String,
177    },
178    Set {
179        key: String,
180        value: String,
181    },
182    SetBaseUrl {
183        url: String,
184    },
185    List,
186
187    // LLM Provider management
188    AddProvider {
189        name: String,
190        #[arg(long)]
191        provider_type: String, // openai, groq, together, local
192        #[arg(long)]
193        api_key: Option<String>,
194        #[arg(long)]
195        base_url: String,
196        #[arg(long)]
197        model: Option<String>,
198        #[arg(long)]
199        default: bool,
200    },
201    RemoveProvider {
202        name: String,
203    },
204    ListProviders,
205    SetDefaultProvider {
206        name: String,
207    },
208}