llmkit-tower 0.1.0

Tower middleware (retry, rate limit, cost tracking, tracing) for llmkit-rs
Documentation
//! Cost-tracking layer: per-request cost plus a cumulative session total, with
//! an optional budget cap.

use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

use async_trait::async_trait;
use llmkit_core::{
    pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
    LlmError, LlmProvider, LlmResult,
};

use crate::layer::LlmLayer;

/// Shared, cloneable handle to the running session cost (USD), in micro-dollars.
#[derive(Clone, Default)]
pub struct SessionCost(Arc<AtomicU64>);

impl SessionCost {
    fn add(&self, usd: f64) {
        let micros = (usd * 1_000_000.0).round() as u64;
        self.0.fetch_add(micros, Ordering::Relaxed);
    }

    /// Total cost accumulated so far, in USD.
    pub fn total_usd(&self) -> f64 {
        self.0.load(Ordering::Relaxed) as f64 / 1_000_000.0
    }
}

/// Tracks per-request and cumulative cost; optionally enforces a budget.
#[derive(Clone)]
pub struct CostTrackingLayer {
    budget_usd: Option<f64>,
    cost: SessionCost,
}

impl CostTrackingLayer {
    /// Track cost with no spending cap.
    pub fn new() -> Self {
        Self { budget_usd: None, cost: SessionCost::default() }
    }

    /// Track cost and refuse requests once `budget_usd` is reached.
    pub fn with_budget(budget_usd: f64) -> Self {
        Self { budget_usd: Some(budget_usd), cost: SessionCost::default() }
    }

    /// Handle to read the running session cost.
    pub fn session_cost(&self) -> SessionCost {
        self.cost.clone()
    }
}

impl Default for CostTrackingLayer {
    fn default() -> Self {
        Self::new()
    }
}

impl LlmLayer for CostTrackingLayer {
    type Provider = CostTracking;
    fn layer(self, inner: Arc<dyn LlmProvider>) -> CostTracking {
        CostTracking { inner, budget_usd: self.budget_usd, cost: self.cost }
    }
}

/// Provider produced by [`CostTrackingLayer`].
pub struct CostTracking {
    inner: Arc<dyn LlmProvider>,
    budget_usd: Option<f64>,
    cost: SessionCost,
}

impl CostTracking {
    /// Handle to read the running session cost.
    pub fn session_cost(&self) -> SessionCost {
        self.cost.clone()
    }

    fn check_budget(&self) -> LlmResult<()> {
        if let Some(budget) = self.budget_usd {
            let spent = self.cost.total_usd();
            if spent >= budget {
                return Err(LlmError::BudgetExceeded(format!(
                    "session cost ${spent:.4} reached budget ${budget:.2}"
                )));
            }
        }
        Ok(())
    }
}

#[async_trait]
impl LlmProvider for CostTracking {
    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        self.check_budget()?;
        let mut resp = self.inner.chat(req).await?;

        let cost = resp
            .cost
            .or_else(|| pricing::pricing_for(&resp.model).map(|p| p.cost_for(resp.usage)));
        if let Some(c) = cost {
            self.cost.add(c.total_usd());
            resp.cost = Some(c);
            tracing::info!(
                provider = resp.provider,
                model = %resp.model,
                request_usd = c.total_usd(),
                session_usd = self.cost.total_usd(),
                "cost tracked"
            );
        }
        Ok(resp)
    }

    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        self.check_budget()?;
        // Streaming usage arrives in the terminal Done event; we cannot bill it
        // here without consuming the stream, so the budget check is pre-flight.
        self.inner.chat_stream(req).await
    }

    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
        self.check_budget()?;
        let resp = self.inner.embed(req).await?;
        if let Some(p) = pricing::pricing_for(&resp.model) {
            self.cost.add(p.cost_for(resp.usage).total_usd());
        }
        Ok(resp)
    }

    fn name(&self) -> &'static str {
        self.inner.name()
    }

    fn model(&self) -> &str {
        self.inner.model()
    }

    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
        self.inner.estimate_cost(req)
    }
}