reflex/semantic/
config.rs

1//! Configuration for semantic query feature
2
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::env;
6use std::path::Path;
7
8/// Semantic query configuration
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SemanticConfig {
11    /// Enable semantic query feature
12    #[serde(default = "default_enabled")]
13    pub enabled: bool,
14
15    /// LLM provider (openai, anthropic, groq)
16    #[serde(default = "default_provider")]
17    pub provider: String,
18
19    /// Optional model override (uses provider default if None)
20    #[serde(default)]
21    pub model: Option<String>,
22
23    /// Auto-execute generated commands without confirmation
24    #[serde(default)]
25    pub auto_execute: bool,
26
27    /// Enable agentic mode (multi-step reasoning with context gathering)
28    #[serde(default = "default_agentic_enabled")]
29    pub agentic_enabled: bool,
30
31    /// Maximum iterations for query refinement in agentic mode
32    #[serde(default = "default_max_iterations")]
33    pub max_iterations: usize,
34
35    /// Maximum tool calls per context gathering phase
36    #[serde(default = "default_max_tools")]
37    pub max_tools_per_phase: usize,
38
39    /// Enable result evaluation in agentic mode
40    #[serde(default = "default_evaluation_enabled")]
41    pub evaluation_enabled: bool,
42
43    /// Evaluation strictness (0.0-1.0, higher is stricter)
44    #[serde(default = "default_strictness")]
45    pub evaluation_strictness: f32,
46}
47
48fn default_enabled() -> bool {
49    true
50}
51
52fn default_provider() -> String {
53    "openai".to_string()
54}
55
56fn default_agentic_enabled() -> bool {
57    false // Disabled by default, opt-in for experimental feature
58}
59
60fn default_max_iterations() -> usize {
61    2
62}
63
64fn default_max_tools() -> usize {
65    5
66}
67
68fn default_evaluation_enabled() -> bool {
69    true
70}
71
72fn default_strictness() -> f32 {
73    0.5
74}
75
76impl Default for SemanticConfig {
77    fn default() -> Self {
78        Self {
79            enabled: true,
80            provider: "openai".to_string(),
81            model: None,
82            auto_execute: false,
83            agentic_enabled: false,
84            max_iterations: 2,
85            max_tools_per_phase: 5,
86            evaluation_enabled: true,
87            evaluation_strictness: 0.5,
88        }
89    }
90}
91
92/// Load semantic config from ~/.reflex/config.toml
93///
94/// Semantic configuration is ALWAYS user-level (not project-level).
95/// Falls back to defaults if file doesn't exist or [semantic] section is missing.
96///
97/// Note: The cache_dir parameter is ignored - kept for API compatibility but will be removed in future.
98pub fn load_config(_cache_dir: &Path) -> Result<SemanticConfig> {
99    // Semantic config is always in user home directory, not project directory
100    let home = match dirs::home_dir() {
101        Some(h) => h,
102        None => {
103            log::debug!("Could not determine home directory, using defaults");
104            return Ok(SemanticConfig::default());
105        }
106    };
107
108    let config_path = home.join(".reflex").join("config.toml");
109
110    if !config_path.exists() {
111        log::debug!("No ~/.reflex/config.toml found, using default semantic config");
112        return Ok(SemanticConfig::default());
113    }
114
115    let config_str = std::fs::read_to_string(&config_path)
116        .context("Failed to read ~/.reflex/config.toml")?;
117
118    let toml_value: toml::Value = toml::from_str(&config_str)
119        .context("Failed to parse ~/.reflex/config.toml")?;
120
121    // Extract [semantic] section
122    if let Some(semantic_table) = toml_value.get("semantic") {
123        let config: SemanticConfig = semantic_table.clone().try_into()
124            .context("Failed to parse [semantic] section in ~/.reflex/config.toml")?;
125        log::debug!("Loaded semantic config from ~/.reflex/config.toml: provider={}", config.provider);
126        Ok(config)
127    } else {
128        log::debug!("No [semantic] section in ~/.reflex/config.toml, using defaults");
129        Ok(SemanticConfig::default())
130    }
131}
132
133/// User configuration structure for ~/.reflex/config.toml
134#[derive(Debug, Clone, Serialize, Deserialize)]
135struct UserConfig {
136    #[serde(default)]
137    credentials: Option<Credentials>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
141struct Credentials {
142    #[serde(default)]
143    openai_api_key: Option<String>,
144    #[serde(default)]
145    anthropic_api_key: Option<String>,
146    #[serde(default)]
147    groq_api_key: Option<String>,
148    #[serde(default)]
149    openai_model: Option<String>,
150    #[serde(default)]
151    anthropic_model: Option<String>,
152    #[serde(default)]
153    groq_model: Option<String>,
154}
155
156/// Load user configuration from ~/.reflex/config.toml
157fn load_user_config() -> Result<Option<UserConfig>> {
158    let home = match dirs::home_dir() {
159        Some(h) => h,
160        None => {
161            log::debug!("Could not determine home directory");
162            return Ok(None);
163        }
164    };
165
166    let config_path = home.join(".reflex").join("config.toml");
167
168    if !config_path.exists() {
169        log::debug!("No user config found at ~/.reflex/config.toml");
170        return Ok(None);
171    }
172
173    let config_str = std::fs::read_to_string(&config_path)
174        .context("Failed to read ~/.reflex/config.toml")?;
175
176    let config: UserConfig = toml::from_str(&config_str)
177        .context("Failed to parse ~/.reflex/config.toml")?;
178
179    Ok(Some(config))
180}
181
182/// Get API key for a provider
183///
184/// Checks in priority order:
185/// 1. ~/.reflex/config.toml (user config file)
186/// 2. {PROVIDER}_API_KEY environment variable (e.g., OPENAI_API_KEY)
187/// 3. Error if not found
188pub fn get_api_key(provider: &str) -> Result<String> {
189    // First check user config file
190    if let Ok(Some(user_config)) = load_user_config() {
191        if let Some(credentials) = &user_config.credentials {
192            // Get the appropriate key based on provider
193            let key = match provider.to_lowercase().as_str() {
194                "openai" => credentials.openai_api_key.as_ref(),
195                "anthropic" => credentials.anthropic_api_key.as_ref(),
196                "groq" => credentials.groq_api_key.as_ref(),
197                _ => None,
198            };
199
200            if let Some(api_key) = key {
201                log::debug!("Using {} API key from ~/.reflex/config.toml", provider);
202                return Ok(api_key.clone());
203            }
204        }
205    }
206
207    // Fall back to environment variables
208    let env_var = match provider.to_lowercase().as_str() {
209        "openai" => "OPENAI_API_KEY",
210        "anthropic" => "ANTHROPIC_API_KEY",
211        "groq" => "GROQ_API_KEY",
212        _ => anyhow::bail!("Unknown provider: {}", provider),
213    };
214
215    env::var(env_var).with_context(|| {
216        format!(
217            "API key not found for provider '{}'.\n\
218             \n\
219             Either:\n\
220             1. Run 'rfx ask --configure' to set up your API key interactively\n\
221             2. Set the {} environment variable manually\n\
222             \n\
223             Example: export {}=sk-...",
224            provider, env_var, env_var
225        )
226    })
227}
228
229/// Check if any API key is configured for any supported provider
230///
231/// Checks in priority order:
232/// 1. ~/.reflex/config.toml (credentials section)
233/// 2. Environment variables (OPENAI_API_KEY, ANTHROPIC_API_KEY, GROQ_API_KEY)
234///
235/// Returns true if at least one API key is found for any provider.
236pub fn is_any_api_key_configured() -> bool {
237    let providers = ["openai", "anthropic", "groq"];
238
239    // Check user config file first
240    if let Ok(Some(user_config)) = load_user_config() {
241        if let Some(credentials) = &user_config.credentials {
242            // Check if any provider has an API key in the config file
243            if credentials.openai_api_key.is_some()
244                || credentials.anthropic_api_key.is_some()
245                || credentials.groq_api_key.is_some()
246            {
247                log::debug!("Found API key in ~/.reflex/config.toml");
248                return true;
249            }
250        }
251    }
252
253    // Check environment variables
254    for provider in &providers {
255        let env_var = match *provider {
256            "openai" => "OPENAI_API_KEY",
257            "anthropic" => "ANTHROPIC_API_KEY",
258            "groq" => "GROQ_API_KEY",
259            _ => continue,
260        };
261
262        if env::var(env_var).is_ok() {
263            log::debug!("Found {} environment variable", env_var);
264            return true;
265        }
266    }
267
268    log::debug!("No API keys found in config or environment variables");
269    false
270}
271
272/// Get the preferred model for a provider from user config
273///
274/// Returns None if no model is configured for this provider.
275/// The caller should use provider defaults if None is returned.
276pub fn get_user_model(provider: &str) -> Option<String> {
277    if let Ok(Some(user_config)) = load_user_config() {
278        if let Some(credentials) = &user_config.credentials {
279            let model = match provider.to_lowercase().as_str() {
280                "openai" => credentials.openai_model.as_ref(),
281                "anthropic" => credentials.anthropic_model.as_ref(),
282                "groq" => credentials.groq_model.as_ref(),
283                _ => None,
284            };
285
286            if let Some(model_name) = model {
287                log::debug!("Using {} model from ~/.reflex/config.toml: {}", provider, model_name);
288                return Some(model_name.clone());
289            }
290        }
291    }
292
293    None
294}
295
296/// Save user's provider/model preference to ~/.reflex/config.toml
297///
298/// Updates the [credentials] section with the new model for the specified provider.
299/// Creates the config file and directory if they don't exist.
300pub fn save_user_provider(provider: &str, model: Option<&str>) -> Result<()> {
301    let home = dirs::home_dir().context("Cannot find home directory")?;
302    let config_dir = home.join(".reflex");
303    let config_path = config_dir.join("config.toml");
304
305    // Create directory if needed
306    std::fs::create_dir_all(&config_dir)
307        .context("Failed to create ~/.reflex directory")?;
308
309    // Read existing config or create empty
310    let mut config: toml::Value = if config_path.exists() {
311        let content = std::fs::read_to_string(&config_path)
312            .context("Failed to read ~/.reflex/config.toml")?;
313        toml::from_str(&content)
314            .context("Failed to parse ~/.reflex/config.toml")?
315    } else {
316        toml::Value::Table(toml::map::Map::new())
317    };
318
319    // Ensure [credentials] section exists
320    let credentials = config
321        .as_table_mut()
322        .context("Config root is not a table")?
323        .entry("credentials")
324        .or_insert(toml::Value::Table(toml::map::Map::new()))
325        .as_table_mut()
326        .context("[credentials] is not a table")?;
327
328    // Set model for this provider (if provided)
329    if let Some(m) = model {
330        let key = format!("{}_model", provider.to_lowercase());
331        credentials.insert(key, toml::Value::String(m.to_string()));
332        log::info!("Saved {} model: {}", provider, m);
333    }
334
335    // Write back to file
336    let toml_str = toml::to_string_pretty(&config)
337        .context("Failed to serialize config to TOML")?;
338    std::fs::write(&config_path, toml_str)
339        .context("Failed to write ~/.reflex/config.toml")?;
340
341    Ok(())
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use tempfile::TempDir;
348
349    #[test]
350    fn test_default_config() {
351        let config = SemanticConfig::default();
352        assert_eq!(config.enabled, true);
353        assert_eq!(config.provider, "openai");
354        assert_eq!(config.model, None);
355        assert_eq!(config.auto_execute, false);
356    }
357
358    #[test]
359    fn test_load_config_no_file() {
360        let temp = TempDir::new().unwrap();
361
362        // Set HOME to temp directory to avoid loading user's config
363        unsafe {
364            env::set_var("HOME", temp.path());
365        }
366        let config = load_config(temp.path()).unwrap();
367        unsafe {
368            env::remove_var("HOME");
369        }
370
371        // Should return defaults
372        assert_eq!(config.provider, "openai");
373        assert_eq!(config.enabled, true);
374    }
375
376    #[test]
377    fn test_load_config_with_semantic_section() {
378        let temp = TempDir::new().unwrap();
379        let reflex_dir = temp.path().join(".reflex");
380        std::fs::create_dir_all(&reflex_dir).unwrap();
381        let config_path = reflex_dir.join("config.toml");
382
383        std::fs::write(
384            &config_path,
385            r#"
386[semantic]
387enabled = true
388provider = "anthropic"
389model = "claude-3-5-sonnet-20241022"
390auto_execute = true
391            "#,
392        )
393        .unwrap();
394
395        // Set HOME to temp directory to load test config
396        unsafe {
397            env::set_var("HOME", temp.path());
398        }
399        let config = load_config(temp.path()).unwrap();
400        unsafe {
401            env::remove_var("HOME");
402        }
403
404        assert_eq!(config.enabled, true);
405        assert_eq!(config.provider, "anthropic");
406        assert_eq!(config.model, Some("claude-3-5-sonnet-20241022".to_string()));
407        assert_eq!(config.auto_execute, true);
408    }
409
410    #[test]
411    fn test_load_config_without_semantic_section() {
412        let temp = TempDir::new().unwrap();
413        let reflex_dir = temp.path().join(".reflex");
414        std::fs::create_dir_all(&reflex_dir).unwrap();
415        let config_path = reflex_dir.join("config.toml");
416
417        std::fs::write(
418            &config_path,
419            r#"
420[index]
421languages = []
422            "#,
423        )
424        .unwrap();
425
426        // Set HOME to temp directory to load test config
427        unsafe {
428            env::set_var("HOME", temp.path());
429        }
430        let config = load_config(temp.path()).unwrap();
431        unsafe {
432            env::remove_var("HOME");
433        }
434
435        // Should return defaults
436        assert_eq!(config.provider, "openai");
437    }
438
439    #[test]
440    fn test_get_api_key_env_var() {
441        let temp = TempDir::new().unwrap();
442
443        // Set HOME to temp directory to avoid loading user's config
444        unsafe {
445            env::set_var("HOME", temp.path());
446            env::set_var("OPENAI_API_KEY", "test-key-123");
447        }
448
449        let key = get_api_key("openai").unwrap();
450        assert_eq!(key, "test-key-123");
451
452        unsafe {
453            env::remove_var("OPENAI_API_KEY");
454            env::remove_var("HOME");
455        }
456    }
457
458    #[test]
459    fn test_get_api_key_missing() {
460        let temp = TempDir::new().unwrap();
461
462        // Set HOME to temp directory to avoid loading user's config
463        unsafe {
464            env::set_var("HOME", temp.path());
465            env::remove_var("GROQ_API_KEY");
466        }
467
468        let result = get_api_key("groq");
469        assert!(result.is_err());
470        assert!(result.unwrap_err().to_string().contains("GROQ_API_KEY"));
471
472        unsafe {
473            env::remove_var("HOME");
474        }
475    }
476
477    #[test]
478    fn test_get_api_key_unknown_provider() {
479        let result = get_api_key("unknown");
480        assert!(result.is_err());
481        assert!(result.unwrap_err().to_string().contains("Unknown provider"));
482    }
483}