llmkit-rs 0.1.0

Unified multi-provider async LLM client for Rust — OpenAI, Anthropic, Ollama, with Tower middleware
Documentation
//! [`LlmClientBuilder`] — compose a provider, fallbacks, and Tower layers.

use std::sync::Arc;

use llmkit_core::{
    ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
    LlmResult,
};
use llmkit_tower::{FallbackProvider, LlmLayer, SessionCost};

use crate::alias::ModelAliases;
use crate::tool_loop::ChatBuilder;

/// A boxed layer application: wraps an inner provider, returning a new one.
type BoxedLayer = Box<dyn FnOnce(Arc<dyn LlmProvider>) -> Arc<dyn LlmProvider>>;

/// Builds an [`LlmClient`] from a primary provider, optional fallbacks, and a
/// stack of Tower layers.
///
/// Layers are applied outermost-first, matching the order they are added: the
/// first `.layer(...)` call sees the request first and the response last.
#[derive(Default)]
pub struct LlmClientBuilder {
    providers: Vec<Arc<dyn LlmProvider>>,
    layers: Vec<BoxedLayer>,
    aliases: ModelAliases,
    session_cost: Option<SessionCost>,
}

impl LlmClientBuilder {
    /// Start a new builder with default model aliases.
    pub fn new() -> Self {
        Self { aliases: ModelAliases::with_defaults(), ..Default::default() }
    }

    /// Set the primary provider.
    pub fn provider<P: LlmProvider>(mut self, provider: P) -> Self {
        self.providers.insert(0, Arc::new(provider));
        self
    }

    /// Add a fallback provider, tried in registration order after the primary.
    pub fn fallback<P: LlmProvider>(mut self, provider: P) -> Self {
        self.providers.push(Arc::new(provider));
        self
    }

    /// Add a Tower layer. Outermost layers are added first.
    pub fn layer<L>(mut self, layer: L) -> Self
    where
        L: LlmLayer + 'static,
        L::Provider: 'static,
    {
        self.layers
            .push(Box::new(move |inner| Arc::new(layer.layer(inner)) as Arc<dyn LlmProvider>));
        self
    }

    /// Register a model alias (e.g. `"fast"` → `"gpt-4o-mini"`).
    pub fn alias(mut self, alias: impl Into<String>, model: impl Into<String>) -> Self {
        self.aliases.set(alias, model);
        self
    }

    /// Track the session-cost handle from a `CostTrackingLayer` so the client
    /// can report `session_cost_usd()`.
    pub fn track_cost(mut self, handle: SessionCost) -> Self {
        self.session_cost = Some(handle);
        self
    }

    /// Finish building.
    pub fn build(self) -> LlmResult<LlmClient> {
        if self.providers.is_empty() {
            return Err(LlmError::invalid("LlmClientBuilder requires a provider"));
        }

        // Base provider: a single provider, or a fallback chain.
        let base: Arc<dyn LlmProvider> = if self.providers.len() == 1 {
            self.providers.into_iter().next().unwrap()
        } else {
            Arc::new(FallbackProvider::new(self.providers))
        };

        // Apply layers so the first-added layer is outermost. We fold the layers
        // in reverse: the last-added wraps the base first, the first-added wraps
        // last and therefore sits on the outside.
        let mut provider = base;
        for layer in self.layers.into_iter().rev() {
            provider = layer(provider);
        }

        Ok(LlmClient { provider, aliases: self.aliases, session_cost: self.session_cost })
    }
}

/// A composed, ready-to-use LLM client.
///
/// Wraps the layered provider stack and applies model-alias resolution to each
/// request. Clone is cheap (`Arc` internally).
#[derive(Clone)]
pub struct LlmClient {
    provider: Arc<dyn LlmProvider>,
    aliases: ModelAliases,
    session_cost: Option<SessionCost>,
}

impl LlmClient {
    fn resolve(&self, mut req: ChatRequest) -> ChatRequest {
        if let Some(model) = &req.model {
            req.model = Some(self.aliases.resolve(model).to_string());
        }
        req
    }

    /// Begin a chat request, optionally attaching tools via
    /// [`ChatBuilder::with_tool`]. Await the result to run it.
    pub fn chat(&self, req: ChatRequest) -> ChatBuilder {
        ChatBuilder::new(self.provider.clone(), self.resolve(req))
    }

    /// Single-shot chat with no tool loop.
    pub async fn chat_once(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        self.provider.chat(self.resolve(req)).await
    }

    /// Streaming chat.
    pub async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        self.provider.chat_stream(self.resolve(req)).await
    }

    /// Generate embeddings.
    pub async fn embed(&self, mut req: EmbedRequest) -> LlmResult<EmbedResponse> {
        if let Some(model) = &req.model {
            req.model = Some(self.aliases.resolve(model).to_string());
        }
        self.provider.embed(req).await
    }

    /// Cumulative session cost in USD, if a `CostTrackingLayer` handle was
    /// registered via [`LlmClientBuilder::track_cost`].
    pub fn session_cost_usd(&self) -> f64 {
        self.session_cost.as_ref().map(SessionCost::total_usd).unwrap_or(0.0)
    }

    /// The underlying composed provider.
    pub fn provider(&self) -> &Arc<dyn LlmProvider> {
        &self.provider
    }
}