Skip to main content

baml_agent/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum AgentConfigError {
7    #[error("Missing env var: {0}")]
8    MissingEnvVar(String),
9    #[error("Provider not found: {0}")]
10    ProviderNotFound(String),
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ProviderConfig {
15    pub provider_type: String, // "vertex-ai", "google-ai", "openai-generic"
16    pub model: String,
17    pub api_key_env_var: Option<String>,
18    pub base_url: Option<String>,
19    pub location: Option<String>,
20    pub project_id: Option<String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AgentConfig {
25    pub default_provider: String,
26    pub providers: HashMap<String, ProviderConfig>,
27}
28
29impl AgentConfig {
30    /// Create config with Vertex AI defaults from environment.
31    ///
32    /// Reads `GOOGLE_CLOUD_PROJECT` env var. Returns error if missing.
33    pub fn vertex_from_env() -> Result<Self, AgentConfigError> {
34        let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
35            .ok()
36            .filter(|v| !v.trim().is_empty())
37            .ok_or_else(|| AgentConfigError::MissingEnvVar("GOOGLE_CLOUD_PROJECT".into()))?;
38
39        let mut providers = HashMap::new();
40
41        providers.insert("vertex".into(), ProviderConfig {
42            provider_type: "vertex-ai".into(),
43            model: "gemini-3.1-flash-lite-preview".into(),
44            api_key_env_var: None,
45            base_url: None,
46            location: Some("global".into()),
47            project_id: Some(project_id.clone()),
48        });
49
50        providers.insert("vertex_fallback".into(), ProviderConfig {
51            provider_type: "vertex-ai".into(),
52            model: "gemini-3-flash-preview".into(),
53            api_key_env_var: None,
54            base_url: None,
55            location: Some("global".into()),
56            project_id: Some(project_id),
57        });
58
59        providers.insert("local".into(), ProviderConfig {
60            provider_type: "openai-generic".into(),
61            model: "llama3.2".into(),
62            api_key_env_var: None,
63            base_url: Some("http://localhost:11434/v1".into()),
64            location: None,
65            project_id: None,
66        });
67
68        Ok(Self {
69            default_provider: "vertex".into(),
70            providers,
71        })
72    }
73
74    /// Add or replace a provider.
75    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
76        self.providers.insert(name.into(), config);
77    }
78
79    /// Set Vertex project_id on all vertex-ai providers.
80    pub fn set_vertex_project(&mut self, project_id: &str) {
81        for p in self.providers.values_mut() {
82            if p.provider_type == "vertex-ai" {
83                p.project_id = Some(project_id.into());
84                if p.location.is_none() {
85                    p.location = Some("global".into());
86                }
87            }
88        }
89    }
90}