Skip to main content

llmkit/
builder.rs

1//! [`LlmClientBuilder`] — compose a provider, fallbacks, and Tower layers.
2
3use std::sync::Arc;
4
5use llmkit_core::{
6    ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
7    LlmResult,
8};
9use llmkit_tower::{FallbackProvider, LlmLayer, SessionCost};
10
11use crate::alias::ModelAliases;
12use crate::tool_loop::ChatBuilder;
13
14/// A boxed layer application: wraps an inner provider, returning a new one.
15type BoxedLayer = Box<dyn FnOnce(Arc<dyn LlmProvider>) -> Arc<dyn LlmProvider>>;
16
17/// Builds an [`LlmClient`] from a primary provider, optional fallbacks, and a
18/// stack of Tower layers.
19///
20/// Layers are applied outermost-first, matching the order they are added: the
21/// first `.layer(...)` call sees the request first and the response last.
22#[derive(Default)]
23pub struct LlmClientBuilder {
24    providers: Vec<Arc<dyn LlmProvider>>,
25    layers: Vec<BoxedLayer>,
26    aliases: ModelAliases,
27    session_cost: Option<SessionCost>,
28}
29
30impl LlmClientBuilder {
31    /// Start a new builder with default model aliases.
32    pub fn new() -> Self {
33        Self { aliases: ModelAliases::with_defaults(), ..Default::default() }
34    }
35
36    /// Set the primary provider.
37    pub fn provider<P: LlmProvider>(mut self, provider: P) -> Self {
38        self.providers.insert(0, Arc::new(provider));
39        self
40    }
41
42    /// Add a fallback provider, tried in registration order after the primary.
43    pub fn fallback<P: LlmProvider>(mut self, provider: P) -> Self {
44        self.providers.push(Arc::new(provider));
45        self
46    }
47
48    /// Add a Tower layer. Outermost layers are added first.
49    pub fn layer<L>(mut self, layer: L) -> Self
50    where
51        L: LlmLayer + 'static,
52        L::Provider: 'static,
53    {
54        self.layers
55            .push(Box::new(move |inner| Arc::new(layer.layer(inner)) as Arc<dyn LlmProvider>));
56        self
57    }
58
59    /// Register a model alias (e.g. `"fast"` → `"gpt-4o-mini"`).
60    pub fn alias(mut self, alias: impl Into<String>, model: impl Into<String>) -> Self {
61        self.aliases.set(alias, model);
62        self
63    }
64
65    /// Track the session-cost handle from a `CostTrackingLayer` so the client
66    /// can report `session_cost_usd()`.
67    pub fn track_cost(mut self, handle: SessionCost) -> Self {
68        self.session_cost = Some(handle);
69        self
70    }
71
72    /// Finish building.
73    pub fn build(self) -> LlmResult<LlmClient> {
74        if self.providers.is_empty() {
75            return Err(LlmError::invalid("LlmClientBuilder requires a provider"));
76        }
77
78        // Base provider: a single provider, or a fallback chain.
79        let base: Arc<dyn LlmProvider> = if self.providers.len() == 1 {
80            self.providers.into_iter().next().unwrap()
81        } else {
82            Arc::new(FallbackProvider::new(self.providers))
83        };
84
85        // Apply layers so the first-added layer is outermost. We fold the layers
86        // in reverse: the last-added wraps the base first, the first-added wraps
87        // last and therefore sits on the outside.
88        let mut provider = base;
89        for layer in self.layers.into_iter().rev() {
90            provider = layer(provider);
91        }
92
93        Ok(LlmClient { provider, aliases: self.aliases, session_cost: self.session_cost })
94    }
95}
96
97/// A composed, ready-to-use LLM client.
98///
99/// Wraps the layered provider stack and applies model-alias resolution to each
100/// request. Clone is cheap (`Arc` internally).
101#[derive(Clone)]
102pub struct LlmClient {
103    provider: Arc<dyn LlmProvider>,
104    aliases: ModelAliases,
105    session_cost: Option<SessionCost>,
106}
107
108impl LlmClient {
109    fn resolve(&self, mut req: ChatRequest) -> ChatRequest {
110        if let Some(model) = &req.model {
111            req.model = Some(self.aliases.resolve(model).to_string());
112        }
113        req
114    }
115
116    /// Begin a chat request, optionally attaching tools via
117    /// [`ChatBuilder::with_tool`]. Await the result to run it.
118    pub fn chat(&self, req: ChatRequest) -> ChatBuilder {
119        ChatBuilder::new(self.provider.clone(), self.resolve(req))
120    }
121
122    /// Single-shot chat with no tool loop.
123    pub async fn chat_once(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
124        self.provider.chat(self.resolve(req)).await
125    }
126
127    /// Streaming chat.
128    pub async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
129        self.provider.chat_stream(self.resolve(req)).await
130    }
131
132    /// Generate embeddings.
133    pub async fn embed(&self, mut req: EmbedRequest) -> LlmResult<EmbedResponse> {
134        if let Some(model) = &req.model {
135            req.model = Some(self.aliases.resolve(model).to_string());
136        }
137        self.provider.embed(req).await
138    }
139
140    /// Cumulative session cost in USD, if a `CostTrackingLayer` handle was
141    /// registered via [`LlmClientBuilder::track_cost`].
142    pub fn session_cost_usd(&self) -> f64 {
143        self.session_cost.as_ref().map(SessionCost::total_usd).unwrap_or(0.0)
144    }
145
146    /// The underlying composed provider.
147    pub fn provider(&self) -> &Arc<dyn LlmProvider> {
148        &self.provider
149    }
150}