Skip to main content

cognee_core/
rate_limiter.rs

1//! Proactive request-rate throttling for pipeline tasks.
2// The semaphore is never closed (only dropped on teardown) so acquire() cannot
3// fail — expect() is safe here.
4#![allow(
5    clippy::expect_used,
6    reason = "semaphore is never closed; acquire() cannot return Err"
7)]
8//!
9//! See `docs/cog-4454-core/03-rate-limiting.md` for the design rationale and how
10//! this differs from `Pipeline::with_concurrency` (item parallelism) and
11//! `RetryPolicy` (reactive backoff).
12//!
13//! # Choosing the right tool
14//!
15//! - **Proactive request-rate throttle** → [`RateLimiter`] (this module)
16//! - **Bounded item-level parallelism** → [`Pipeline::with_concurrency`](crate::pipeline::Pipeline::with_concurrency)
17//! - **Reactive backoff on failure** → [`RetryPolicy`](crate::pipeline::RetryPolicy)
18
19use std::sync::Arc;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use tokio::sync::Semaphore;
24
25/// Admission throttle: `acquire().await` returns when the caller is permitted to
26/// start an external call. Object-safe; hold as `Arc<dyn RateLimiter>`.
27///
28/// Both [`TokenBucketLimiter`] and [`SemaphoreLimiter`] implement this trait.
29/// The `acquire()` contract models *admission* (rate of starts), not
30/// concurrency-with-hold. For true hold-until-done concurrency limiting,
31/// prefer [`Pipeline::with_concurrency`](crate::pipeline::Pipeline::with_concurrency).
32#[async_trait]
33pub trait RateLimiter: Send + Sync {
34    /// Wait until the caller is permitted to start an external call.
35    async fn acquire(&self);
36}
37
38/// Caps the number of starts allowed per refill window (token-bucket algorithm).
39///
40/// `capacity` tokens are available at startup; a background task adds one token
41/// every `1 / refill_per_sec` seconds, up to `capacity`. `acquire()` waits for a
42/// token and consumes it (the refiller restores tokens over time).
43///
44/// # Panics (constructor)
45///
46/// `new` asserts `capacity > 0` and `refill_per_sec > 0.0`. These are API
47/// misuse guards: a zero capacity would permanently block all callers and a
48/// non-positive rate is mathematically undefined, so a panic at construction
49/// time is the right signal.
50///
51/// # Refiller lifecycle
52///
53/// The background refiller task self-terminates one tick after the limiter is
54/// dropped (`Arc::strong_count == 1`). This avoids leaking a task per limiter.
55/// If a stricter shutdown is needed later, store a `tokio::task::JoinHandle`
56/// + `Notify` and abort on `Drop`; not required for the initial implementation.
57pub struct TokenBucketLimiter {
58    semaphore: Arc<Semaphore>,
59}
60
61impl TokenBucketLimiter {
62    /// Create a new token-bucket limiter.
63    ///
64    /// * `capacity` — maximum burst size (initial token count, upper refill bound).
65    /// * `refill_per_sec` — tokens restored per second.
66    ///
67    /// # Panics
68    ///
69    /// Panics if `capacity == 0` or `refill_per_sec <= 0.0` (API misuse guards).
70    pub fn new(capacity: usize, refill_per_sec: f64) -> Self {
71        assert!(capacity > 0, "capacity must be > 0");
72        assert!(refill_per_sec > 0.0, "refill_per_sec must be > 0");
73
74        let semaphore = Arc::new(Semaphore::new(capacity));
75        let refill = semaphore.clone();
76        let interval = Duration::from_secs_f64(1.0 / refill_per_sec);
77
78        // Background refiller. Stops when the semaphore (last Arc) is dropped.
79        // Use interval_at (start = now + interval) so the first tick is one full
80        // interval away; tokio::time::interval's first tick fires immediately, which
81        // would refill a permit right away if all tokens were consumed before the
82        // spawned task first runs.
83        tokio::spawn(async move {
84            let mut ticker =
85                tokio::time::interval_at(tokio::time::Instant::now() + interval, interval);
86            loop {
87                ticker.tick().await;
88                // Only add back up to `capacity` (avoid unbounded growth).
89                if refill.available_permits() < capacity {
90                    refill.add_permits(1);
91                }
92                // Stop if we are the only holder left (limiter was dropped).
93                if Arc::strong_count(&refill) == 1 {
94                    break;
95                }
96            }
97        });
98
99        Self { semaphore }
100    }
101}
102
103#[async_trait]
104impl RateLimiter for TokenBucketLimiter {
105    async fn acquire(&self) {
106        // `forget()` consumes the permit without releasing it back to the
107        // semaphore; the refiller restores tokens over time.
108        // The semaphore is never closed (only dropped on limiter teardown),
109        // so `acquire()` cannot return `Err`.
110        let permit = self
111            .semaphore
112            .acquire()
113            .await
114            .expect("rate-limiter semaphore is never closed");
115        permit.forget();
116    }
117}
118
119/// Admission-style concurrency limit: at most `max_per_sec` starts may be
120/// issued per second.
121///
122/// Distinct from [`Pipeline::with_concurrency`](crate::pipeline::Pipeline::with_concurrency),
123/// which bounds data-item parallelism at the executor level. `SemaphoreLimiter`
124/// is a proactive request-rate throttle implemented as a token bucket whose
125/// capacity and refill rate are both set to `max_per_sec`.
126pub struct SemaphoreLimiter {
127    inner: TokenBucketLimiter,
128}
129
130impl SemaphoreLimiter {
131    /// Create a new semaphore-style limiter that allows at most `max_per_sec`
132    /// acquisitions per second.
133    pub fn new(max_per_sec: usize) -> Self {
134        Self {
135            inner: TokenBucketLimiter::new(max_per_sec, max_per_sec as f64),
136        }
137    }
138}
139
140#[async_trait]
141impl RateLimiter for SemaphoreLimiter {
142    async fn acquire(&self) {
143        self.inner.acquire().await;
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use std::time::Instant;
151
152    /// `TokenBucketLimiter` lets `capacity` immediate acquisitions through (the
153    /// initial token pool is full), then the next acquire must wait for a refill
154    /// tick.
155    #[tokio::test]
156    async fn token_bucket_burst_then_wait() {
157        // 2 tokens, refill at 10/sec (100ms per token).
158        let limiter = TokenBucketLimiter::new(2, 10.0);
159
160        // First two acquires should be nearly instant (tokens available).
161        let t0 = Instant::now();
162        limiter.acquire().await;
163        limiter.acquire().await;
164        let burst_elapsed = t0.elapsed();
165        assert!(
166            burst_elapsed < Duration::from_millis(80),
167            "burst acquires should be fast, took {burst_elapsed:?}"
168        );
169
170        // Third acquire must wait ~100ms for the next refill tick.
171        let t1 = Instant::now();
172        limiter.acquire().await;
173        let wait_elapsed = t1.elapsed();
174        assert!(
175            wait_elapsed >= Duration::from_millis(50),
176            "third acquire should wait for refill, took {wait_elapsed:?}"
177        );
178    }
179
180    /// `SemaphoreLimiter::new` panics if `max_per_sec == 0` (delegated through
181    /// `TokenBucketLimiter::new`'s capacity assert).
182    #[test]
183    #[should_panic(expected = "capacity must be > 0")]
184    fn semaphore_limiter_rejects_zero() {
185        let _ = SemaphoreLimiter::new(0);
186    }
187
188    /// `TokenBucketLimiter::new` panics on zero capacity.
189    #[test]
190    #[should_panic(expected = "capacity must be > 0")]
191    fn token_bucket_rejects_zero_capacity() {
192        let _ = TokenBucketLimiter::new(0, 1.0);
193    }
194
195    /// `TokenBucketLimiter::new` panics on non-positive refill rate.
196    #[test]
197    #[should_panic(expected = "refill_per_sec must be > 0")]
198    fn token_bucket_rejects_zero_rate() {
199        let _ = TokenBucketLimiter::new(1, 0.0);
200    }
201}