llmkit-tower 0.1.0

Tower middleware (retry, rate limit, cost tracking, tracing) for llmkit-rs
Documentation
//! Token-bucket rate limiting, per provider.

use std::sync::Arc;
use std::time::{Duration, Instant};

use async_trait::async_trait;
use llmkit_core::{
    ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmProvider,
    LlmResult,
};
use tokio::sync::Mutex;

use crate::layer::LlmLayer;

/// Token-bucket rate limiter: `capacity` tokens refilling over `window`.
#[derive(Debug, Clone, Copy)]
pub struct RateLimitLayer {
    capacity: f64,
    refill_per_sec: f64,
}

impl RateLimitLayer {
    /// `capacity` tokens that fully refill over `window`.
    ///
    /// e.g. `token_bucket(60_000, Duration::from_secs(60))` ≈ 60k tokens/min.
    pub fn token_bucket(capacity: u64, window: Duration) -> Self {
        let secs = window.as_secs_f64().max(f64::MIN_POSITIVE);
        Self { capacity: capacity as f64, refill_per_sec: capacity as f64 / secs }
    }
}

impl LlmLayer for RateLimitLayer {
    type Provider = RateLimit;
    fn layer(self, inner: Arc<dyn LlmProvider>) -> RateLimit {
        RateLimit {
            inner,
            bucket: Arc::new(Mutex::new(Bucket {
                tokens: self.capacity,
                capacity: self.capacity,
                refill_per_sec: self.refill_per_sec,
                last: Instant::now(),
            })),
        }
    }
}

struct Bucket {
    tokens: f64,
    capacity: f64,
    refill_per_sec: f64,
    last: Instant,
}

impl Bucket {
    fn refill(&mut self) {
        let now = Instant::now();
        let elapsed = now.duration_since(self.last).as_secs_f64();
        self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity);
        self.last = now;
    }

    /// Time to wait until `cost` tokens are available, then deduct them.
    fn time_until_available(&mut self, cost: f64) -> Duration {
        self.refill();
        let cost = cost.min(self.capacity);
        if self.tokens >= cost {
            self.tokens -= cost;
            Duration::ZERO
        } else {
            let deficit = cost - self.tokens;
            let wait = deficit / self.refill_per_sec;
            self.tokens = 0.0;
            // The deficit is paid off by the future refill we're waiting for.
            Duration::from_secs_f64(wait)
        }
    }
}

/// Provider produced by [`RateLimitLayer`].
pub struct RateLimit {
    inner: Arc<dyn LlmProvider>,
    bucket: Arc<Mutex<Bucket>>,
}

impl RateLimit {
    /// Estimate the token cost of a request for bucket accounting.
    fn estimated_cost(req: &ChatRequest) -> f64 {
        let chars: usize = req
            .messages
            .iter()
            .filter_map(|m| m.content.as_text())
            .map(|t| t.len())
            .sum::<usize>()
            + req.system.as_deref().map(str::len).unwrap_or(0);
        let prompt = (chars / 4) as f64;
        prompt + req.max_tokens.unwrap_or(256) as f64
    }

    async fn acquire(&self, cost: f64) {
        let wait = {
            let mut bucket = self.bucket.lock().await;
            bucket.time_until_available(cost)
        };
        if !wait.is_zero() {
            tokio::time::sleep(wait).await;
        }
    }
}

#[async_trait]
impl LlmProvider for RateLimit {
    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        self.acquire(Self::estimated_cost(&req)).await;
        self.inner.chat(req).await
    }

    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        self.acquire(Self::estimated_cost(&req)).await;
        self.inner.chat_stream(req).await
    }

    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
        let cost: f64 = req.input.iter().map(|s| (s.len() / 4) as f64).sum();
        self.acquire(cost).await;
        self.inner.embed(req).await
    }

    fn name(&self) -> &'static str {
        self.inner.name()
    }

    fn model(&self) -> &str {
        self.inner.model()
    }

    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
        self.inner.estimate_cost(req)
    }
}