Skip to main content

baml_agent/
engine.rs

1use crate::config::AgentConfig;
2use std::collections::HashMap;
3
4/// Trait abstracting BAML's generated `ClientRegistry`.
5///
6/// Each BAML project generates its own `ClientRegistry` type, but the API
7/// is identical. Implement this trait on a newtype wrapper around your
8/// project's `ClientRegistry`.
9///
10/// ```ignore
11/// struct MyRegistry(baml_client::ClientRegistry);
12///
13/// impl BamlRegistry for MyRegistry {
14///     fn new() -> Self { Self(baml_client::ClientRegistry::new()) }
15///     fn add_llm_client(&mut self, name: &str, provider_type: &str, options: HashMap<String, serde_json::Value>) {
16///         self.0.add_llm_client(name, provider_type, options);
17///     }
18///     fn set_primary_client(&mut self, name: &str) { self.0.set_primary_client(name); }
19///     fn into_inner(self) -> baml_client::ClientRegistry { self.0 }
20/// }
21/// ```
22pub trait BamlRegistry: Sized {
23    fn new() -> Self;
24    fn add_llm_client(
25        &mut self,
26        name: &str,
27        provider_type: &str,
28        options: HashMap<String, serde_json::Value>,
29    );
30    fn set_primary_client(&mut self, name: &str);
31}
32
33/// Generic engine that builds a `BamlRegistry` from `AgentConfig`.
34pub struct AgentEngine {
35    config: AgentConfig,
36}
37
38impl AgentEngine {
39    pub fn new(config: AgentConfig) -> Self {
40        Self { config }
41    }
42
43    /// Build a BAML ClientRegistry from the agent config.
44    ///
45    /// Iterates all providers, sets options (model, base_url, location,
46    /// project_id, api_key), and sets the primary client.
47    pub fn build_registry<R: BamlRegistry>(&self) -> Result<R, String> {
48        if !self
49            .config
50            .providers
51            .contains_key(&self.config.default_provider)
52        {
53            return Err(format!(
54                "default provider '{}' is not configured",
55                self.config.default_provider
56            ));
57        }
58
59        let mut registry = R::new();
60
61        for (name, conf) in &self.config.providers {
62            let mut options: HashMap<String, serde_json::Value> = HashMap::new();
63            options.insert("model".into(), serde_json::json!(conf.model));
64
65            if let Some(url) = &conf.base_url {
66                options.insert("base_url".into(), serde_json::json!(url));
67            }
68            if let Some(loc) = &conf.location {
69                options.insert("location".into(), serde_json::json!(loc));
70            }
71            if let Some(pid) = &conf.project_id {
72                options.insert("project_id".into(), serde_json::json!(pid));
73            }
74            if let Some(env_var) = &conf.api_key_env_var {
75                options.insert(
76                    "api_key".into(),
77                    serde_json::json!(format!("env.{}", env_var)),
78                );
79            }
80
81            registry.add_llm_client(name, &conf.provider_type, options);
82        }
83
84        registry.set_primary_client(&self.config.default_provider);
85        Ok(registry)
86    }
87
88    pub fn config(&self) -> &AgentConfig {
89        &self.config
90    }
91}