Skip to main content

reflex/semantic/
config.rs

1//! Configuration for semantic query feature
2
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::env;
7use std::path::Path;
8
9/// Semantic query configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SemanticConfig {
12    /// Enable semantic query feature
13    #[serde(default = "default_enabled")]
14    pub enabled: bool,
15
16    /// LLM provider (openai, anthropic, openrouter)
17    #[serde(default = "default_provider")]
18    pub provider: String,
19
20    /// Optional model override (uses provider default if None)
21    #[serde(default)]
22    pub model: Option<String>,
23
24    /// Auto-execute generated commands without confirmation
25    #[serde(default)]
26    pub auto_execute: bool,
27
28    /// Enable agentic mode (multi-step reasoning with context gathering)
29    #[serde(default = "default_agentic_enabled")]
30    pub agentic_enabled: bool,
31
32    /// Maximum iterations for query refinement in agentic mode
33    #[serde(default = "default_max_iterations")]
34    pub max_iterations: usize,
35
36    /// Maximum tool calls per context gathering phase
37    #[serde(default = "default_max_tools")]
38    pub max_tools_per_phase: usize,
39
40    /// Enable result evaluation in agentic mode
41    #[serde(default = "default_evaluation_enabled")]
42    pub evaluation_enabled: bool,
43
44    /// Evaluation strictness (0.0-1.0, higher is stricter)
45    #[serde(default = "default_strictness")]
46    pub evaluation_strictness: f32,
47}
48
49fn default_enabled() -> bool {
50    true
51}
52
53fn default_provider() -> String {
54    "openai".to_string()
55}
56
57fn default_agentic_enabled() -> bool {
58    false // Disabled by default, opt-in for experimental feature
59}
60
61fn default_max_iterations() -> usize {
62    2
63}
64
65fn default_max_tools() -> usize {
66    5
67}
68
69fn default_evaluation_enabled() -> bool {
70    true
71}
72
73fn default_strictness() -> f32 {
74    0.5
75}
76
77impl Default for SemanticConfig {
78    fn default() -> Self {
79        Self {
80            enabled: true,
81            provider: "openai".to_string(),
82            model: None,
83            auto_execute: false,
84            agentic_enabled: false,
85            max_iterations: 2,
86            max_tools_per_phase: 5,
87            evaluation_enabled: true,
88            evaluation_strictness: 0.5,
89        }
90    }
91}
92
93/// Apply environment variable overrides to a semantic config.
94///
95/// Supports:
96/// - `REFLEX_PROVIDER` — overrides the provider (e.g., "openrouter", "anthropic", "openai")
97/// - `REFLEX_MODEL` — overrides the model
98///
99/// This enables CI/headless usage where there's no ~/.reflex/config.toml.
100fn apply_env_overrides(mut config: SemanticConfig) -> SemanticConfig {
101    if let Ok(provider) = env::var("REFLEX_PROVIDER") {
102        if !provider.is_empty() {
103            log::debug!("Overriding provider from REFLEX_PROVIDER env var: {}", provider);
104            config.provider = provider;
105        }
106    }
107
108    if let Ok(model) = env::var("REFLEX_MODEL") {
109        if !model.is_empty() {
110            log::debug!("Overriding model from REFLEX_MODEL env var: {}", model);
111            config.model = Some(model);
112        }
113    }
114
115    config
116}
117
118/// Load semantic config from ~/.reflex/config.toml
119///
120/// Semantic configuration is ALWAYS user-level (not project-level).
121/// Falls back to defaults if file doesn't exist or [semantic] section is missing.
122/// Environment variables `REFLEX_PROVIDER` and `REFLEX_MODEL` override config file values.
123///
124/// Note: The cache_dir parameter is ignored - kept for API compatibility but will be removed in future.
125pub fn load_config(_cache_dir: &Path) -> Result<SemanticConfig> {
126    // Semantic config is always in user home directory, not project directory
127    let home = match dirs::home_dir() {
128        Some(h) => h,
129        None => {
130            log::debug!("Could not determine home directory, using defaults");
131            return Ok(apply_env_overrides(SemanticConfig::default()));
132        }
133    };
134
135    let config_path = home.join(".reflex").join("config.toml");
136
137    if !config_path.exists() {
138        log::debug!("No ~/.reflex/config.toml found, using default semantic config");
139        return Ok(apply_env_overrides(SemanticConfig::default()));
140    }
141
142    let config_str = std::fs::read_to_string(&config_path)
143        .context("Failed to read ~/.reflex/config.toml")?;
144
145    let toml_value: toml::Value = toml::from_str(&config_str)
146        .context("Failed to parse ~/.reflex/config.toml")?;
147
148    // Extract [semantic] section
149    if let Some(semantic_table) = toml_value.get("semantic") {
150        let config: SemanticConfig = semantic_table.clone().try_into()
151            .context("Failed to parse [semantic] section in ~/.reflex/config.toml")?;
152        log::debug!("Loaded semantic config from ~/.reflex/config.toml: provider={}", config.provider);
153        Ok(apply_env_overrides(config))
154    } else {
155        log::debug!("No [semantic] section in ~/.reflex/config.toml, using defaults");
156        Ok(apply_env_overrides(SemanticConfig::default()))
157    }
158}
159
160/// User configuration structure for ~/.reflex/config.toml
161#[derive(Debug, Clone, Serialize, Deserialize)]
162struct UserConfig {
163    #[serde(default)]
164    credentials: Option<Credentials>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168struct Credentials {
169    #[serde(default)]
170    openai_api_key: Option<String>,
171    #[serde(default)]
172    anthropic_api_key: Option<String>,
173    #[serde(default)]
174    openrouter_api_key: Option<String>,
175    #[serde(default)]
176    openai_model: Option<String>,
177    #[serde(default)]
178    anthropic_model: Option<String>,
179    #[serde(default)]
180    openrouter_model: Option<String>,
181    #[serde(default)]
182    openrouter_sort: Option<String>,
183}
184
185/// Load user configuration from ~/.reflex/config.toml
186fn load_user_config() -> Result<Option<UserConfig>> {
187    let home = match dirs::home_dir() {
188        Some(h) => h,
189        None => {
190            log::debug!("Could not determine home directory");
191            return Ok(None);
192        }
193    };
194
195    let config_path = home.join(".reflex").join("config.toml");
196
197    if !config_path.exists() {
198        log::debug!("No user config found at ~/.reflex/config.toml");
199        return Ok(None);
200    }
201
202    let config_str = std::fs::read_to_string(&config_path)
203        .context("Failed to read ~/.reflex/config.toml")?;
204
205    let config: UserConfig = toml::from_str(&config_str)
206        .context("Failed to parse ~/.reflex/config.toml")?;
207
208    Ok(Some(config))
209}
210
211/// Get API key for a provider
212///
213/// Checks in priority order:
214/// 1. ~/.reflex/config.toml (user config file)
215/// 2. REFLEX_AI_API_KEY environment variable (generic, provider-agnostic)
216/// 3. {PROVIDER}_API_KEY environment variable (e.g., OPENAI_API_KEY)
217/// 4. Error if not found
218pub fn get_api_key(provider: &str) -> Result<String> {
219    // First check user config file
220    if let Ok(Some(user_config)) = load_user_config() {
221        if let Some(credentials) = &user_config.credentials {
222            // Get the appropriate key based on provider
223            let key = match provider.to_lowercase().as_str() {
224                "openai" => credentials.openai_api_key.as_ref(),
225                "anthropic" => credentials.anthropic_api_key.as_ref(),
226                "openrouter" => credentials.openrouter_api_key.as_ref(),
227                _ => None,
228            };
229
230            if let Some(api_key) = key {
231                log::debug!("Using {} API key from ~/.reflex/config.toml", provider);
232                return Ok(api_key.clone());
233            }
234        }
235    }
236
237    // Check generic REFLEX_AI_API_KEY env var (provider-agnostic, useful for CI)
238    if let Ok(key) = env::var("REFLEX_AI_API_KEY") {
239        if !key.is_empty() {
240            log::debug!("Using API key from REFLEX_AI_API_KEY env var for provider '{}'", provider);
241            return Ok(key);
242        }
243    }
244
245    // Fall back to provider-specific environment variables
246    let env_var = match provider.to_lowercase().as_str() {
247        "openai" => "OPENAI_API_KEY",
248        "anthropic" => "ANTHROPIC_API_KEY",
249        "openrouter" => "OPENROUTER_API_KEY",
250        _ => anyhow::bail!("Unknown provider: {}", provider),
251    };
252
253    env::var(env_var).with_context(|| {
254        format!(
255            "API key not found for provider '{}'.\n\
256             \n\
257             Either:\n\
258             1. Run 'rfx ask --configure' to set up your API key interactively\n\
259             2. Set REFLEX_AI_API_KEY (works with any provider)\n\
260             3. Set the {} environment variable\n\
261             \n\
262             Example: export REFLEX_AI_API_KEY=sk-...",
263            provider, env_var
264        )
265    })
266}
267
268/// Check if any API key is configured for any supported provider
269///
270/// Checks in priority order:
271/// 1. ~/.reflex/config.toml (credentials section)
272/// 2. REFLEX_AI_API_KEY environment variable (generic)
273/// 3. Provider-specific environment variables (OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY)
274///
275/// Returns true if at least one API key is found for any provider.
276pub fn is_any_api_key_configured() -> bool {
277    let providers = ["openai", "anthropic", "openrouter"];
278
279    // Check user config file first
280    if let Ok(Some(user_config)) = load_user_config() {
281        if let Some(credentials) = &user_config.credentials {
282            // Check if any provider has an API key in the config file
283            if credentials.openai_api_key.is_some()
284                || credentials.anthropic_api_key.is_some()
285                || credentials.openrouter_api_key.is_some()
286            {
287                log::debug!("Found API key in ~/.reflex/config.toml");
288                return true;
289            }
290        }
291    }
292
293    // Check generic REFLEX_AI_API_KEY
294    if let Ok(key) = env::var("REFLEX_AI_API_KEY") {
295        if !key.is_empty() {
296            log::debug!("Found REFLEX_AI_API_KEY env var");
297            return true;
298        }
299    }
300
301    // Check provider-specific environment variables
302    for provider in &providers {
303        let env_var = match *provider {
304            "openai" => "OPENAI_API_KEY",
305            "anthropic" => "ANTHROPIC_API_KEY",
306            "openrouter" => "OPENROUTER_API_KEY",
307            _ => continue,
308        };
309
310        if env::var(env_var).is_ok() {
311            log::debug!("Found {} environment variable", env_var);
312            return true;
313        }
314    }
315
316    log::debug!("No API keys found in config or environment variables");
317    false
318}
319
320/// Get the preferred model for a provider from user config
321///
322/// Returns None if no model is configured for this provider.
323/// The caller should use provider defaults if None is returned.
324pub fn get_user_model(provider: &str) -> Option<String> {
325    if let Ok(Some(user_config)) = load_user_config() {
326        if let Some(credentials) = &user_config.credentials {
327            let model = match provider.to_lowercase().as_str() {
328                "openai" => credentials.openai_model.as_ref(),
329                "anthropic" => credentials.anthropic_model.as_ref(),
330                "openrouter" => credentials.openrouter_model.as_ref(),
331                _ => None,
332            };
333
334            if let Some(model_name) = model {
335                log::debug!("Using {} model from ~/.reflex/config.toml: {}", provider, model_name);
336                return Some(model_name.clone());
337            }
338        }
339    }
340
341    None
342}
343
344/// Save user's provider/model preference to ~/.reflex/config.toml
345///
346/// Updates the [credentials] section with the new model for the specified provider.
347/// Creates the config file and directory if they don't exist.
348pub fn save_user_provider(provider: &str, model: Option<&str>) -> Result<()> {
349    let home = dirs::home_dir().context("Cannot find home directory")?;
350    let config_dir = home.join(".reflex");
351    let config_path = config_dir.join("config.toml");
352
353    // Create directory if needed
354    std::fs::create_dir_all(&config_dir)
355        .context("Failed to create ~/.reflex directory")?;
356
357    // Read existing config or create empty
358    let mut config: toml::Value = if config_path.exists() {
359        let content = std::fs::read_to_string(&config_path)
360            .context("Failed to read ~/.reflex/config.toml")?;
361        toml::from_str(&content)
362            .context("Failed to parse ~/.reflex/config.toml")?
363    } else {
364        toml::Value::Table(toml::map::Map::new())
365    };
366
367    // Ensure [credentials] section exists
368    let credentials = config
369        .as_table_mut()
370        .context("Config root is not a table")?
371        .entry("credentials")
372        .or_insert(toml::Value::Table(toml::map::Map::new()))
373        .as_table_mut()
374        .context("[credentials] is not a table")?;
375
376    // Set model for this provider (if provided)
377    if let Some(m) = model {
378        let key = format!("{}_model", provider.to_lowercase());
379        credentials.insert(key, toml::Value::String(m.to_string()));
380        log::info!("Saved {} model: {}", provider, m);
381    }
382
383    // Write back to file
384    let toml_str = toml::to_string_pretty(&config)
385        .context("Failed to serialize config to TOML")?;
386    std::fs::write(&config_path, toml_str)
387        .context("Failed to write ~/.reflex/config.toml")?;
388
389    Ok(())
390}
391
392/// Get provider-specific options from user config
393///
394/// Returns `Some(HashMap)` for providers that need extra settings (e.g., OpenRouter sort strategy).
395/// Returns `None` for providers with no additional options.
396pub fn get_provider_options(provider: &str) -> Option<HashMap<String, String>> {
397    if provider.to_lowercase() != "openrouter" {
398        return None;
399    }
400
401    if let Ok(Some(user_config)) = load_user_config() {
402        if let Some(credentials) = &user_config.credentials {
403            if let Some(sort) = &credentials.openrouter_sort {
404                let mut opts = HashMap::new();
405                opts.insert("sort".to_string(), sort.clone());
406                return Some(opts);
407            }
408        }
409    }
410
411    None
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use tempfile::TempDir;
418
419    #[test]
420    fn test_default_config() {
421        let config = SemanticConfig::default();
422        assert_eq!(config.enabled, true);
423        assert_eq!(config.provider, "openai");
424        assert_eq!(config.model, None);
425        assert_eq!(config.auto_execute, false);
426    }
427
428    #[test]
429    fn test_load_config_no_file() {
430        let temp = TempDir::new().unwrap();
431
432        // Set HOME to temp directory to avoid loading user's config
433        unsafe {
434            env::set_var("HOME", temp.path());
435        }
436        let config = load_config(temp.path()).unwrap();
437        unsafe {
438            env::remove_var("HOME");
439        }
440
441        // Should return defaults
442        assert_eq!(config.provider, "openai");
443        assert_eq!(config.enabled, true);
444    }
445
446    #[test]
447    fn test_load_config_with_semantic_section() {
448        let temp = TempDir::new().unwrap();
449        let reflex_dir = temp.path().join(".reflex");
450        std::fs::create_dir_all(&reflex_dir).unwrap();
451        let config_path = reflex_dir.join("config.toml");
452
453        std::fs::write(
454            &config_path,
455            r#"
456[semantic]
457enabled = true
458provider = "anthropic"
459model = "claude-3-5-sonnet-20241022"
460auto_execute = true
461            "#,
462        )
463        .unwrap();
464
465        // Set HOME to temp directory to load test config
466        unsafe {
467            env::set_var("HOME", temp.path());
468        }
469        let config = load_config(temp.path()).unwrap();
470        unsafe {
471            env::remove_var("HOME");
472        }
473
474        assert_eq!(config.enabled, true);
475        assert_eq!(config.provider, "anthropic");
476        assert_eq!(config.model, Some("claude-3-5-sonnet-20241022".to_string()));
477        assert_eq!(config.auto_execute, true);
478    }
479
480    #[test]
481    fn test_load_config_without_semantic_section() {
482        let temp = TempDir::new().unwrap();
483        let reflex_dir = temp.path().join(".reflex");
484        std::fs::create_dir_all(&reflex_dir).unwrap();
485        let config_path = reflex_dir.join("config.toml");
486
487        std::fs::write(
488            &config_path,
489            r#"
490[index]
491languages = []
492            "#,
493        )
494        .unwrap();
495
496        // Set HOME to temp directory to load test config
497        unsafe {
498            env::set_var("HOME", temp.path());
499        }
500        let config = load_config(temp.path()).unwrap();
501        unsafe {
502            env::remove_var("HOME");
503        }
504
505        // Should return defaults
506        assert_eq!(config.provider, "openai");
507    }
508
509    #[test]
510    fn test_get_api_key_env_var() {
511        let temp = TempDir::new().unwrap();
512
513        // Set HOME to temp directory to avoid loading user's config
514        unsafe {
515            env::set_var("HOME", temp.path());
516            env::set_var("OPENAI_API_KEY", "test-key-123");
517        }
518
519        let key = get_api_key("openai").unwrap();
520        assert_eq!(key, "test-key-123");
521
522        unsafe {
523            env::remove_var("OPENAI_API_KEY");
524            env::remove_var("HOME");
525        }
526    }
527
528    #[test]
529    fn test_get_api_key_missing() {
530        let temp = TempDir::new().unwrap();
531
532        // Set HOME to temp directory to avoid loading user's config
533        unsafe {
534            env::set_var("HOME", temp.path());
535            env::remove_var("OPENROUTER_API_KEY");
536        }
537
538        let result = get_api_key("openrouter");
539        assert!(result.is_err());
540        assert!(result.unwrap_err().to_string().contains("OPENROUTER_API_KEY"));
541
542        unsafe {
543            env::remove_var("HOME");
544        }
545    }
546
547    #[test]
548    fn test_get_api_key_unknown_provider() {
549        let result = get_api_key("unknown");
550        assert!(result.is_err());
551        assert!(result.unwrap_err().to_string().contains("Unknown provider"));
552    }
553
554    #[test]
555    fn test_env_override_provider() {
556        let temp = TempDir::new().unwrap();
557
558        unsafe {
559            env::set_var("HOME", temp.path());
560            env::set_var("REFLEX_PROVIDER", "openrouter");
561        }
562
563        let config = load_config(temp.path()).unwrap();
564
565        unsafe {
566            env::remove_var("REFLEX_PROVIDER");
567            env::remove_var("HOME");
568        }
569
570        assert_eq!(config.provider, "openrouter");
571    }
572
573    #[test]
574    fn test_env_override_model() {
575        let temp = TempDir::new().unwrap();
576
577        unsafe {
578            env::set_var("HOME", temp.path());
579            env::set_var("REFLEX_MODEL", "google/gemini-2.5-flash");
580        }
581
582        let config = load_config(temp.path()).unwrap();
583
584        unsafe {
585            env::remove_var("REFLEX_MODEL");
586            env::remove_var("HOME");
587        }
588
589        assert_eq!(config.model, Some("google/gemini-2.5-flash".to_string()));
590        // Provider should remain the default since we didn't override it
591        assert_eq!(config.provider, "openai");
592    }
593
594    #[test]
595    fn test_get_api_key_generic_env_var() {
596        let temp = TempDir::new().unwrap();
597
598        unsafe {
599            env::set_var("HOME", temp.path());
600            env::remove_var("OPENROUTER_API_KEY");
601            env::set_var("REFLEX_AI_API_KEY", "generic-key-456");
602        }
603
604        let key = get_api_key("openrouter").unwrap();
605        assert_eq!(key, "generic-key-456");
606
607        unsafe {
608            env::remove_var("REFLEX_AI_API_KEY");
609            env::remove_var("HOME");
610        }
611    }
612}