use std::collections::HashMap;
use serde_json::Value;
use crate::{
codec::BamlEncode,
proto::baml_cffi_v1::{host_map_entry, HostClientProperty, HostClientRegistry, HostMapEntry},
};
#[derive(Debug, Clone)]
struct ClientProperty {
name: String,
provider: String,
retry_policy: Option<String>,
options: HashMap<String, Value>,
}
#[derive(Debug, Clone, Default)]
pub struct ClientRegistry {
primary: Option<String>,
clients: Vec<ClientProperty>,
}
impl ClientRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn add_llm_client(
&mut self,
name: impl Into<String>,
provider: impl Into<String>,
options: HashMap<String, Value>,
) {
self.clients.push(ClientProperty {
name: name.into(),
provider: provider.into(),
retry_policy: None,
options,
});
}
pub fn set_primary_client(&mut self, name: impl Into<String>) {
self.primary = Some(name.into());
}
pub fn is_empty(&self) -> bool {
self.primary.is_none() && self.clients.is_empty()
}
pub(crate) fn encode(&self) -> HostClientRegistry {
let clients = self
.clients
.iter()
.map(|c| {
let options = c
.options
.iter()
.map(|(k, v)| HostMapEntry {
key: Some(host_map_entry::Key::StringKey(k.clone())),
value: Some(v.baml_encode()), })
.collect();
HostClientProperty {
name: c.name.clone(),
provider: c.provider.clone(),
retry_policy: c.retry_policy.clone(),
options,
}
})
.collect();
HostClientRegistry {
primary: self.primary.clone(),
clients,
}
}
}