1use serde::{Deserialize, Serialize};
2use std::path::Path;
3use tokio::fs;
4use crate::{Error, Result};
5
6#[derive(Debug, Serialize, Deserialize, Clone)]
7pub struct ModelConfig {
8 pub model_name: String,
9 pub provider: String,
10 pub api_base: String,
11}
12
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct DefaultConfig {
15 pub provider: String,
16 pub api_base: String,
17}
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct TelemetryWebhookConfig {
21 pub url: String,
22 pub headers: std::collections::HashMap<String, String>,
23}
24
25#[derive(Debug, Serialize, Deserialize, Clone)]
26pub struct Config {
27 pub models: Option<Vec<ModelConfig>>,
28 pub default: Option<DefaultConfig>,
29 pub telemetry_webhook: Option<TelemetryWebhookConfig>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ResolvedModelConfig {
34 pub provider: String,
35 pub api_base: String,
36 pub model_name: String,
37}
38
39pub struct ConfigManager {
40 config: Option<Config>,
41 config_path: String,
42}
43
44impl ConfigManager {
45 pub fn new() -> Self {
46 let config_path = Self::find_config_file();
47 Self {
48 config: None,
49 config_path,
50 }
51 }
52
53 pub fn new_with_path(config_path: Option<String>) -> Self {
54 let config_path = config_path.unwrap_or_else(|| Self::find_config_file());
55 Self {
56 config: None,
57 config_path,
58 }
59 }
60
61 fn find_config_file() -> String {
62 let possible_paths = vec![
64 std::env::var("SUPERAGENT_CONFIG").ok(),
66 Some("superagent.yaml".to_string()),
68 Some("../superagent.yaml".to_string()),
70 dirs::home_dir().map(|home| home.join(".superagent").join("superagent.yaml").to_string_lossy().to_string()),
72 Some("/etc/superagent/superagent.yaml".to_string()),
74 ];
75
76 for path_option in possible_paths {
77 if let Some(path) = path_option {
78 if Path::new(&path).exists() {
79 return path;
80 }
81 }
82 }
83
84 "superagent.yaml".to_string()
86 }
87
88 pub async fn load_config(&mut self) -> Result<()> {
89 if !Path::new(&self.config_path).exists() {
90 return Err(Error::Config(format!("Config file not found at {}", self.config_path)));
91 }
92
93 let config_data = fs::read_to_string(&self.config_path).await?;
94 self.config = Some(serde_yaml::from_str(&config_data)?);
95 Ok(())
96 }
97
98 pub fn get_model_config(&self, model_name: &str) -> ResolvedModelConfig {
99 if let Some(config) = &self.config {
100 if let Some(models) = &config.models {
102 if let Some(model_config) = models.iter().find(|m| m.model_name == model_name) {
103 return ResolvedModelConfig {
104 provider: model_config.provider.clone(),
105 api_base: model_config.api_base.clone(),
106 model_name: model_config.model_name.clone(),
107 };
108 }
109 }
110
111 if let Some(default) = &config.default {
113 return ResolvedModelConfig {
114 provider: default.provider.clone(),
115 api_base: default.api_base.clone(),
116 model_name: model_name.to_string(), };
118 }
119 }
120
121 ResolvedModelConfig {
123 provider: "anthropic".to_string(),
124 api_base: "https://api.anthropic.com/".to_string(),
125 model_name: model_name.to_string(),
126 }
127 }
128
129 pub fn get_api_base_for_model(&self, model_name: &str) -> String {
130 let config = self.get_model_config(model_name);
131 config.api_base
132 }
133
134 pub fn is_provider(&self, model_name: &str, provider: &str) -> bool {
135 let config = self.get_model_config(model_name);
136 config.provider == provider
137 }
138
139 pub fn get_telemetry_webhook_config(&self) -> Option<&TelemetryWebhookConfig> {
140 if let Some(config) = &self.config {
141 return config.telemetry_webhook.as_ref();
142 }
143 None
144 }
145}
146
147impl Default for ConfigManager {
148 fn default() -> Self {
149 Self::new()
150 }
151}