Skip to main content

crabtalk_proxy/ext/
rate_limit.rs

1use crabtalk_core::{
2    BoxFuture, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ExtensionError,
3    Prefix, RequestContext, Storage,
4};
5use std::{
6    sync::Arc,
7    time::{SystemTime, UNIX_EPOCH},
8};
9
10pub struct RateLimit {
11    storage: Arc<dyn Storage>,
12    requests_per_minute: u64,
13    tokens_per_minute: Option<u64>,
14}
15
16impl RateLimit {
17    pub fn new(config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
18        let rpm = config
19            .get("requests_per_minute")
20            .and_then(|v| v.as_i64())
21            .ok_or("rate_limit: missing or invalid 'requests_per_minute'")?;
22
23        if rpm <= 0 {
24            return Err("rate_limit: 'requests_per_minute' must be positive".to_string());
25        }
26
27        let tpm = config
28            .get("tokens_per_minute")
29            .and_then(|v| v.as_i64())
30            .map(|v| {
31                if v <= 0 {
32                    Err("rate_limit: 'tokens_per_minute' must be positive".to_string())
33                } else {
34                    Ok(v as u64)
35                }
36            })
37            .transpose()?;
38
39        Ok(Self {
40            storage,
41            requests_per_minute: rpm as u64,
42            tokens_per_minute: tpm,
43        })
44    }
45}
46
47fn current_minute() -> u64 {
48    SystemTime::now()
49        .duration_since(UNIX_EPOCH)
50        .unwrap_or_default()
51        .as_secs()
52        / 60
53}
54
55impl crabtalk_core::Extension for RateLimit {
56    fn name(&self) -> &str {
57        "rate_limit"
58    }
59
60    fn prefix(&self) -> Prefix {
61        *b"rlim"
62    }
63
64    fn on_request(&self, ctx: &RequestContext) -> BoxFuture<'_, Result<(), ExtensionError>> {
65        let key_name = ctx.key_name.as_deref().unwrap_or("__global");
66        let minute = current_minute();
67
68        let rpm_suffix = format!("{key_name}:{minute}");
69        let rpm_key = self.storage_key(rpm_suffix.as_bytes());
70        let rpm_limit = self.requests_per_minute;
71
72        let tpm_limit = self.tokens_per_minute;
73        let tpm_key = tpm_limit.map(|_| {
74            let tpm_suffix = format!("{key_name}:tpm:{minute}");
75            self.storage_key(tpm_suffix.as_bytes())
76        });
77
78        Box::pin(async move {
79            // Check RPM.
80            let count = self
81                .storage
82                .increment(&rpm_key, 1)
83                .await
84                .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
85
86            if count as u64 > rpm_limit {
87                return Err(ExtensionError::new(
88                    429,
89                    "rate limit exceeded (RPM)",
90                    "rate_limit_error",
91                ));
92            }
93
94            // Check TPM.
95            if let (Some(limit), Some(key)) = (tpm_limit, &tpm_key) {
96                let tokens = self
97                    .storage
98                    .increment(key, 0)
99                    .await
100                    .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
101
102                if tokens as u64 > limit {
103                    return Err(ExtensionError::new(
104                        429,
105                        "rate limit exceeded (TPM)",
106                        "rate_limit_error",
107                    ));
108                }
109            }
110
111            Ok(())
112        })
113    }
114
115    fn on_response(
116        &self,
117        ctx: &RequestContext,
118        _request: &ChatCompletionRequest,
119        response: &ChatCompletionResponse,
120    ) -> BoxFuture<'_, ()> {
121        if self.tokens_per_minute.is_none() {
122            return Box::pin(async {});
123        }
124
125        let total_tokens = response
126            .usage
127            .as_ref()
128            .map(|u| u.total_tokens as i64)
129            .unwrap_or(0);
130
131        if total_tokens == 0 {
132            return Box::pin(async {});
133        }
134
135        let key_name = ctx.key_name.as_deref().unwrap_or("__global");
136        let minute = current_minute();
137        let tpm_suffix = format!("{key_name}:tpm:{minute}");
138        let tpm_key = self.storage_key(tpm_suffix.as_bytes());
139
140        Box::pin(async move {
141            let _ = self.storage.increment(&tpm_key, total_tokens).await;
142        })
143    }
144
145    fn on_chunk(&self, ctx: &RequestContext, chunk: &ChatCompletionChunk) -> BoxFuture<'_, ()> {
146        if self.tokens_per_minute.is_none() {
147            return Box::pin(async {});
148        }
149
150        let total_tokens = chunk
151            .usage
152            .as_ref()
153            .map(|u| u.total_tokens as i64)
154            .unwrap_or(0);
155
156        if total_tokens == 0 {
157            return Box::pin(async {});
158        }
159
160        let key_name = ctx.key_name.as_deref().unwrap_or("__global");
161        let minute = current_minute();
162        let tpm_suffix = format!("{key_name}:tpm:{minute}");
163        let tpm_key = self.storage_key(tpm_suffix.as_bytes());
164
165        Box::pin(async move {
166            let _ = self.storage.increment(&tpm_key, total_tokens).await;
167        })
168    }
169}