Skip to main content

synaptic_models/
token_bucket.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, SynapticError};
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8/// A token bucket rate limiter.
9///
10/// Starts full at `capacity` tokens and refills at `refill_rate` tokens per second.
11/// Calling [`acquire`](TokenBucket::acquire) waits until a token is available, then
12/// consumes one token.
13pub struct TokenBucket {
14    capacity: f64,
15    tokens: Mutex<f64>,
16    refill_rate: f64,
17    last_refill: Mutex<Instant>,
18}
19
20impl TokenBucket {
21    /// Create a new token bucket that starts full.
22    ///
23    /// - `capacity`: maximum number of tokens the bucket can hold
24    /// - `refill_rate`: tokens added per second
25    pub fn new(capacity: f64, refill_rate: f64) -> Self {
26        Self {
27            capacity,
28            tokens: Mutex::new(capacity),
29            refill_rate,
30            last_refill: Mutex::new(Instant::now()),
31        }
32    }
33
34    /// Wait until a token is available and consume it.
35    pub async fn acquire(&self) {
36        loop {
37            self.refill().await;
38
39            let mut tokens = self.tokens.lock().await;
40            if *tokens >= 1.0 {
41                *tokens -= 1.0;
42                return;
43            }
44            drop(tokens);
45
46            // Wait a short interval before checking again
47            // Calculate how long until we have at least 1 token
48            let wait = std::time::Duration::from_secs_f64(1.0 / self.refill_rate);
49            tokio::time::sleep(wait).await;
50        }
51    }
52
53    async fn refill(&self) {
54        let now = Instant::now();
55        let mut last_refill = self.last_refill.lock().await;
56        let elapsed = now.duration_since(*last_refill);
57        let new_tokens = elapsed.as_secs_f64() * self.refill_rate;
58
59        if new_tokens > 0.0 {
60            let mut tokens = self.tokens.lock().await;
61            *tokens = (*tokens + new_tokens).min(self.capacity);
62            *last_refill = now;
63        }
64    }
65}
66
67/// A ChatModel wrapper that uses a [`TokenBucket`] to rate-limit requests.
68///
69/// Each call to `chat` or `stream_chat` acquires one token before delegating
70/// to the inner model.
71pub struct TokenBucketChatModel {
72    inner: Arc<dyn ChatModel>,
73    bucket: Arc<TokenBucket>,
74}
75
76impl TokenBucketChatModel {
77    /// Create a new token-bucket rate-limited model.
78    ///
79    /// - `inner`: the model to wrap
80    /// - `capacity`: maximum burst size (tokens)
81    /// - `refill_rate`: tokens per second
82    pub fn new(inner: Arc<dyn ChatModel>, capacity: f64, refill_rate: f64) -> Self {
83        Self {
84            inner,
85            bucket: Arc::new(TokenBucket::new(capacity, refill_rate)),
86        }
87    }
88}
89
90#[async_trait]
91impl ChatModel for TokenBucketChatModel {
92    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
93        self.bucket.acquire().await;
94        self.inner.chat(request).await
95    }
96
97    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
98        let inner = self.inner.clone();
99        let bucket = self.bucket.clone();
100
101        Box::pin(async_stream::stream! {
102            bucket.acquire().await;
103
104            use futures::StreamExt;
105            let mut stream = inner.stream_chat(request);
106            while let Some(result) = stream.next().await {
107                yield result;
108            }
109        })
110    }
111}