klieo-core 0.3.0

Core traits + runtime for the klieo agent framework.
Documentation
//! Client-side token-bucket rate limiter for any [`LlmClient`].
//!
//! Wraps an existing [`LlmClient`] with per-instance request-per-minute (RPM)
//! and optional tokens-per-minute (TPM) gates. Both gates use the
//! token-bucket algorithm: capacity equal to the per-minute limit, refilled
//! at `limit / 60.0` units per second up to capacity.
//!
//! The wrapper itself implements [`LlmClient`] and forwards every call —
//! `name`, `capabilities`, `complete`, `stream`, `embed` — to the inner
//! client. RPM is consumed once per call; TPM (if enabled) is consumed by
//! `complete` and `stream` based on a coarse char-length token estimate
//! (`prompt chars / 4 + max_tokens.unwrap_or(256)`). `embed` is RPM-only:
//! providers price embeddings separately and the per-text token count is
//! out of scope for this primitive.
//!
//! ## Why client-side
//!
//! Server-side 429s already exist; this wrapper exists to keep us *under*
//! the published limit so we do not depend on retry-with-backoff alone for
//! steady-state throughput. The runtime's
//! [`crate::runtime::run_steps`] retry loop still handles incidental 429s.
//!
//! ## Example
//!
//! ```
//! # tokio_test::block_on(async {
//! use klieo_core::rate_limit::{RateLimitConfig, RateLimitedClient};
//! use klieo_core::test_utils::{FakeLlmClient, FakeLlmStep};
//! use klieo_core::{ChatRequest, LlmClient};
//! use std::sync::Arc;
//!
//! let inner: Arc<dyn LlmClient> = Arc::new(
//!     FakeLlmClient::new("ollama")
//!         .with_steps(vec![FakeLlmStep::Text("hi".into())]),
//! );
//! let limited = RateLimitedClient::new(
//!     inner,
//!     RateLimitConfig::new(60).with_tokens_per_minute(60_000),
//! );
//! let resp = limited.complete(ChatRequest::new(vec![])).await.unwrap();
//! assert_eq!(resp.message.content, "hi");
//! # });
//! ```

use crate::error::LlmError;
use crate::llm::{Capabilities, ChatRequest, ChatResponse, ChunkStream, Embedding, LlmClient};
use async_trait::async_trait;
use parking_lot::Mutex;
use std::sync::Arc;
use tokio::time::{Duration, Instant};

/// Configuration for [`RateLimitedClient`].
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
    /// Maximum requests per minute. Bucket capacity equals this value;
    /// refill rate is `requests_per_minute / 60.0` per second.
    pub requests_per_minute: u32,
    /// Optional maximum tokens per minute. When `None`, only the RPM gate
    /// applies. Set this for providers that meter by tokens (OpenAI,
    /// Anthropic); leave `None` for providers that do not (e.g. Ollama).
    pub tokens_per_minute: Option<u32>,
}

impl RateLimitConfig {
    /// Build a config with only an RPM limit.
    pub fn new(requests_per_minute: u32) -> Self {
        Self {
            requests_per_minute,
            tokens_per_minute: None,
        }
    }

    /// Add a TPM gate to an existing config.
    pub fn with_tokens_per_minute(mut self, tokens_per_minute: u32) -> Self {
        self.tokens_per_minute = Some(tokens_per_minute);
        self
    }
}

#[derive(Debug)]
struct Bucket {
    tokens: f64,
    capacity: f64,
    refill_per_sec: f64,
    last_refill: Instant,
}

impl Bucket {
    fn new(capacity: u32) -> Self {
        let cap = f64::from(capacity.max(1));
        Self {
            tokens: cap,
            capacity: cap,
            refill_per_sec: cap / 60.0,
            last_refill: Instant::now(),
        }
    }

    fn refill(&mut self) {
        let now = Instant::now();
        let elapsed = now
            .saturating_duration_since(self.last_refill)
            .as_secs_f64();
        if elapsed > 0.0 {
            let added = elapsed * self.refill_per_sec;
            self.tokens = (self.tokens + added).min(self.capacity);
            self.last_refill = now;
        }
    }

    /// Returns `Ok(())` if `n` tokens were consumed. Returns
    /// `Err(wait)` if the bucket is short, where `wait` is how long to
    /// sleep before the bucket is expected to have enough.
    fn try_consume(&mut self, n: f64) -> Result<(), Duration> {
        self.refill();
        if self.tokens >= n {
            self.tokens -= n;
            Ok(())
        } else {
            let needed = n - self.tokens;
            let secs = needed / self.refill_per_sec;
            Err(Duration::from_secs_f64(secs))
        }
    }
}

/// [`LlmClient`] decorator that gates the inner client through token
/// buckets for RPM and (optionally) TPM.
///
/// Construct with [`RateLimitedClient::new`] and use it anywhere an
/// `Arc<dyn LlmClient>` is accepted.
/// W5.A24: bucket guarded by a `parking_lot::Mutex` instead of
/// `tokio::sync::Mutex`. The critical section is microseconds — just
/// the bucket-math `try_consume` call — so a non-async mutex skips
/// the tokio scheduler hop on every acquire. Holding a parking_lot
/// mutex across `.await` is unsafe; this impl never does (the guard
/// drops at the end of the `{ ... }` block before any `sleep`).
pub struct RateLimitedClient {
    inner: Arc<dyn LlmClient>,
    rpm: Mutex<Bucket>,
    tpm: Option<Mutex<Bucket>>,
}

impl RateLimitedClient {
    /// Wrap `inner` with the supplied [`RateLimitConfig`].
    pub fn new(inner: Arc<dyn LlmClient>, config: RateLimitConfig) -> Self {
        let rpm = Mutex::new(Bucket::new(config.requests_per_minute));
        let tpm = config.tokens_per_minute.map(|t| Mutex::new(Bucket::new(t)));
        Self { inner, rpm, tpm }
    }

    async fn acquire_rpm(&self) {
        loop {
            // parking_lot::Mutex::lock() is sync. Drop the guard
            // (returned by the inner block) before the `sleep` await
            // so the mutex isn't held across .await.
            let wait = {
                let mut b = self.rpm.lock();
                b.try_consume(1.0)
            };
            match wait {
                Ok(()) => return,
                Err(d) => tokio::time::sleep(d).await,
            }
        }
    }

    async fn acquire_tpm(&self, tokens: u32) {
        let Some(tpm) = &self.tpm else { return };
        let n = f64::from(tokens.max(1));
        loop {
            let wait = {
                let mut b = tpm.lock();
                b.try_consume(n)
            };
            match wait {
                Ok(()) => return,
                Err(d) => tokio::time::sleep(d).await,
            }
        }
    }
}

/// Coarse char-based token estimator.
///
/// Uses ~4 chars per token across English + structured prompts as the
/// industry-standard rule-of-thumb; an exact count requires the
/// provider's tokenizer, which is out of scope here.
fn estimate_tokens(req: &ChatRequest) -> u32 {
    let prompt_chars: usize = req.messages.iter().map(|m| m.content.len()).sum();
    let prompt_tokens = u32::try_from(prompt_chars / 4).unwrap_or(u32::MAX);
    let completion_estimate = req.max_tokens.unwrap_or(256);
    prompt_tokens.saturating_add(completion_estimate)
}

#[async_trait]
impl LlmClient for RateLimitedClient {
    fn name(&self) -> &str {
        self.inner.name()
    }

    fn capabilities(&self) -> &Capabilities {
        self.inner.capabilities()
    }

    async fn complete(&self, req: ChatRequest) -> Result<ChatResponse, LlmError> {
        let tokens = estimate_tokens(&req);
        self.acquire_rpm().await;
        self.acquire_tpm(tokens).await;
        self.inner.complete(req).await
    }

    async fn stream(&self, req: ChatRequest) -> Result<ChunkStream, LlmError> {
        let tokens = estimate_tokens(&req);
        self.acquire_rpm().await;
        self.acquire_tpm(tokens).await;
        self.inner.stream(req).await
    }

    async fn embed(&self, texts: &[String]) -> Result<Vec<Embedding>, LlmError> {
        // RPM-only: embeddings are priced separately and per-text token
        // counts are out of scope for this primitive.
        self.acquire_rpm().await;
        self.inner.embed(texts).await
    }
}

#[cfg(test)]
mod tests {
    //! Tests run with `start_paused = true`. Bucket internals use
    //! [`tokio::time::Instant`], so refills + sleeps both honour virtual
    //! time and the schedule (RPM=2 → 30s per refilled token) collapses
    //! instantly while still exercising the gate logic.
    use super::*;
    use crate::test_utils::{FakeLlmClient, FakeLlmStep};
    use crate::ChatRequest;
    use std::sync::Arc;

    fn fake() -> Arc<dyn LlmClient> {
        Arc::new(FakeLlmClient::new("fake").with_steps(vec![
            FakeLlmStep::Text("ok".into()),
            FakeLlmStep::Text("ok".into()),
            FakeLlmStep::Text("ok".into()),
            FakeLlmStep::Text("ok".into()),
            FakeLlmStep::Text("ok".into()),
            FakeLlmStep::Text("ok".into()),
        ]))
    }

    #[tokio::test(start_paused = true)]
    async fn calls_within_rpm_capacity_do_not_block() {
        let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(60));
        let start = Instant::now();
        for _ in 0..5 {
            limited.complete(ChatRequest::new(vec![])).await.unwrap();
        }
        // No sleep should have been needed: bucket capacity 60 ≫ 5 calls.
        assert!(
            Instant::now().saturating_duration_since(start) < Duration::from_secs(1),
            "expected near-instant completion under capacity"
        );
    }

    #[tokio::test(start_paused = true)]
    async fn blocks_when_rpm_bucket_drained_then_refills() {
        // Capacity = 2, refill = 2/60 per sec → one new token every 30s.
        let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(2));
        // Drain the bucket.
        limited.complete(ChatRequest::new(vec![])).await.unwrap();
        limited.complete(ChatRequest::new(vec![])).await.unwrap();
        let start = Instant::now();
        // Third call must wait ~30s of virtual time for one token.
        limited.complete(ChatRequest::new(vec![])).await.unwrap();
        let waited = Instant::now().saturating_duration_since(start);
        assert!(
            waited >= Duration::from_secs(29),
            "expected ≥29s wait, got {waited:?}"
        );
    }

    #[tokio::test(start_paused = true)]
    async fn tpm_disabled_lets_large_requests_through() {
        // No TPM gate ⇒ huge max_tokens does not block.
        let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(60));
        let mut req = ChatRequest::new(vec![]);
        req.max_tokens = Some(100_000);
        let start = Instant::now();
        limited.complete(req).await.unwrap();
        assert!(Instant::now().saturating_duration_since(start) < Duration::from_secs(1));
    }

    #[tokio::test(start_paused = true)]
    async fn tpm_gates_independently_of_rpm() {
        // RPM is generous (1000/min); TPM is tight (1000/min). Each call
        // estimates 256 + 0 = 256 tokens. Three calls drain ~768 tokens
        // out of 1000 (still under capacity), but with max_tokens=600
        // each, three calls = 1800 estimated tokens > 1000 capacity, so
        // the third must wait for refill.
        let limited = RateLimitedClient::new(
            fake(),
            RateLimitConfig::new(1000).with_tokens_per_minute(1000),
        );
        let mut req = ChatRequest::new(vec![]);
        req.max_tokens = Some(600);
        // First call: 600 tokens consumed, 400 left.
        limited.complete(req.clone()).await.unwrap();
        // Second call: would need 600 but only 400 left ⇒ blocks for
        // refill. Refill rate = 1000/60 ≈ 16.67/sec; 200 tokens needed
        // ≈ 12s. We assert ≥ 11s as a lower bound.
        let start = Instant::now();
        limited.complete(req).await.unwrap();
        let waited = Instant::now().saturating_duration_since(start);
        assert!(
            waited >= Duration::from_secs(11),
            "expected TPM gate to delay ≥11s, got {waited:?}"
        );
    }

    #[tokio::test(start_paused = true)]
    async fn concurrent_callers_serialize_through_bucket() {
        // Capacity = 2. Spawn 4 concurrent callers; first two run
        // immediately, last two wait for refill (one new token every 30s).
        let limited = Arc::new(RateLimitedClient::new(fake(), RateLimitConfig::new(2)));
        let start = Instant::now();
        let mut handles = vec![];
        for _ in 0..4 {
            let l = limited.clone();
            handles.push(tokio::spawn(async move {
                l.complete(ChatRequest::new(vec![])).await.unwrap();
            }));
        }
        for h in handles {
            h.await.unwrap();
        }
        // Two calls had to wait for the bucket to refill from 0 → 1
        // and 0 → 1 again ⇒ ~60s total virtual time.
        let waited = Instant::now().saturating_duration_since(start);
        assert!(
            waited >= Duration::from_secs(59),
            "expected ≥59s, got {waited:?}"
        );
    }

    #[tokio::test(start_paused = true)]
    async fn refill_caps_at_capacity() {
        // Drain partially, sleep beyond a full refill window, verify
        // the bucket does not overflow capacity.
        let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(10));
        for _ in 0..3 {
            limited.complete(ChatRequest::new(vec![])).await.unwrap();
        }
        // 7 tokens left, capacity 10. Sleep 5 minutes — refill caps at 10.
        tokio::time::sleep(Duration::from_secs(300)).await;
        // Now verify we can fire exactly 10 calls fast (no blocking).
        let start = Instant::now();
        for _ in 0..10 {
            // FakeLlmClient queue may be exhausted; recreate per call by
            // reusing the same wrapper but spawning fresh fakes via a
            // local helper would over-complicate. Instead, accept that
            // the inner returns Unsupported once steps run out — we are
            // testing the bucket gate, not the inner client. A concrete
            // assertion below covers timing.
            let _ = limited.complete(ChatRequest::new(vec![])).await;
        }
        assert!(
            Instant::now().saturating_duration_since(start) < Duration::from_secs(1),
            "10 calls within capacity must not block after a long idle"
        );
    }
}