1use std::sync::Arc;
4
5use llmkit_core::{
6 ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
7 LlmResult,
8};
9use llmkit_tower::{FallbackProvider, LlmLayer, SessionCost};
10
11use crate::alias::ModelAliases;
12use crate::tool_loop::ChatBuilder;
13
14type BoxedLayer = Box<dyn FnOnce(Arc<dyn LlmProvider>) -> Arc<dyn LlmProvider>>;
16
17#[derive(Default)]
23pub struct LlmClientBuilder {
24 providers: Vec<Arc<dyn LlmProvider>>,
25 layers: Vec<BoxedLayer>,
26 aliases: ModelAliases,
27 session_cost: Option<SessionCost>,
28}
29
30impl LlmClientBuilder {
31 pub fn new() -> Self {
33 Self { aliases: ModelAliases::with_defaults(), ..Default::default() }
34 }
35
36 pub fn provider<P: LlmProvider>(mut self, provider: P) -> Self {
38 self.providers.insert(0, Arc::new(provider));
39 self
40 }
41
42 pub fn fallback<P: LlmProvider>(mut self, provider: P) -> Self {
44 self.providers.push(Arc::new(provider));
45 self
46 }
47
48 pub fn layer<L>(mut self, layer: L) -> Self
50 where
51 L: LlmLayer + 'static,
52 L::Provider: 'static,
53 {
54 self.layers
55 .push(Box::new(move |inner| Arc::new(layer.layer(inner)) as Arc<dyn LlmProvider>));
56 self
57 }
58
59 pub fn alias(mut self, alias: impl Into<String>, model: impl Into<String>) -> Self {
61 self.aliases.set(alias, model);
62 self
63 }
64
65 pub fn track_cost(mut self, handle: SessionCost) -> Self {
68 self.session_cost = Some(handle);
69 self
70 }
71
72 pub fn build(self) -> LlmResult<LlmClient> {
74 if self.providers.is_empty() {
75 return Err(LlmError::invalid("LlmClientBuilder requires a provider"));
76 }
77
78 let base: Arc<dyn LlmProvider> = if self.providers.len() == 1 {
80 self.providers.into_iter().next().unwrap()
81 } else {
82 Arc::new(FallbackProvider::new(self.providers))
83 };
84
85 let mut provider = base;
89 for layer in self.layers.into_iter().rev() {
90 provider = layer(provider);
91 }
92
93 Ok(LlmClient { provider, aliases: self.aliases, session_cost: self.session_cost })
94 }
95}
96
97#[derive(Clone)]
102pub struct LlmClient {
103 provider: Arc<dyn LlmProvider>,
104 aliases: ModelAliases,
105 session_cost: Option<SessionCost>,
106}
107
108impl LlmClient {
109 fn resolve(&self, mut req: ChatRequest) -> ChatRequest {
110 if let Some(model) = &req.model {
111 req.model = Some(self.aliases.resolve(model).to_string());
112 }
113 req
114 }
115
116 pub fn chat(&self, req: ChatRequest) -> ChatBuilder {
119 ChatBuilder::new(self.provider.clone(), self.resolve(req))
120 }
121
122 pub async fn chat_once(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
124 self.provider.chat(self.resolve(req)).await
125 }
126
127 pub async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
129 self.provider.chat_stream(self.resolve(req)).await
130 }
131
132 pub async fn embed(&self, mut req: EmbedRequest) -> LlmResult<EmbedResponse> {
134 if let Some(model) = &req.model {
135 req.model = Some(self.aliases.resolve(model).to_string());
136 }
137 self.provider.embed(req).await
138 }
139
140 pub fn session_cost_usd(&self) -> f64 {
143 self.session_cost.as_ref().map(SessionCost::total_usd).unwrap_or(0.0)
144 }
145
146 pub fn provider(&self) -> &Arc<dyn LlmProvider> {
148 &self.provider
149 }
150}