sayr_engine/
config.rs

1use std::env;
2use std::fs;
3use std::path::Path;
4
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AgnoError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10pub struct ServerConfig {
11    pub host: String,
12    pub port: u16,
13    #[serde(default = "default_tls")]
14    pub tls_enabled: bool,
15}
16
17fn default_tls() -> bool {
18    false
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub struct SecurityConfig {
23    #[serde(default)]
24    pub allowed_origins: Vec<String>,
25    #[serde(default)]
26    pub allowed_tenants: Vec<String>,
27    #[serde(default = "default_encryption_required")]
28    pub encryption_required: bool,
29}
30
31impl Default for SecurityConfig {
32    fn default() -> Self {
33        Self {
34            allowed_origins: Vec::new(),
35            allowed_tenants: Vec::new(),
36            encryption_required: default_encryption_required(),
37        }
38    }
39}
40
41fn default_encryption_required() -> bool {
42    true
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct TelemetryConfig {
47    #[serde(default = "default_sample_rate")]
48    pub sample_rate: f32,
49    #[serde(default)]
50    pub endpoint: Option<String>,
51    #[serde(default = "default_retention_hours")]
52    pub retention_hours: u32,
53}
54
55impl Default for TelemetryConfig {
56    fn default() -> Self {
57        Self {
58            sample_rate: default_sample_rate(),
59            endpoint: None,
60            retention_hours: default_retention_hours(),
61        }
62    }
63}
64
65fn default_sample_rate() -> f32 {
66    1.0
67}
68
69fn default_retention_hours() -> u32 {
70    72
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
74pub struct DeploymentConfig {
75    #[serde(default = "default_replicas")]
76    pub replicas: u16,
77    #[serde(default = "default_max_concurrency")]
78    pub max_concurrency: u32,
79    #[serde(default)]
80    pub autoscale: bool,
81    #[serde(default)]
82    pub container_image: Option<String>,
83}
84
85impl Default for DeploymentConfig {
86    fn default() -> Self {
87        Self {
88            replicas: default_replicas(),
89            max_concurrency: default_max_concurrency(),
90            autoscale: false,
91            container_image: None,
92        }
93    }
94}
95
96fn default_replicas() -> u16 {
97    1
98}
99
100fn default_max_concurrency() -> u32 {
101    32
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
105pub struct ModelConfig {
106    pub provider: String,
107    pub model: String,
108    #[serde(default)]
109    pub api_key: Option<String>,
110    #[serde(default)]
111    pub base_url: Option<String>,
112    #[serde(default)]
113    pub organization: Option<String>,
114    #[serde(default)]
115    pub stream: bool,
116    #[serde(default)]
117    pub openai: ProviderConfig,
118    #[serde(default)]
119    pub anthropic: ProviderConfig,
120    #[serde(default)]
121    pub gemini: ProviderConfig,
122    #[serde(default)]
123    pub cohere: ProviderConfig,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
127pub struct ProviderConfig {
128    #[serde(default)]
129    pub api_key: Option<String>,
130    #[serde(default)]
131    pub endpoint: Option<String>,
132    #[serde(default)]
133    pub organization: Option<String>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum StorageBackend {
139    File,
140    Sqlite,
141}
142
143impl Default for StorageBackend {
144    fn default() -> Self {
145        StorageBackend::File
146    }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
150pub struct StorageConfig {
151    #[serde(default)]
152    pub backend: StorageBackend,
153    #[serde(default = "default_storage_path")]
154    pub file_path: String,
155    #[serde(default)]
156    pub database_url: Option<String>,
157}
158
159impl Default for StorageConfig {
160    fn default() -> Self {
161        Self {
162            backend: StorageBackend::default(),
163            file_path: default_storage_path(),
164            database_url: None,
165        }
166    }
167}
168
169fn default_storage_path() -> String {
170    "conversation.jsonl".into()
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
174pub struct AppConfig {
175    pub server: ServerConfig,
176    #[serde(default)]
177    pub security: SecurityConfig,
178    #[serde(default)]
179    pub telemetry: TelemetryConfig,
180    #[serde(default)]
181    pub deployment: DeploymentConfig,
182    pub model: ModelConfig,
183    #[serde(default)]
184    pub storage: StorageConfig,
185}
186
187impl Default for AppConfig {
188    fn default() -> Self {
189        Self {
190            server: ServerConfig {
191                host: "0.0.0.0".into(),
192                port: 8080,
193                tls_enabled: default_tls(),
194            },
195            security: SecurityConfig {
196                allowed_origins: vec![],
197                allowed_tenants: vec![],
198                encryption_required: default_encryption_required(),
199            },
200            telemetry: TelemetryConfig {
201                sample_rate: default_sample_rate(),
202                endpoint: None,
203                retention_hours: default_retention_hours(),
204            },
205            deployment: DeploymentConfig {
206                replicas: default_replicas(),
207                max_concurrency: default_max_concurrency(),
208                autoscale: false,
209                container_image: None,
210            },
211            model: ModelConfig {
212                provider: "stub".into(),
213                model: "stub-model".into(),
214                api_key: None,
215                base_url: None,
216                organization: None,
217                stream: false,
218                openai: ProviderConfig::default(),
219                anthropic: ProviderConfig::default(),
220                gemini: ProviderConfig::default(),
221                cohere: ProviderConfig::default(),
222            },
223            storage: StorageConfig::default(),
224        }
225    }
226}
227
228impl AppConfig {
229    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
230        let raw = fs::read_to_string(path)?;
231        let cfg: Self = toml::from_str(&raw)
232            .map_err(|err| AgnoError::Protocol(format!("Failed to parse configuration: {err}")))?;
233        Ok(cfg)
234    }
235
236    pub fn from_env_or_file(path: impl AsRef<Path>) -> Result<Self> {
237        let mut cfg = Self::from_file(path)?;
238        if let Ok(host) = env::var("AGNO_HOST") {
239            cfg.server.host = host;
240        }
241        if let Ok(port) = env::var("AGNO_PORT") {
242            if let Ok(parsed) = port.parse::<u16>() {
243                cfg.server.port = parsed;
244            }
245        }
246        if let Ok(key) = env::var("AGNO_API_KEY") {
247            cfg.model.api_key = Some(key);
248        }
249        if let Ok(openai_key) = env::var("AGNO_OPENAI_API_KEY") {
250            cfg.model.openai.api_key = Some(openai_key);
251        }
252        if let Ok(openai_endpoint) = env::var("AGNO_OPENAI_ENDPOINT") {
253            cfg.model.openai.endpoint = Some(openai_endpoint);
254        }
255        if let Ok(openai_org) = env::var("AGNO_OPENAI_ORG") {
256            cfg.model.openai.organization = Some(openai_org);
257        }
258        if let Ok(anthropic_key) = env::var("AGNO_ANTHROPIC_API_KEY") {
259            cfg.model.anthropic.api_key = Some(anthropic_key);
260        }
261        if let Ok(anthropic_endpoint) = env::var("AGNO_ANTHROPIC_ENDPOINT") {
262            cfg.model.anthropic.endpoint = Some(anthropic_endpoint);
263        }
264        if let Ok(gemini_key) = env::var("AGNO_GEMINI_API_KEY") {
265            cfg.model.gemini.api_key = Some(gemini_key);
266        }
267        if let Ok(gemini_endpoint) = env::var("AGNO_GEMINI_ENDPOINT") {
268            cfg.model.gemini.endpoint = Some(gemini_endpoint);
269        }
270        if let Ok(cohere_key) = env::var("AGNO_COHERE_API_KEY") {
271            cfg.model.cohere.api_key = Some(cohere_key);
272        }
273        if let Ok(cohere_endpoint) = env::var("AGNO_COHERE_ENDPOINT") {
274            cfg.model.cohere.endpoint = Some(cohere_endpoint);
275        }
276        if let Ok(stream) = env::var("AGNO_STREAMING") {
277            if let Ok(parsed) = stream.parse::<bool>() {
278                cfg.model.stream = parsed;
279            }
280        }
281        if let Ok(sample) = env::var("AGNO_TELEMETRY_SAMPLE") {
282            if let Ok(parsed) = sample.parse::<f32>() {
283                cfg.telemetry.sample_rate = parsed.clamp(0.01, 1.0);
284            }
285        }
286        if let Ok(backend) = env::var("AGNO_STORAGE_BACKEND") {
287            cfg.storage.backend = match backend.to_ascii_lowercase().as_str() {
288                "sqlite" => StorageBackend::Sqlite,
289                _ => StorageBackend::File,
290            };
291        }
292        if let Ok(path) = env::var("AGNO_STORAGE_PATH") {
293            cfg.storage.file_path = path;
294        }
295        if let Ok(url) = env::var("AGNO_DATABASE_URL") {
296            cfg.storage.database_url = Some(url);
297        }
298        Ok(cfg)
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use std::env;
306    use std::io::Write;
307    use tempfile::NamedTempFile;
308
309    #[test]
310    fn loads_and_overrides() {
311        let mut file = NamedTempFile::new().unwrap();
312        writeln!(
313            file,
314            "[server]\nhost='127.0.0.1'\nport=9000\n[model]\nprovider='openai'\nmodel='gpt-4'"
315        )
316        .unwrap();
317
318        env::set_var("AGNO_PORT", "9100");
319        let cfg = AppConfig::from_env_or_file(file.path()).unwrap();
320
321        assert_eq!(cfg.server.port, 9100);
322        assert_eq!(cfg.server.host, "127.0.0.1");
323        assert_eq!(cfg.model.provider, "openai");
324        env::remove_var("AGNO_PORT");
325    }
326
327    #[test]
328    fn overrides_storage_backend() {
329        let mut file = NamedTempFile::new().unwrap();
330        writeln!(
331            file,
332            "[server]\nhost='127.0.0.1'\nport=9000\n[model]\nprovider='openai'\nmodel='gpt-4'\n[storage]\nbackend='file'\nfile_path='transcript.jsonl'"
333        )
334        .unwrap();
335
336        env::set_var("AGNO_STORAGE_BACKEND", "sqlite");
337        env::set_var("AGNO_DATABASE_URL", "sqlite::memory:");
338        let cfg = AppConfig::from_env_or_file(file.path()).unwrap();
339
340        assert_eq!(cfg.storage.backend, StorageBackend::Sqlite);
341        assert_eq!(
342            cfg.storage.database_url,
343            Some("sqlite::memory:".to_string())
344        );
345
346        env::remove_var("AGNO_STORAGE_BACKEND");
347        env::remove_var("AGNO_DATABASE_URL");
348    }
349}