1use std::{
2 collections::HashMap,
3 future::Future,
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::{Context, Poll},
7 time::{Duration, Instant},
8};
9
10use axum::{body::Body, extract::Request};
11use http::{HeaderValue, Response, StatusCode};
12use tower::{Layer, Service};
13
14use crate::context::{IdentityExtractor, RequestIdentity};
15
16type IdentityFn =
17 Arc<dyn Fn(&http::request::Parts) -> Option<RequestIdentity> + Send + Sync + 'static>;
18
19#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
21#[error("{message}")]
22pub struct RateLimitError {
23 message: String,
24}
25
26impl RateLimitError {
27 pub fn new(message: impl Into<String>) -> Self {
29 Self {
30 message: message.into(),
31 }
32 }
33}
34
35#[derive(Clone, Debug, Eq, PartialEq)]
37pub struct RateLimitDecision {
38 pub allowed: bool,
40 pub limit: u64,
42 pub remaining: u64,
44 pub reset_after: Duration,
46}
47
48pub trait RateLimitStore: Send + Sync + 'static {
50 fn check(
52 &self,
53 identity: &RequestIdentity,
54 limit: u64,
55 window: Duration,
56 ) -> Result<RateLimitDecision, RateLimitError>;
57}
58
59#[derive(Clone, Default)]
66pub struct InMemoryRateLimitStore {
67 state: Arc<Mutex<HashMap<String, WindowState>>>,
68}
69
70impl InMemoryRateLimitStore {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn len(&self) -> usize {
81 self.state
82 .lock()
83 .map(|state| state.len())
84 .unwrap_or_default()
85 }
86
87 pub fn is_empty(&self) -> bool {
89 self.len() == 0
90 }
91}
92
93impl RateLimitStore for InMemoryRateLimitStore {
94 fn check(
95 &self,
96 identity: &RequestIdentity,
97 limit: u64,
98 window: Duration,
99 ) -> Result<RateLimitDecision, RateLimitError> {
100 let now = Instant::now();
101 let mut state = self
102 .state
103 .lock()
104 .map_err(|_| RateLimitError::new("rate limit store poisoned"))?;
105 state.retain(|_, window_state| now.duration_since(window_state.started_at) < window);
106 let window_state = state
107 .entry(identity.as_str().to_owned())
108 .or_insert_with(|| WindowState {
109 started_at: now,
110 count: 0,
111 });
112 if now.duration_since(window_state.started_at) >= window {
113 window_state.started_at = now;
114 window_state.count = 0;
115 }
116
117 let allowed = window_state.count < limit;
118 if allowed {
119 window_state.count += 1;
120 }
121 let remaining = limit.saturating_sub(window_state.count);
122 let reset_after = window.saturating_sub(now.duration_since(window_state.started_at));
123 Ok(RateLimitDecision {
124 allowed,
125 limit,
126 remaining,
127 reset_after,
128 })
129 }
130}
131
132#[derive(Clone)]
133struct WindowState {
134 started_at: Instant,
135 count: u64,
136}
137
138#[derive(Clone)]
140pub struct RateLimitConfig {
141 limit: u64,
142 window: Duration,
143 store: Arc<dyn RateLimitStore>,
144 identity: IdentityFn,
145 fail_open: bool,
146}
147
148impl RateLimitConfig {
149 pub fn new(limit: u64, window: Duration, store: impl RateLimitStore) -> Self {
151 Self {
152 limit,
153 window,
154 store: Arc::new(store),
155 identity: Arc::new(|_parts| Some(RequestIdentity::new("anonymous"))),
156 fail_open: true,
157 }
158 }
159
160 pub fn identity(mut self, extractor: impl IdentityExtractor) -> Self {
162 self.identity = Arc::new(move |parts| extractor.extract(parts));
163 self
164 }
165
166 pub fn fail_open(mut self) -> Self {
168 self.fail_open = true;
169 self
170 }
171
172 pub fn fail_closed(mut self) -> Self {
174 self.fail_open = false;
175 self
176 }
177
178 pub fn layer(self) -> RateLimitLayer {
180 RateLimitLayer { config: self }
181 }
182}
183
184#[derive(Clone)]
186pub struct RateLimitLayer {
187 config: RateLimitConfig,
188}
189
190impl<S> Layer<S> for RateLimitLayer {
191 type Service = RateLimitService<S>;
192
193 fn layer(&self, inner: S) -> Self::Service {
194 RateLimitService {
195 inner,
196 config: self.config.clone(),
197 }
198 }
199}
200
201#[derive(Clone)]
203pub struct RateLimitService<S> {
204 inner: S,
205 config: RateLimitConfig,
206}
207
208impl<S> Service<Request> for RateLimitService<S>
209where
210 S: Service<Request, Response = Response<Body>> + Send + 'static,
211 S::Future: Send + 'static,
212 S::Error: Send + 'static,
213{
214 type Response = Response<Body>;
215 type Error = S::Error;
216 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
217
218 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
219 self.inner.poll_ready(cx)
220 }
221
222 fn call(&mut self, request: Request) -> Self::Future {
223 let config = self.config.clone();
224 let (parts, body) = request.into_parts();
225 let identity =
226 (config.identity)(&parts).unwrap_or_else(|| RequestIdentity::new("anonymous"));
227 let decision = config
228 .store
229 .check(&identity, config.limit, config.window)
230 .unwrap_or(RateLimitDecision {
231 allowed: config.fail_open,
232 limit: config.limit,
233 remaining: if config.fail_open { config.limit } else { 0 },
234 reset_after: config.window,
235 });
236
237 if !decision.allowed {
238 return Box::pin(async move { Ok(rate_limited_response(decision)) });
239 }
240
241 let future = self.inner.call(Request::from_parts(parts, body));
242 Box::pin(async move {
243 let mut response = future.await?;
244 insert_rate_limit_headers(response.headers_mut(), &decision);
245 Ok(response)
246 })
247 }
248}
249
250fn rate_limited_response(decision: RateLimitDecision) -> Response<Body> {
251 let mut response = Response::new(Body::from("rate limit exceeded"));
252 *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
253 insert_rate_limit_headers(response.headers_mut(), &decision);
254 response.headers_mut().insert(
255 http::header::RETRY_AFTER,
256 HeaderValue::from_str(&decision.reset_after.as_secs().max(1).to_string())
257 .expect("retry-after must be a valid header"),
258 );
259 response
260}
261
262fn insert_rate_limit_headers(headers: &mut http::HeaderMap, decision: &RateLimitDecision) {
263 headers.insert(
264 "ratelimit-limit",
265 HeaderValue::from_str(&decision.limit.to_string()).expect("limit header must be valid"),
266 );
267 headers.insert(
268 "ratelimit-remaining",
269 HeaderValue::from_str(&decision.remaining.to_string())
270 .expect("remaining header must be valid"),
271 );
272 headers.insert(
273 "ratelimit-reset",
274 HeaderValue::from_str(&decision.reset_after.as_secs().max(1).to_string())
275 .expect("reset header must be valid"),
276 );
277}