ai_firewall/
config.rs

1use serde::{Deserialize, Serialize};
2use std::path::Path;
3use tokio::fs;
4use crate::{Error, Result};
5
6#[derive(Debug, Serialize, Deserialize, Clone)]
7pub struct ModelConfig {
8    pub model_name: String,
9    pub provider: String,
10    pub api_base: String,
11}
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct DefaultConfig {
15    pub provider: String,
16    pub api_base: String,
17}
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct TelemetryWebhookConfig {
21    pub url: String,
22    pub headers: std::collections::HashMap<String, String>,
23}
24
25#[derive(Debug, Serialize, Deserialize, Clone)]
26pub struct Config {
27    pub models: Option<Vec<ModelConfig>>,
28    pub default: Option<DefaultConfig>,
29    pub telemetry_webhook: Option<TelemetryWebhookConfig>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ResolvedModelConfig {
34    pub provider: String,
35    pub api_base: String,
36    pub model_name: String,
37}
38
39pub struct ConfigManager {
40    config: Option<Config>,
41    config_path: String,
42}
43
44impl ConfigManager {
45    pub fn new() -> Self {
46        let config_path = Self::find_config_file();
47        Self {
48            config: None,
49            config_path,
50        }
51    }
52
53    pub fn new_with_path(config_path: Option<String>) -> Self {
54        let config_path = config_path.unwrap_or_else(|| Self::find_config_file());
55        Self {
56            config: None,
57            config_path,
58        }
59    }
60
61    fn find_config_file() -> String {
62        // Try multiple locations in order of preference
63        let possible_paths = vec![
64            // 1. Environment variable
65            std::env::var("SUPERAGENT_CONFIG").ok(),
66            // 2. Current working directory
67            Some("superagent.yaml".to_string()),
68            // 3. Parent directory (for rust/ subdirectory setup)
69            Some("../superagent.yaml".to_string()),
70            // 4. Home directory
71            dirs::home_dir().map(|home| home.join(".superagent").join("superagent.yaml").to_string_lossy().to_string()),
72            // 5. System config directory
73            Some("/etc/superagent/superagent.yaml".to_string()),
74        ];
75
76        for path_option in possible_paths {
77            if let Some(path) = path_option {
78                if Path::new(&path).exists() {
79                    return path;
80                }
81            }
82        }
83
84        // Default fallback
85        "superagent.yaml".to_string()
86    }
87
88    pub async fn load_config(&mut self) -> Result<()> {
89        if !Path::new(&self.config_path).exists() {
90            return Err(Error::Config(format!("Config file not found at {}", self.config_path)));
91        }
92
93        let config_data = fs::read_to_string(&self.config_path).await?;
94        self.config = Some(serde_yaml::from_str(&config_data)?);
95        Ok(())
96    }
97
98    pub fn get_model_config(&self, model_name: &str) -> ResolvedModelConfig {
99        if let Some(config) = &self.config {
100            // Find model in config
101            if let Some(models) = &config.models {
102                if let Some(model_config) = models.iter().find(|m| m.model_name == model_name) {
103                    return ResolvedModelConfig {
104                        provider: model_config.provider.clone(),
105                        api_base: model_config.api_base.clone(),
106                        model_name: model_config.model_name.clone(),
107                    };
108                }
109            }
110
111            // Return default if model not found
112            if let Some(default) = &config.default {
113                return ResolvedModelConfig {
114                    provider: default.provider.clone(),
115                    api_base: default.api_base.clone(),
116                    model_name: model_name.to_string(), // Keep original model name
117                };
118            }
119        }
120
121        // Ultimate fallback
122        ResolvedModelConfig {
123            provider: "anthropic".to_string(),
124            api_base: "https://api.anthropic.com/".to_string(),
125            model_name: model_name.to_string(),
126        }
127    }
128
129    pub fn get_api_base_for_model(&self, model_name: &str) -> String {
130        let config = self.get_model_config(model_name);
131        config.api_base
132    }
133
134    pub fn is_provider(&self, model_name: &str, provider: &str) -> bool {
135        let config = self.get_model_config(model_name);
136        config.provider == provider
137    }
138
139    pub fn get_telemetry_webhook_config(&self) -> Option<&TelemetryWebhookConfig> {
140        if let Some(config) = &self.config {
141            return config.telemetry_webhook.as_ref();
142        }
143        None
144    }
145}
146
147impl Default for ConfigManager {
148    fn default() -> Self {
149        Self::new()
150    }
151}