Skip to main content

baml_agent/
engine.rs

1use crate::config::AgentConfig;
2use std::collections::HashMap;
3
4/// Trait abstracting BAML's generated `ClientRegistry`.
5///
6/// Each BAML project generates its own `ClientRegistry` type, but the API
7/// is identical. Implement this trait on a newtype wrapper around your
8/// project's `ClientRegistry`.
9///
10/// ```ignore
11/// struct MyRegistry(baml_client::ClientRegistry);
12///
13/// impl BamlRegistry for MyRegistry {
14///     fn new() -> Self { Self(baml_client::ClientRegistry::new()) }
15///     fn add_llm_client(&mut self, name: &str, provider_type: &str, options: HashMap<String, serde_json::Value>) {
16///         self.0.add_llm_client(name, provider_type, options);
17///     }
18///     fn set_primary_client(&mut self, name: &str) { self.0.set_primary_client(name); }
19///     fn into_inner(self) -> baml_client::ClientRegistry { self.0 }
20/// }
21/// ```
22pub trait BamlRegistry: Sized {
23    fn new() -> Self;
24    fn add_llm_client(
25        &mut self,
26        name: &str,
27        provider_type: &str,
28        options: HashMap<String, serde_json::Value>,
29    );
30    fn set_primary_client(&mut self, name: &str);
31}
32
33/// Generic engine that builds a `BamlRegistry` from `AgentConfig`.
34pub struct AgentEngine {
35    config: AgentConfig,
36}
37
38impl AgentEngine {
39    pub fn new(config: AgentConfig) -> Self {
40        Self { config }
41    }
42
43    /// Build a BAML ClientRegistry from the agent config.
44    ///
45    /// Iterates all providers, sets options (model, base_url, location,
46    /// project_id, api_key), and sets the primary client.
47    pub fn build_registry<R: BamlRegistry>(&self) -> Result<R, String> {
48        if !self.config.providers.contains_key(&self.config.default_provider) {
49            return Err(format!(
50                "default provider '{}' is not configured",
51                self.config.default_provider
52            ));
53        }
54
55        let mut registry = R::new();
56
57        for (name, conf) in &self.config.providers {
58            let mut options: HashMap<String, serde_json::Value> = HashMap::new();
59            options.insert("model".into(), serde_json::json!(conf.model));
60
61            if let Some(url) = &conf.base_url {
62                options.insert("base_url".into(), serde_json::json!(url));
63            }
64            if let Some(loc) = &conf.location {
65                options.insert("location".into(), serde_json::json!(loc));
66            }
67            if let Some(pid) = &conf.project_id {
68                options.insert("project_id".into(), serde_json::json!(pid));
69            }
70            if let Some(env_var) = &conf.api_key_env_var {
71                options.insert("api_key".into(), serde_json::json!(format!("env.{}", env_var)));
72            }
73
74            registry.add_llm_client(name, &conf.provider_type, options);
75        }
76
77        registry.set_primary_client(&self.config.default_provider);
78        Ok(registry)
79    }
80
81    pub fn config(&self) -> &AgentConfig {
82        &self.config
83    }
84}