Skip to main content

baml_agent/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum AgentConfigError {
7    #[error("Missing env var: {0}")]
8    MissingEnvVar(String),
9    #[error("Provider not found: {0}")]
10    ProviderNotFound(String),
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ProviderConfig {
15    pub provider_type: String, // "vertex-ai", "google-ai", "openai-generic"
16    pub model: String,
17    pub api_key_env_var: Option<String>,
18    pub base_url: Option<String>,
19    pub location: Option<String>,
20    pub project_id: Option<String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AgentConfig {
25    pub default_provider: String,
26    pub providers: HashMap<String, ProviderConfig>,
27}
28
29impl AgentConfig {
30    /// Create config with Vertex AI defaults from environment.
31    ///
32    /// Reads `GOOGLE_CLOUD_PROJECT` env var. Returns error if missing.
33    pub fn vertex_from_env() -> Result<Self, AgentConfigError> {
34        let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
35            .ok()
36            .filter(|v| !v.trim().is_empty())
37            .ok_or_else(|| AgentConfigError::MissingEnvVar("GOOGLE_CLOUD_PROJECT".into()))?;
38
39        let mut providers = HashMap::new();
40
41        providers.insert(
42            "vertex".into(),
43            ProviderConfig {
44                provider_type: "vertex-ai".into(),
45                model: "gemini-3.1-flash-lite-preview".into(),
46                api_key_env_var: None,
47                base_url: None,
48                location: Some("global".into()),
49                project_id: Some(project_id.clone()),
50            },
51        );
52
53        providers.insert(
54            "vertex_fallback".into(),
55            ProviderConfig {
56                provider_type: "vertex-ai".into(),
57                model: "gemini-3-flash-preview".into(),
58                api_key_env_var: None,
59                base_url: None,
60                location: Some("global".into()),
61                project_id: Some(project_id),
62            },
63        );
64
65        providers.insert(
66            "local".into(),
67            ProviderConfig {
68                provider_type: "openai-generic".into(),
69                model: "llama3.2".into(),
70                api_key_env_var: None,
71                base_url: Some("http://localhost:11434/v1".into()),
72                location: None,
73                project_id: None,
74            },
75        );
76
77        Ok(Self {
78            default_provider: "vertex".into(),
79            providers,
80        })
81    }
82
83    /// Add or replace a provider.
84    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
85        self.providers.insert(name.into(), config);
86    }
87
88    /// Set Vertex project_id on all vertex-ai providers.
89    pub fn set_vertex_project(&mut self, project_id: &str) {
90        for p in self.providers.values_mut() {
91            if p.provider_type == "vertex-ai" {
92                p.project_id = Some(project_id.into());
93                if p.location.is_none() {
94                    p.location = Some("global".into());
95                }
96            }
97        }
98    }
99}