Skip to main content

nidus_http/middleware/
rate_limit.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    pin::Pin,
5    sync::{Arc, Mutex},
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9
10use axum::{body::Body, extract::Request};
11use http::{HeaderValue, Response, StatusCode};
12use tower::{Layer, Service};
13
14use crate::context::{IdentityExtractor, RequestIdentity};
15
16type IdentityFn =
17    Arc<dyn Fn(&http::request::Parts) -> Option<RequestIdentity> + Send + Sync + 'static>;
18
19/// Error returned by rate-limit stores.
20#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
21#[error("{message}")]
22pub struct RateLimitError {
23    message: String,
24}
25
26impl RateLimitError {
27    /// Creates a rate-limit store error.
28    pub fn new(message: impl Into<String>) -> Self {
29        Self {
30            message: message.into(),
31        }
32    }
33}
34
35/// Decision returned by a rate-limit store.
36#[derive(Clone, Debug, Eq, PartialEq)]
37pub struct RateLimitDecision {
38    /// Whether the request is allowed.
39    pub allowed: bool,
40    /// Limit for the active window.
41    pub limit: u64,
42    /// Remaining requests in the active window.
43    pub remaining: u64,
44    /// Seconds until the window resets.
45    pub reset_after: Duration,
46}
47
48/// Store adapter used by rate-limit layers.
49pub trait RateLimitStore: Send + Sync + 'static {
50    /// Checks and consumes one request for an identity.
51    fn check(
52        &self,
53        identity: &RequestIdentity,
54        limit: u64,
55        window: Duration,
56    ) -> Result<RateLimitDecision, RateLimitError>;
57}
58
59/// In-memory rate-limit store intended for local development and single-process apps.
60///
61/// The store tracks counters in process memory and opportunistically removes
62/// expired identity windows whenever [`RateLimitStore::check`] runs. It is not
63/// shared across processes, not durable across restarts, and not a substitute
64/// for a distributed limiter at multi-instance production boundaries.
65#[derive(Clone, Default)]
66pub struct InMemoryRateLimitStore {
67    state: Arc<Mutex<HashMap<String, WindowState>>>,
68}
69
70impl InMemoryRateLimitStore {
71    /// Creates an empty in-memory rate-limit store.
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    /// Returns the number of identity windows currently retained by the store.
77    ///
78    /// Expired windows are pruned opportunistically during [`RateLimitStore::check`],
79    /// so this value is mainly useful for tests, diagnostics, and local tools.
80    pub fn len(&self) -> usize {
81        self.state
82            .lock()
83            .map(|state| state.len())
84            .unwrap_or_default()
85    }
86
87    /// Returns whether the store currently retains no identity windows.
88    pub fn is_empty(&self) -> bool {
89        self.len() == 0
90    }
91}
92
93impl RateLimitStore for InMemoryRateLimitStore {
94    fn check(
95        &self,
96        identity: &RequestIdentity,
97        limit: u64,
98        window: Duration,
99    ) -> Result<RateLimitDecision, RateLimitError> {
100        let now = Instant::now();
101        let mut state = self
102            .state
103            .lock()
104            .map_err(|_| RateLimitError::new("rate limit store poisoned"))?;
105        state.retain(|_, window_state| now.duration_since(window_state.started_at) < window);
106        let window_state = state
107            .entry(identity.as_str().to_owned())
108            .or_insert_with(|| WindowState {
109                started_at: now,
110                count: 0,
111            });
112        if now.duration_since(window_state.started_at) >= window {
113            window_state.started_at = now;
114            window_state.count = 0;
115        }
116
117        let allowed = window_state.count < limit;
118        if allowed {
119            window_state.count += 1;
120        }
121        let remaining = limit.saturating_sub(window_state.count);
122        let reset_after = window.saturating_sub(now.duration_since(window_state.started_at));
123        Ok(RateLimitDecision {
124            allowed,
125            limit,
126            remaining,
127            reset_after,
128        })
129    }
130}
131
132#[derive(Clone)]
133struct WindowState {
134    started_at: Instant,
135    count: u64,
136}
137
138/// Typed config for production-shaped rate limiting.
139#[derive(Clone)]
140pub struct RateLimitConfig {
141    limit: u64,
142    window: Duration,
143    store: Arc<dyn RateLimitStore>,
144    identity: IdentityFn,
145    fail_open: bool,
146}
147
148impl RateLimitConfig {
149    /// Creates a rate-limit config with an explicit store.
150    pub fn new(limit: u64, window: Duration, store: impl RateLimitStore) -> Self {
151        Self {
152            limit,
153            window,
154            store: Arc::new(store),
155            identity: Arc::new(|_parts| Some(RequestIdentity::new("anonymous"))),
156            fail_open: true,
157        }
158    }
159
160    /// Replaces the identity extractor.
161    pub fn identity(mut self, extractor: impl IdentityExtractor) -> Self {
162        self.identity = Arc::new(move |parts| extractor.extract(parts));
163        self
164    }
165
166    /// Allows requests when the backing store fails.
167    pub fn fail_open(mut self) -> Self {
168        self.fail_open = true;
169        self
170    }
171
172    /// Rejects requests when the backing store fails.
173    pub fn fail_closed(mut self) -> Self {
174        self.fail_open = false;
175        self
176    }
177
178    /// Creates a Tower layer from this config.
179    pub fn layer(self) -> RateLimitLayer {
180        RateLimitLayer { config: self }
181    }
182}
183
184/// Tower layer that applies configured rate limiting.
185#[derive(Clone)]
186pub struct RateLimitLayer {
187    config: RateLimitConfig,
188}
189
190impl<S> Layer<S> for RateLimitLayer {
191    type Service = RateLimitService<S>;
192
193    fn layer(&self, inner: S) -> Self::Service {
194        RateLimitService {
195            inner,
196            config: self.config.clone(),
197        }
198    }
199}
200
201/// Service produced by [`RateLimitLayer`].
202#[derive(Clone)]
203pub struct RateLimitService<S> {
204    inner: S,
205    config: RateLimitConfig,
206}
207
208impl<S> Service<Request> for RateLimitService<S>
209where
210    S: Service<Request, Response = Response<Body>> + Send + 'static,
211    S::Future: Send + 'static,
212    S::Error: Send + 'static,
213{
214    type Response = Response<Body>;
215    type Error = S::Error;
216    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
217
218    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
219        self.inner.poll_ready(cx)
220    }
221
222    fn call(&mut self, request: Request) -> Self::Future {
223        let config = self.config.clone();
224        let (parts, body) = request.into_parts();
225        let identity =
226            (config.identity)(&parts).unwrap_or_else(|| RequestIdentity::new("anonymous"));
227        let decision = config
228            .store
229            .check(&identity, config.limit, config.window)
230            .unwrap_or(RateLimitDecision {
231                allowed: config.fail_open,
232                limit: config.limit,
233                remaining: if config.fail_open { config.limit } else { 0 },
234                reset_after: config.window,
235            });
236
237        if !decision.allowed {
238            return Box::pin(async move { Ok(rate_limited_response(decision)) });
239        }
240
241        let future = self.inner.call(Request::from_parts(parts, body));
242        Box::pin(async move {
243            let mut response = future.await?;
244            insert_rate_limit_headers(response.headers_mut(), &decision);
245            Ok(response)
246        })
247    }
248}
249
250fn rate_limited_response(decision: RateLimitDecision) -> Response<Body> {
251    let mut response = Response::new(Body::from("rate limit exceeded"));
252    *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
253    insert_rate_limit_headers(response.headers_mut(), &decision);
254    response.headers_mut().insert(
255        http::header::RETRY_AFTER,
256        HeaderValue::from_str(&decision.reset_after.as_secs().max(1).to_string())
257            .expect("retry-after must be a valid header"),
258    );
259    response
260}
261
262fn insert_rate_limit_headers(headers: &mut http::HeaderMap, decision: &RateLimitDecision) {
263    headers.insert(
264        "ratelimit-limit",
265        HeaderValue::from_str(&decision.limit.to_string()).expect("limit header must be valid"),
266    );
267    headers.insert(
268        "ratelimit-remaining",
269        HeaderValue::from_str(&decision.remaining.to_string())
270            .expect("remaining header must be valid"),
271    );
272    headers.insert(
273        "ratelimit-reset",
274        HeaderValue::from_str(&decision.reset_after.as_secs().max(1).to_string())
275            .expect("reset header must be valid"),
276    );
277}