api_gateway/middleware/
rate_limit.rs1use crate::config::ApiGatewayConfig;
2use anyhow::{Context, Result, anyhow};
3use axum::http::{HeaderValue, Method, StatusCode, header};
4use axum::{
5 extract::Request,
6 middleware::Next,
7 response::{IntoResponse, Response},
8};
9use governor::clock::Clock;
10use governor::middleware::StateInformationMiddleware;
11use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
12use std::collections::HashMap;
13use std::num::NonZeroU32;
14use std::sync::Arc;
15use tokio::sync::Semaphore;
16
17use crate::middleware::common;
18
19type RateLimitKey = (Method, String);
20type BucketMap = Arc<HashMap<RateLimitKey, Arc<BucketMapEntry>>>;
21type InflightMap = Arc<HashMap<RateLimitKey, Arc<Semaphore>>>;
22
23#[derive(Default, Clone)]
24pub struct RateLimiterMap {
25 buckets: BucketMap,
26 inflight: InflightMap,
27}
28
29struct BucketMapEntry {
30 bucket: DefaultDirectRateLimiter<StateInformationMiddleware>,
31 policy: HeaderValue,
32 burst: HeaderValue,
33}
34
35impl BucketMapEntry {
36 pub fn new(rps: u32, burst: u32) -> Result<Self> {
37 let bucket = RateLimiter::direct(
38 Quota::per_second(NonZeroU32::new(rps).with_context(|| anyhow!("rps is zero"))?)
39 .allow_burst(NonZeroU32::new(burst).with_context(|| anyhow!("burst is zero"))?),
40 )
41 .with_middleware::<StateInformationMiddleware>();
42 let policy = HeaderValue::from_str(&format!("\"burst\";q={burst};w={rps}"))
43 .context("Failed to create rate limit policy")?;
44 Ok(Self {
45 bucket,
46 policy,
47 burst: burst.into(),
48 })
49 }
50}
51
52impl RateLimiterMap {
53 pub fn from_specs(
56 specs: &Vec<modkit::api::OperationSpec>,
57 cfg: &ApiGatewayConfig,
58 ) -> Result<Self> {
59 let mut buckets = HashMap::new();
60 let mut inflight = HashMap::new();
61 for spec in specs {
63 let (rps, burst, max_in_flight) = spec.rate_limit.as_ref().map_or(
64 (
65 cfg.defaults.rate_limit.rps,
66 cfg.defaults.rate_limit.burst,
67 cfg.defaults.rate_limit.in_flight,
68 ),
69 |r| (r.rps, r.burst, r.in_flight),
70 );
71
72 let key = (spec.method.clone(), spec.path.clone());
73 buckets.insert(
74 key.clone(),
75 Arc::new(
76 BucketMapEntry::new(rps, burst)
77 .with_context(|| anyhow!("RateLimit spec invalid {spec:?} invalid"))?,
78 ),
79 );
80 inflight.insert(key, Arc::new(Semaphore::new(max_in_flight as usize)));
81 }
82 Ok(Self {
83 buckets: Arc::new(buckets),
84 inflight: Arc::new(inflight),
85 })
86 }
87}
88
89pub async fn rate_limit_middleware(map: RateLimiterMap, mut req: Request, next: Next) -> Response {
91 let method = req.method().clone();
92 let path = req
94 .extensions()
95 .get::<axum::extract::MatchedPath>()
96 .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
97
98 let path = common::resolve_path(&req, path.as_str());
99
100 let key = (method, path);
101
102 if let Some(bucker_map_entry) = map.buckets.get(&key) {
103 let headers = req.headers_mut();
104 headers.insert("RateLimit-Policy", bucker_map_entry.policy.clone());
105 match bucker_map_entry.bucket.check() {
106 Ok(state) => {
107 headers.insert("RateLimit-Limit", bucker_map_entry.burst.clone());
108 headers.insert(
109 "RateLimit-Limit-Remaining",
110 state.remaining_burst_capacity().into(),
111 );
112 headers.insert("X-RateLimit-Limit", bucker_map_entry.burst.clone());
113 headers.insert(
114 "X-RateLimit-Remaining",
115 state.remaining_burst_capacity().into(),
116 );
117 }
118 Err(not_until) => {
119 let wait = not_until.wait_time_from(bucker_map_entry.bucket.clock().now());
120 headers.insert(header::RETRY_AFTER, wait.as_secs().into());
121 return StatusCode::TOO_MANY_REQUESTS.into_response();
122 }
123 }
124 }
125
126 if let Some(sem) = map.inflight.get(&key) {
127 match sem.clone().try_acquire_owned() {
128 Ok(_permit) => {
129 return next.run(req).await;
131 }
132 Err(_) => {
133 return StatusCode::SERVICE_UNAVAILABLE.into_response();
134 }
135 }
136 }
137
138 next.run(req).await
139}