ai_proxy/
config.rs

1use serde::{Deserialize, Serialize};
2use std::path::Path;
3use tokio::fs;
4use crate::{Error, Result};
5
6#[derive(Debug, Serialize, Deserialize, Clone)]
7pub struct ModelConfig {
8    pub model_name: String,
9    pub provider: String,
10    pub api_base: String,
11}
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct DefaultConfig {
15    pub provider: String,
16    pub api_base: String,
17}
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct Config {
21    pub models: Option<Vec<ModelConfig>>,
22    pub default: Option<DefaultConfig>,
23}
24
25#[derive(Debug, Clone)]
26pub struct ResolvedModelConfig {
27    pub provider: String,
28    pub api_base: String,
29    pub model_name: String,
30}
31
32pub struct ConfigManager {
33    config: Option<Config>,
34    config_path: String,
35}
36
37impl ConfigManager {
38    pub fn new() -> Self {
39        let config_path = Self::find_config_file();
40        Self {
41            config: None,
42            config_path,
43        }
44    }
45
46    pub fn new_with_path(config_path: Option<String>) -> Self {
47        let config_path = config_path.unwrap_or_else(|| Self::find_config_file());
48        Self {
49            config: None,
50            config_path,
51        }
52    }
53
54    fn find_config_file() -> String {
55        // Try multiple locations in order of preference
56        let possible_paths = vec![
57            // 1. Environment variable
58            std::env::var("VIBEKIT_CONFIG").ok(),
59            // 2. Current working directory
60            Some("vibekit.yaml".to_string()),
61            // 3. Parent directory (for rust/ subdirectory setup)
62            Some("../vibekit.yaml".to_string()),
63            // 4. Home directory
64            dirs::home_dir().map(|home| home.join(".vibekit").join("vibekit.yaml").to_string_lossy().to_string()),
65            // 5. System config directory
66            Some("/etc/vibekit/vibekit.yaml".to_string()),
67        ];
68
69        for path_option in possible_paths {
70            if let Some(path) = path_option {
71                if Path::new(&path).exists() {
72                    return path;
73                }
74            }
75        }
76
77        // Default fallback
78        "vibekit.yaml".to_string()
79    }
80
81    pub async fn load_config(&mut self) -> Result<()> {
82        if !Path::new(&self.config_path).exists() {
83            return Err(Error::Config(format!("Config file not found at {}", self.config_path)));
84        }
85
86        let config_data = fs::read_to_string(&self.config_path).await?;
87        self.config = Some(serde_yaml::from_str(&config_data)?);
88        Ok(())
89    }
90
91    pub fn get_model_config(&self, model_name: &str) -> ResolvedModelConfig {
92        if let Some(config) = &self.config {
93            // Find model in config
94            if let Some(models) = &config.models {
95                if let Some(model_config) = models.iter().find(|m| m.model_name == model_name) {
96                    return ResolvedModelConfig {
97                        provider: model_config.provider.clone(),
98                        api_base: model_config.api_base.clone(),
99                        model_name: model_config.model_name.clone(),
100                    };
101                }
102            }
103
104            // Return default if model not found
105            if let Some(default) = &config.default {
106                return ResolvedModelConfig {
107                    provider: default.provider.clone(),
108                    api_base: default.api_base.clone(),
109                    model_name: model_name.to_string(), // Keep original model name
110                };
111            }
112        }
113
114        // Ultimate fallback
115        ResolvedModelConfig {
116            provider: "anthropic".to_string(),
117            api_base: "https://api.anthropic.com/".to_string(),
118            model_name: model_name.to_string(),
119        }
120    }
121
122    pub fn get_api_base_for_model(&self, model_name: &str) -> String {
123        let config = self.get_model_config(model_name);
124        config.api_base
125    }
126
127    pub fn is_provider(&self, model_name: &str, provider: &str) -> bool {
128        let config = self.get_model_config(model_name);
129        config.provider == provider
130    }
131}
132
133impl Default for ConfigManager {
134    fn default() -> Self {
135        Self::new()
136    }
137}