cognee_core/rate_limiter.rs
1//! Proactive request-rate throttling for pipeline tasks.
2// The semaphore is never closed (only dropped on teardown) so acquire() cannot
3// fail — expect() is safe here.
4#![allow(
5 clippy::expect_used,
6 reason = "semaphore is never closed; acquire() cannot return Err"
7)]
8//!
9//! See `docs/cog-4454-core/03-rate-limiting.md` for the design rationale and how
10//! this differs from `Pipeline::with_concurrency` (item parallelism) and
11//! `RetryPolicy` (reactive backoff).
12//!
13//! # Choosing the right tool
14//!
15//! - **Proactive request-rate throttle** → [`RateLimiter`] (this module)
16//! - **Bounded item-level parallelism** → [`Pipeline::with_concurrency`](crate::pipeline::Pipeline::with_concurrency)
17//! - **Reactive backoff on failure** → [`RetryPolicy`](crate::pipeline::RetryPolicy)
18
19use std::sync::Arc;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use tokio::sync::Semaphore;
24
25/// Admission throttle: `acquire().await` returns when the caller is permitted to
26/// start an external call. Object-safe; hold as `Arc<dyn RateLimiter>`.
27///
28/// Both [`TokenBucketLimiter`] and [`SemaphoreLimiter`] implement this trait.
29/// The `acquire()` contract models *admission* (rate of starts), not
30/// concurrency-with-hold. For true hold-until-done concurrency limiting,
31/// prefer [`Pipeline::with_concurrency`](crate::pipeline::Pipeline::with_concurrency).
32#[async_trait]
33pub trait RateLimiter: Send + Sync {
34 /// Wait until the caller is permitted to start an external call.
35 async fn acquire(&self);
36}
37
38/// Caps the number of starts allowed per refill window (token-bucket algorithm).
39///
40/// `capacity` tokens are available at startup; a background task adds one token
41/// every `1 / refill_per_sec` seconds, up to `capacity`. `acquire()` waits for a
42/// token and consumes it (the refiller restores tokens over time).
43///
44/// # Panics (constructor)
45///
46/// `new` asserts `capacity > 0` and `refill_per_sec > 0.0`. These are API
47/// misuse guards: a zero capacity would permanently block all callers and a
48/// non-positive rate is mathematically undefined, so a panic at construction
49/// time is the right signal.
50///
51/// # Refiller lifecycle
52///
53/// The background refiller task self-terminates one tick after the limiter is
54/// dropped (`Arc::strong_count == 1`). This avoids leaking a task per limiter.
55/// If a stricter shutdown is needed later, store a `tokio::task::JoinHandle`
56/// + `Notify` and abort on `Drop`; not required for the initial implementation.
57pub struct TokenBucketLimiter {
58 semaphore: Arc<Semaphore>,
59}
60
61impl TokenBucketLimiter {
62 /// Create a new token-bucket limiter.
63 ///
64 /// * `capacity` — maximum burst size (initial token count, upper refill bound).
65 /// * `refill_per_sec` — tokens restored per second.
66 ///
67 /// # Panics
68 ///
69 /// Panics if `capacity == 0` or `refill_per_sec <= 0.0` (API misuse guards).
70 pub fn new(capacity: usize, refill_per_sec: f64) -> Self {
71 assert!(capacity > 0, "capacity must be > 0");
72 assert!(refill_per_sec > 0.0, "refill_per_sec must be > 0");
73
74 let semaphore = Arc::new(Semaphore::new(capacity));
75 let refill = semaphore.clone();
76 let interval = Duration::from_secs_f64(1.0 / refill_per_sec);
77
78 // Background refiller. Stops when the semaphore (last Arc) is dropped.
79 // Use interval_at (start = now + interval) so the first tick is one full
80 // interval away; tokio::time::interval's first tick fires immediately, which
81 // would refill a permit right away if all tokens were consumed before the
82 // spawned task first runs.
83 tokio::spawn(async move {
84 let mut ticker =
85 tokio::time::interval_at(tokio::time::Instant::now() + interval, interval);
86 loop {
87 ticker.tick().await;
88 // Only add back up to `capacity` (avoid unbounded growth).
89 if refill.available_permits() < capacity {
90 refill.add_permits(1);
91 }
92 // Stop if we are the only holder left (limiter was dropped).
93 if Arc::strong_count(&refill) == 1 {
94 break;
95 }
96 }
97 });
98
99 Self { semaphore }
100 }
101}
102
103#[async_trait]
104impl RateLimiter for TokenBucketLimiter {
105 async fn acquire(&self) {
106 // `forget()` consumes the permit without releasing it back to the
107 // semaphore; the refiller restores tokens over time.
108 // The semaphore is never closed (only dropped on limiter teardown),
109 // so `acquire()` cannot return `Err`.
110 let permit = self
111 .semaphore
112 .acquire()
113 .await
114 .expect("rate-limiter semaphore is never closed");
115 permit.forget();
116 }
117}
118
119/// Admission-style concurrency limit: at most `max_per_sec` starts may be
120/// issued per second.
121///
122/// Distinct from [`Pipeline::with_concurrency`](crate::pipeline::Pipeline::with_concurrency),
123/// which bounds data-item parallelism at the executor level. `SemaphoreLimiter`
124/// is a proactive request-rate throttle implemented as a token bucket whose
125/// capacity and refill rate are both set to `max_per_sec`.
126pub struct SemaphoreLimiter {
127 inner: TokenBucketLimiter,
128}
129
130impl SemaphoreLimiter {
131 /// Create a new semaphore-style limiter that allows at most `max_per_sec`
132 /// acquisitions per second.
133 pub fn new(max_per_sec: usize) -> Self {
134 Self {
135 inner: TokenBucketLimiter::new(max_per_sec, max_per_sec as f64),
136 }
137 }
138}
139
140#[async_trait]
141impl RateLimiter for SemaphoreLimiter {
142 async fn acquire(&self) {
143 self.inner.acquire().await;
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::time::Instant;
151
152 /// `TokenBucketLimiter` lets `capacity` immediate acquisitions through (the
153 /// initial token pool is full), then the next acquire must wait for a refill
154 /// tick.
155 #[tokio::test]
156 async fn token_bucket_burst_then_wait() {
157 // 2 tokens, refill at 10/sec (100ms per token).
158 let limiter = TokenBucketLimiter::new(2, 10.0);
159
160 // First two acquires should be nearly instant (tokens available).
161 let t0 = Instant::now();
162 limiter.acquire().await;
163 limiter.acquire().await;
164 let burst_elapsed = t0.elapsed();
165 assert!(
166 burst_elapsed < Duration::from_millis(80),
167 "burst acquires should be fast, took {burst_elapsed:?}"
168 );
169
170 // Third acquire must wait ~100ms for the next refill tick.
171 let t1 = Instant::now();
172 limiter.acquire().await;
173 let wait_elapsed = t1.elapsed();
174 assert!(
175 wait_elapsed >= Duration::from_millis(50),
176 "third acquire should wait for refill, took {wait_elapsed:?}"
177 );
178 }
179
180 /// `SemaphoreLimiter::new` panics if `max_per_sec == 0` (delegated through
181 /// `TokenBucketLimiter::new`'s capacity assert).
182 #[test]
183 #[should_panic(expected = "capacity must be > 0")]
184 fn semaphore_limiter_rejects_zero() {
185 let _ = SemaphoreLimiter::new(0);
186 }
187
188 /// `TokenBucketLimiter::new` panics on zero capacity.
189 #[test]
190 #[should_panic(expected = "capacity must be > 0")]
191 fn token_bucket_rejects_zero_capacity() {
192 let _ = TokenBucketLimiter::new(0, 1.0);
193 }
194
195 /// `TokenBucketLimiter::new` panics on non-positive refill rate.
196 #[test]
197 #[should_panic(expected = "refill_per_sec must be > 0")]
198 fn token_bucket_rejects_zero_rate() {
199 let _ = TokenBucketLimiter::new(1, 0.0);
200 }
201}