Skip to main content

reflex/semantic/
config.rs

1//! Configuration for semantic query feature
2
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::env;
7use std::path::Path;
8
9/// Semantic query configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SemanticConfig {
12    /// Enable semantic query feature
13    #[serde(default = "default_enabled")]
14    pub enabled: bool,
15
16    /// LLM provider (openai, anthropic, openrouter)
17    #[serde(default = "default_provider")]
18    pub provider: String,
19
20    /// Optional model override (uses provider default if None)
21    #[serde(default)]
22    pub model: Option<String>,
23
24    /// Auto-execute generated commands without confirmation
25    #[serde(default)]
26    pub auto_execute: bool,
27
28    /// Enable agentic mode (multi-step reasoning with context gathering)
29    #[serde(default = "default_agentic_enabled")]
30    pub agentic_enabled: bool,
31
32    /// Maximum iterations for query refinement in agentic mode
33    #[serde(default = "default_max_iterations")]
34    pub max_iterations: usize,
35
36    /// Maximum tool calls per context gathering phase
37    #[serde(default = "default_max_tools")]
38    pub max_tools_per_phase: usize,
39
40    /// Enable result evaluation in agentic mode
41    #[serde(default = "default_evaluation_enabled")]
42    pub evaluation_enabled: bool,
43
44    /// Evaluation strictness (0.0-1.0, higher is stricter)
45    #[serde(default = "default_strictness")]
46    pub evaluation_strictness: f32,
47}
48
49fn default_enabled() -> bool {
50    true
51}
52
53fn default_provider() -> String {
54    "openai".to_string()
55}
56
57fn default_agentic_enabled() -> bool {
58    false // Disabled by default, opt-in for experimental feature
59}
60
61fn default_max_iterations() -> usize {
62    2
63}
64
65fn default_max_tools() -> usize {
66    5
67}
68
69fn default_evaluation_enabled() -> bool {
70    true
71}
72
73fn default_strictness() -> f32 {
74    0.5
75}
76
77impl Default for SemanticConfig {
78    fn default() -> Self {
79        Self {
80            enabled: true,
81            provider: "openai".to_string(),
82            model: None,
83            auto_execute: false,
84            agentic_enabled: false,
85            max_iterations: 2,
86            max_tools_per_phase: 5,
87            evaluation_enabled: true,
88            evaluation_strictness: 0.5,
89        }
90    }
91}
92
93/// Apply environment variable overrides to a semantic config.
94///
95/// Supports:
96/// - `REFLEX_PROVIDER` — overrides the provider (e.g., "openrouter", "anthropic", "openai")
97/// - `REFLEX_MODEL` — overrides the model
98///
99/// This enables CI/headless usage where there's no ~/.reflex/config.toml.
100fn apply_env_overrides(mut config: SemanticConfig) -> SemanticConfig {
101    if let Ok(provider) = env::var("REFLEX_PROVIDER") {
102        log::debug!("Overriding provider from REFLEX_PROVIDER env var: {}", provider);
103        config.provider = provider;
104    }
105
106    if let Ok(model) = env::var("REFLEX_MODEL") {
107        log::debug!("Overriding model from REFLEX_MODEL env var: {}", model);
108        config.model = Some(model);
109    }
110
111    config
112}
113
114/// Load semantic config from ~/.reflex/config.toml
115///
116/// Semantic configuration is ALWAYS user-level (not project-level).
117/// Falls back to defaults if file doesn't exist or [semantic] section is missing.
118/// Environment variables `REFLEX_PROVIDER` and `REFLEX_MODEL` override config file values.
119///
120/// Note: The cache_dir parameter is ignored - kept for API compatibility but will be removed in future.
121pub fn load_config(_cache_dir: &Path) -> Result<SemanticConfig> {
122    // Semantic config is always in user home directory, not project directory
123    let home = match dirs::home_dir() {
124        Some(h) => h,
125        None => {
126            log::debug!("Could not determine home directory, using defaults");
127            return Ok(apply_env_overrides(SemanticConfig::default()));
128        }
129    };
130
131    let config_path = home.join(".reflex").join("config.toml");
132
133    if !config_path.exists() {
134        log::debug!("No ~/.reflex/config.toml found, using default semantic config");
135        return Ok(apply_env_overrides(SemanticConfig::default()));
136    }
137
138    let config_str = std::fs::read_to_string(&config_path)
139        .context("Failed to read ~/.reflex/config.toml")?;
140
141    let toml_value: toml::Value = toml::from_str(&config_str)
142        .context("Failed to parse ~/.reflex/config.toml")?;
143
144    // Extract [semantic] section
145    if let Some(semantic_table) = toml_value.get("semantic") {
146        let config: SemanticConfig = semantic_table.clone().try_into()
147            .context("Failed to parse [semantic] section in ~/.reflex/config.toml")?;
148        log::debug!("Loaded semantic config from ~/.reflex/config.toml: provider={}", config.provider);
149        Ok(apply_env_overrides(config))
150    } else {
151        log::debug!("No [semantic] section in ~/.reflex/config.toml, using defaults");
152        Ok(apply_env_overrides(SemanticConfig::default()))
153    }
154}
155
156/// User configuration structure for ~/.reflex/config.toml
157#[derive(Debug, Clone, Serialize, Deserialize)]
158struct UserConfig {
159    #[serde(default)]
160    credentials: Option<Credentials>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164struct Credentials {
165    #[serde(default)]
166    openai_api_key: Option<String>,
167    #[serde(default)]
168    anthropic_api_key: Option<String>,
169    #[serde(default)]
170    openrouter_api_key: Option<String>,
171    #[serde(default)]
172    openai_model: Option<String>,
173    #[serde(default)]
174    anthropic_model: Option<String>,
175    #[serde(default)]
176    openrouter_model: Option<String>,
177    #[serde(default)]
178    openrouter_sort: Option<String>,
179}
180
181/// Load user configuration from ~/.reflex/config.toml
182fn load_user_config() -> Result<Option<UserConfig>> {
183    let home = match dirs::home_dir() {
184        Some(h) => h,
185        None => {
186            log::debug!("Could not determine home directory");
187            return Ok(None);
188        }
189    };
190
191    let config_path = home.join(".reflex").join("config.toml");
192
193    if !config_path.exists() {
194        log::debug!("No user config found at ~/.reflex/config.toml");
195        return Ok(None);
196    }
197
198    let config_str = std::fs::read_to_string(&config_path)
199        .context("Failed to read ~/.reflex/config.toml")?;
200
201    let config: UserConfig = toml::from_str(&config_str)
202        .context("Failed to parse ~/.reflex/config.toml")?;
203
204    Ok(Some(config))
205}
206
207/// Get API key for a provider
208///
209/// Checks in priority order:
210/// 1. ~/.reflex/config.toml (user config file)
211/// 2. REFLEX_AI_API_KEY environment variable (generic, provider-agnostic)
212/// 3. {PROVIDER}_API_KEY environment variable (e.g., OPENAI_API_KEY)
213/// 4. Error if not found
214pub fn get_api_key(provider: &str) -> Result<String> {
215    // First check user config file
216    if let Ok(Some(user_config)) = load_user_config() {
217        if let Some(credentials) = &user_config.credentials {
218            // Get the appropriate key based on provider
219            let key = match provider.to_lowercase().as_str() {
220                "openai" => credentials.openai_api_key.as_ref(),
221                "anthropic" => credentials.anthropic_api_key.as_ref(),
222                "openrouter" => credentials.openrouter_api_key.as_ref(),
223                _ => None,
224            };
225
226            if let Some(api_key) = key {
227                log::debug!("Using {} API key from ~/.reflex/config.toml", provider);
228                return Ok(api_key.clone());
229            }
230        }
231    }
232
233    // Check generic REFLEX_AI_API_KEY env var (provider-agnostic, useful for CI)
234    if let Ok(key) = env::var("REFLEX_AI_API_KEY") {
235        log::debug!("Using API key from REFLEX_AI_API_KEY env var for provider '{}'", provider);
236        return Ok(key);
237    }
238
239    // Fall back to provider-specific environment variables
240    let env_var = match provider.to_lowercase().as_str() {
241        "openai" => "OPENAI_API_KEY",
242        "anthropic" => "ANTHROPIC_API_KEY",
243        "openrouter" => "OPENROUTER_API_KEY",
244        _ => anyhow::bail!("Unknown provider: {}", provider),
245    };
246
247    env::var(env_var).with_context(|| {
248        format!(
249            "API key not found for provider '{}'.\n\
250             \n\
251             Either:\n\
252             1. Run 'rfx ask --configure' to set up your API key interactively\n\
253             2. Set REFLEX_AI_API_KEY (works with any provider)\n\
254             3. Set the {} environment variable\n\
255             \n\
256             Example: export REFLEX_AI_API_KEY=sk-...",
257            provider, env_var
258        )
259    })
260}
261
262/// Check if any API key is configured for any supported provider
263///
264/// Checks in priority order:
265/// 1. ~/.reflex/config.toml (credentials section)
266/// 2. REFLEX_AI_API_KEY environment variable (generic)
267/// 3. Provider-specific environment variables (OPENAI_API_KEY, ANTHROPIC_API_KEY, OPENROUTER_API_KEY)
268///
269/// Returns true if at least one API key is found for any provider.
270pub fn is_any_api_key_configured() -> bool {
271    let providers = ["openai", "anthropic", "openrouter"];
272
273    // Check user config file first
274    if let Ok(Some(user_config)) = load_user_config() {
275        if let Some(credentials) = &user_config.credentials {
276            // Check if any provider has an API key in the config file
277            if credentials.openai_api_key.is_some()
278                || credentials.anthropic_api_key.is_some()
279                || credentials.openrouter_api_key.is_some()
280            {
281                log::debug!("Found API key in ~/.reflex/config.toml");
282                return true;
283            }
284        }
285    }
286
287    // Check generic REFLEX_AI_API_KEY
288    if env::var("REFLEX_AI_API_KEY").is_ok() {
289        log::debug!("Found REFLEX_AI_API_KEY env var");
290        return true;
291    }
292
293    // Check provider-specific environment variables
294    for provider in &providers {
295        let env_var = match *provider {
296            "openai" => "OPENAI_API_KEY",
297            "anthropic" => "ANTHROPIC_API_KEY",
298            "openrouter" => "OPENROUTER_API_KEY",
299            _ => continue,
300        };
301
302        if env::var(env_var).is_ok() {
303            log::debug!("Found {} environment variable", env_var);
304            return true;
305        }
306    }
307
308    log::debug!("No API keys found in config or environment variables");
309    false
310}
311
312/// Get the preferred model for a provider from user config
313///
314/// Returns None if no model is configured for this provider.
315/// The caller should use provider defaults if None is returned.
316pub fn get_user_model(provider: &str) -> Option<String> {
317    if let Ok(Some(user_config)) = load_user_config() {
318        if let Some(credentials) = &user_config.credentials {
319            let model = match provider.to_lowercase().as_str() {
320                "openai" => credentials.openai_model.as_ref(),
321                "anthropic" => credentials.anthropic_model.as_ref(),
322                "openrouter" => credentials.openrouter_model.as_ref(),
323                _ => None,
324            };
325
326            if let Some(model_name) = model {
327                log::debug!("Using {} model from ~/.reflex/config.toml: {}", provider, model_name);
328                return Some(model_name.clone());
329            }
330        }
331    }
332
333    None
334}
335
336/// Save user's provider/model preference to ~/.reflex/config.toml
337///
338/// Updates the [credentials] section with the new model for the specified provider.
339/// Creates the config file and directory if they don't exist.
340pub fn save_user_provider(provider: &str, model: Option<&str>) -> Result<()> {
341    let home = dirs::home_dir().context("Cannot find home directory")?;
342    let config_dir = home.join(".reflex");
343    let config_path = config_dir.join("config.toml");
344
345    // Create directory if needed
346    std::fs::create_dir_all(&config_dir)
347        .context("Failed to create ~/.reflex directory")?;
348
349    // Read existing config or create empty
350    let mut config: toml::Value = if config_path.exists() {
351        let content = std::fs::read_to_string(&config_path)
352            .context("Failed to read ~/.reflex/config.toml")?;
353        toml::from_str(&content)
354            .context("Failed to parse ~/.reflex/config.toml")?
355    } else {
356        toml::Value::Table(toml::map::Map::new())
357    };
358
359    // Ensure [credentials] section exists
360    let credentials = config
361        .as_table_mut()
362        .context("Config root is not a table")?
363        .entry("credentials")
364        .or_insert(toml::Value::Table(toml::map::Map::new()))
365        .as_table_mut()
366        .context("[credentials] is not a table")?;
367
368    // Set model for this provider (if provided)
369    if let Some(m) = model {
370        let key = format!("{}_model", provider.to_lowercase());
371        credentials.insert(key, toml::Value::String(m.to_string()));
372        log::info!("Saved {} model: {}", provider, m);
373    }
374
375    // Write back to file
376    let toml_str = toml::to_string_pretty(&config)
377        .context("Failed to serialize config to TOML")?;
378    std::fs::write(&config_path, toml_str)
379        .context("Failed to write ~/.reflex/config.toml")?;
380
381    Ok(())
382}
383
384/// Get provider-specific options from user config
385///
386/// Returns `Some(HashMap)` for providers that need extra settings (e.g., OpenRouter sort strategy).
387/// Returns `None` for providers with no additional options.
388pub fn get_provider_options(provider: &str) -> Option<HashMap<String, String>> {
389    if provider.to_lowercase() != "openrouter" {
390        return None;
391    }
392
393    if let Ok(Some(user_config)) = load_user_config() {
394        if let Some(credentials) = &user_config.credentials {
395            if let Some(sort) = &credentials.openrouter_sort {
396                let mut opts = HashMap::new();
397                opts.insert("sort".to_string(), sort.clone());
398                return Some(opts);
399            }
400        }
401    }
402
403    None
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use tempfile::TempDir;
410
411    #[test]
412    fn test_default_config() {
413        let config = SemanticConfig::default();
414        assert_eq!(config.enabled, true);
415        assert_eq!(config.provider, "openai");
416        assert_eq!(config.model, None);
417        assert_eq!(config.auto_execute, false);
418    }
419
420    #[test]
421    fn test_load_config_no_file() {
422        let temp = TempDir::new().unwrap();
423
424        // Set HOME to temp directory to avoid loading user's config
425        unsafe {
426            env::set_var("HOME", temp.path());
427        }
428        let config = load_config(temp.path()).unwrap();
429        unsafe {
430            env::remove_var("HOME");
431        }
432
433        // Should return defaults
434        assert_eq!(config.provider, "openai");
435        assert_eq!(config.enabled, true);
436    }
437
438    #[test]
439    fn test_load_config_with_semantic_section() {
440        let temp = TempDir::new().unwrap();
441        let reflex_dir = temp.path().join(".reflex");
442        std::fs::create_dir_all(&reflex_dir).unwrap();
443        let config_path = reflex_dir.join("config.toml");
444
445        std::fs::write(
446            &config_path,
447            r#"
448[semantic]
449enabled = true
450provider = "anthropic"
451model = "claude-3-5-sonnet-20241022"
452auto_execute = true
453            "#,
454        )
455        .unwrap();
456
457        // Set HOME to temp directory to load test config
458        unsafe {
459            env::set_var("HOME", temp.path());
460        }
461        let config = load_config(temp.path()).unwrap();
462        unsafe {
463            env::remove_var("HOME");
464        }
465
466        assert_eq!(config.enabled, true);
467        assert_eq!(config.provider, "anthropic");
468        assert_eq!(config.model, Some("claude-3-5-sonnet-20241022".to_string()));
469        assert_eq!(config.auto_execute, true);
470    }
471
472    #[test]
473    fn test_load_config_without_semantic_section() {
474        let temp = TempDir::new().unwrap();
475        let reflex_dir = temp.path().join(".reflex");
476        std::fs::create_dir_all(&reflex_dir).unwrap();
477        let config_path = reflex_dir.join("config.toml");
478
479        std::fs::write(
480            &config_path,
481            r#"
482[index]
483languages = []
484            "#,
485        )
486        .unwrap();
487
488        // Set HOME to temp directory to load test config
489        unsafe {
490            env::set_var("HOME", temp.path());
491        }
492        let config = load_config(temp.path()).unwrap();
493        unsafe {
494            env::remove_var("HOME");
495        }
496
497        // Should return defaults
498        assert_eq!(config.provider, "openai");
499    }
500
501    #[test]
502    fn test_get_api_key_env_var() {
503        let temp = TempDir::new().unwrap();
504
505        // Set HOME to temp directory to avoid loading user's config
506        unsafe {
507            env::set_var("HOME", temp.path());
508            env::set_var("OPENAI_API_KEY", "test-key-123");
509        }
510
511        let key = get_api_key("openai").unwrap();
512        assert_eq!(key, "test-key-123");
513
514        unsafe {
515            env::remove_var("OPENAI_API_KEY");
516            env::remove_var("HOME");
517        }
518    }
519
520    #[test]
521    fn test_get_api_key_missing() {
522        let temp = TempDir::new().unwrap();
523
524        // Set HOME to temp directory to avoid loading user's config
525        unsafe {
526            env::set_var("HOME", temp.path());
527            env::remove_var("OPENROUTER_API_KEY");
528        }
529
530        let result = get_api_key("openrouter");
531        assert!(result.is_err());
532        assert!(result.unwrap_err().to_string().contains("OPENROUTER_API_KEY"));
533
534        unsafe {
535            env::remove_var("HOME");
536        }
537    }
538
539    #[test]
540    fn test_get_api_key_unknown_provider() {
541        let result = get_api_key("unknown");
542        assert!(result.is_err());
543        assert!(result.unwrap_err().to_string().contains("Unknown provider"));
544    }
545
546    #[test]
547    fn test_env_override_provider() {
548        let temp = TempDir::new().unwrap();
549
550        unsafe {
551            env::set_var("HOME", temp.path());
552            env::set_var("REFLEX_PROVIDER", "openrouter");
553        }
554
555        let config = load_config(temp.path()).unwrap();
556
557        unsafe {
558            env::remove_var("REFLEX_PROVIDER");
559            env::remove_var("HOME");
560        }
561
562        assert_eq!(config.provider, "openrouter");
563    }
564
565    #[test]
566    fn test_env_override_model() {
567        let temp = TempDir::new().unwrap();
568
569        unsafe {
570            env::set_var("HOME", temp.path());
571            env::set_var("REFLEX_MODEL", "google/gemini-2.5-flash");
572        }
573
574        let config = load_config(temp.path()).unwrap();
575
576        unsafe {
577            env::remove_var("REFLEX_MODEL");
578            env::remove_var("HOME");
579        }
580
581        assert_eq!(config.model, Some("google/gemini-2.5-flash".to_string()));
582        // Provider should remain the default since we didn't override it
583        assert_eq!(config.provider, "openai");
584    }
585
586    #[test]
587    fn test_get_api_key_generic_env_var() {
588        let temp = TempDir::new().unwrap();
589
590        unsafe {
591            env::set_var("HOME", temp.path());
592            env::remove_var("OPENROUTER_API_KEY");
593            env::set_var("REFLEX_AI_API_KEY", "generic-key-456");
594        }
595
596        let key = get_api_key("openrouter").unwrap();
597        assert_eq!(key, "generic-key-456");
598
599        unsafe {
600            env::remove_var("REFLEX_AI_API_KEY");
601            env::remove_var("HOME");
602        }
603    }
604}