use std::sync::Arc;
use llmkit_core::{
ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
LlmResult,
};
use llmkit_tower::{FallbackProvider, LlmLayer, SessionCost};
use crate::alias::ModelAliases;
use crate::tool_loop::ChatBuilder;
type BoxedLayer = Box<dyn FnOnce(Arc<dyn LlmProvider>) -> Arc<dyn LlmProvider>>;
#[derive(Default)]
pub struct LlmClientBuilder {
providers: Vec<Arc<dyn LlmProvider>>,
layers: Vec<BoxedLayer>,
aliases: ModelAliases,
session_cost: Option<SessionCost>,
}
impl LlmClientBuilder {
pub fn new() -> Self {
Self { aliases: ModelAliases::with_defaults(), ..Default::default() }
}
pub fn provider<P: LlmProvider>(mut self, provider: P) -> Self {
self.providers.insert(0, Arc::new(provider));
self
}
pub fn fallback<P: LlmProvider>(mut self, provider: P) -> Self {
self.providers.push(Arc::new(provider));
self
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: LlmLayer + 'static,
L::Provider: 'static,
{
self.layers
.push(Box::new(move |inner| Arc::new(layer.layer(inner)) as Arc<dyn LlmProvider>));
self
}
pub fn alias(mut self, alias: impl Into<String>, model: impl Into<String>) -> Self {
self.aliases.set(alias, model);
self
}
pub fn track_cost(mut self, handle: SessionCost) -> Self {
self.session_cost = Some(handle);
self
}
pub fn build(self) -> LlmResult<LlmClient> {
if self.providers.is_empty() {
return Err(LlmError::invalid("LlmClientBuilder requires a provider"));
}
let base: Arc<dyn LlmProvider> = if self.providers.len() == 1 {
self.providers.into_iter().next().unwrap()
} else {
Arc::new(FallbackProvider::new(self.providers))
};
let mut provider = base;
for layer in self.layers.into_iter().rev() {
provider = layer(provider);
}
Ok(LlmClient { provider, aliases: self.aliases, session_cost: self.session_cost })
}
}
#[derive(Clone)]
pub struct LlmClient {
provider: Arc<dyn LlmProvider>,
aliases: ModelAliases,
session_cost: Option<SessionCost>,
}
impl LlmClient {
fn resolve(&self, mut req: ChatRequest) -> ChatRequest {
if let Some(model) = &req.model {
req.model = Some(self.aliases.resolve(model).to_string());
}
req
}
pub fn chat(&self, req: ChatRequest) -> ChatBuilder {
ChatBuilder::new(self.provider.clone(), self.resolve(req))
}
pub async fn chat_once(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
self.provider.chat(self.resolve(req)).await
}
pub async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
self.provider.chat_stream(self.resolve(req)).await
}
pub async fn embed(&self, mut req: EmbedRequest) -> LlmResult<EmbedResponse> {
if let Some(model) = &req.model {
req.model = Some(self.aliases.resolve(model).to_string());
}
self.provider.embed(req).await
}
pub fn session_cost_usd(&self) -> f64 {
self.session_cost.as_ref().map(SessionCost::total_usd).unwrap_or(0.0)
}
pub fn provider(&self) -> &Arc<dyn LlmProvider> {
&self.provider
}
}