use axum::extract::{ConnectInfo, Request};
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::{IntoResponse, Json, Response};
use dashmap::DashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Rpm(u32);
impl Rpm {
pub fn new(rpm: u32) -> Result<Self, String> {
if rpm == 0 {
return Err("rpm must be > 0".into());
}
if rpm > MAX_RPM {
return Err(format!("rpm must be <= {MAX_RPM}"));
}
Ok(Rpm(rpm))
}
pub fn get(self) -> u32 {
self.0
}
pub(crate) fn from_raw(rpm: u32) -> Self {
debug_assert!(rpm > 0 && rpm <= MAX_RPM);
Rpm(rpm)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Burst(u32);
impl Burst {
pub fn new(burst: u32) -> Result<Self, String> {
if burst < 1 {
return Err("burst must be >= 1".into());
}
Ok(Burst(burst))
}
pub fn get(self) -> u32 {
self.0
}
pub(crate) fn from_raw(burst: u32) -> Self {
debug_assert!(burst >= 1);
Burst(burst)
}
}
#[derive(Debug)]
pub struct TokenBucket {
capacity: f64,
refill_per_ms: f64,
tokens: f64,
last_refill: Instant,
last_seen_ms: u64,
}
impl TokenBucket {
pub fn new(capacity: u32, refill_per_ms: f64, now: Instant, now_ms: u64) -> Self {
Self {
capacity: capacity as f64,
refill_per_ms,
tokens: capacity as f64,
last_refill: now,
last_seen_ms: now_ms,
}
}
pub fn try_consume(&mut self, now: Instant, now_ms: u64) -> bool {
let elapsed_ms = now
.saturating_duration_since(self.last_refill)
.as_secs_f64()
* 1000.0;
if elapsed_ms > 0.0 {
self.tokens = (self.tokens + elapsed_ms * self.refill_per_ms).min(self.capacity);
self.last_refill = now;
}
self.last_seen_ms = now_ms;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
}
pub const MAX_RPM: u32 = 60_000;
pub struct RateLimiter {
buckets: DashMap<IpAddr, TokenBucket>,
capacity: Burst,
refill_per_ms: f64,
effective_rpm: Rpm,
}
impl RateLimiter {
pub fn new(rpm: u32, burst: u32) -> Self {
if rpm > MAX_RPM {
tracing::warn!(
rpm,
max_rpm = MAX_RPM,
"rate_limit_per_minute exceeds {MAX_RPM}; clamped to {MAX_RPM} (1 ms minimum interval)"
);
}
let effective_rpm = rpm.clamp(1, MAX_RPM);
let refill_per_ms = effective_rpm as f64 / 60_000.0;
Self {
buckets: DashMap::new(),
capacity: Burst::from_raw(burst.max(1)),
refill_per_ms,
effective_rpm: Rpm::from_raw(effective_rpm),
}
}
pub fn interval_ms(&self) -> u64 {
(60_000u64 / self.effective_rpm.0.max(1) as u64).max(1)
}
pub fn check(&self, ip: IpAddr) -> bool {
let now = Instant::now();
let now_ms = unix_ms();
let mut entry = self
.buckets
.entry(ip)
.or_insert_with(|| TokenBucket::new(self.capacity.0, self.refill_per_ms, now, now_ms));
entry.try_consume(now, now_ms)
}
pub fn evict_stale(&self, older_than: Duration) {
let cutoff = unix_ms().saturating_sub(older_than.as_millis() as u64);
self.buckets
.retain(|_, bucket| bucket.last_seen_ms >= cutoff);
}
#[cfg(test)]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.buckets.len()
}
}
fn unix_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_millis() as u64)
}
pub fn extract_client_ip(req: &Request, trust_proxy: bool) -> Option<IpAddr> {
let direct_ip = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip());
if !trust_proxy {
return direct_ip;
}
if let Some(connect_ip) = direct_ip
&& !connect_ip.is_loopback()
&& !is_rfc1918(connect_ip)
{
return Some(connect_ip);
}
let headers = req.headers();
if let Some(value) = headers.get("x-forwarded-for")
&& let Ok(s) = value.to_str()
{
let first = s.split(',').next().unwrap_or("").trim();
if let Ok(ip) = first.parse::<IpAddr>() {
return Some(ip);
}
}
if let Some(value) = headers.get("x-real-ip")
&& let Ok(s) = value.to_str()
&& let Ok(ip) = s.trim().parse::<IpAddr>()
{
return Some(ip);
}
direct_ip
}
fn is_rfc1918(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
let o = v4.octets();
o[0] == 10 || (o[0] == 172 && (o[1] & 0xF0) == 16) || (o[0] == 192 && o[1] == 168)
}
IpAddr::V6(_) => false,
}
}
pub async fn rate_limit_middleware(
limiter: Arc<RateLimiter>,
trust_proxy: bool,
metrics: Option<Arc<super::metrics::MetricsRegistry>>,
req: Request,
next: Next,
) -> Response {
let Some(ip) = extract_client_ip(&req, trust_proxy) else {
tracing::debug!("rate limit: could not determine client IP");
return next.run(req).await;
};
if limiter.check(ip) {
next.run(req).await
} else {
tracing::debug!(client_ip = %ip, "rate limit rejected request");
if let Some(ref reg) = metrics {
reg.counter_inc("gigastt_rate_limit_rejections_total", vec![], 1);
}
(
StatusCode::TOO_MANY_REQUESTS,
[(axum::http::header::RETRY_AFTER, "60")],
Json(serde_json::json!({
"error": "Too many requests",
"code": "rate_limited",
})),
)
.into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{HeaderValue, Request as HttpRequest};
use std::net::Ipv4Addr;
#[test]
fn test_token_bucket_allows_within_capacity() {
let now = Instant::now();
let mut bucket = TokenBucket::new(5, 0.0, now, unix_ms());
for i in 0..5 {
assert!(bucket.try_consume(now, unix_ms()), "call {i} must succeed");
}
assert!(
!bucket.try_consume(now, unix_ms()),
"6th call must be rate-limited"
);
}
#[test]
fn test_token_bucket_refills_over_time() {
let start = Instant::now();
let mut bucket = TokenBucket::new(2, 1.0, start, unix_ms());
assert!(bucket.try_consume(start, unix_ms()));
assert!(bucket.try_consume(start, unix_ms()));
assert!(
!bucket.try_consume(start, unix_ms()),
"should be drained after 2 consumes"
);
let later = start + Duration::from_millis(3);
assert!(
bucket.try_consume(later, unix_ms()),
"should refill after 3 ms"
);
}
#[test]
fn test_rate_limiter_per_ip_isolation() {
let limiter = RateLimiter::new(1, 1);
let a: IpAddr = "10.0.0.1".parse().unwrap();
let b: IpAddr = "10.0.0.2".parse().unwrap();
assert!(limiter.check(a), "A first call allowed");
assert!(
limiter.check(b),
"B first call allowed (independent bucket)"
);
assert!(!limiter.check(a), "A second call rate-limited");
assert!(!limiter.check(b), "B second call rate-limited");
}
#[test]
fn test_rate_limiter_refill_formula_matches_v1_06() {
for &rpm in &[1u32, 10, 30, 60, 600, 60_000] {
let limiter = RateLimiter::new(rpm, 1);
let ip: IpAddr = "10.0.0.3".parse().unwrap();
assert!(limiter.check(ip), "rpm={rpm}: initial burst allowed");
assert!(
!limiter.check(ip),
"rpm={rpm}: second immediate call blocked"
);
let mut guard = limiter.buckets.get_mut(&ip).expect("bucket exists");
let interval_ms = (60_000u64 / rpm as u64).max(1);
let later = guard.last_refill + Duration::from_millis(interval_ms);
assert!(
guard.try_consume(later, unix_ms()),
"rpm={rpm}: 1 token must refill after {interval_ms} ms",
);
}
}
#[test]
fn test_extract_ip_prefers_forwarded_for_when_trusted() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.headers_mut().insert(
"x-forwarded-for",
HeaderValue::from_static("203.0.113.42 , 10.0.0.1"),
);
req.headers_mut()
.insert("x-real-ip", HeaderValue::from_static("198.51.100.7"));
req.extensions_mut().insert(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
)));
let got = extract_client_ip(&req, true).expect("XFF must be parsed");
assert_eq!(got, "203.0.113.42".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_ip_ignores_forwarded_when_not_trusted() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.headers_mut().insert(
"x-forwarded-for",
HeaderValue::from_static("203.0.113.42 , 10.0.0.1"),
);
req.extensions_mut().insert(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)),
12345,
)));
let got = extract_client_ip(&req, false).expect("ConnectInfo must be used");
assert_eq!(got, IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)));
}
#[test]
fn test_extract_ip_falls_back_to_connect_info() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
55555,
)));
let got = extract_client_ip(&req, true).expect("ConnectInfo fallback");
assert_eq!(got, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
}
#[test]
fn test_extract_ip_uses_real_ip_when_forwarded_for_garbage() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.headers_mut()
.insert("x-forwarded-for", HeaderValue::from_static("not-an-ip"));
req.headers_mut()
.insert("x-real-ip", HeaderValue::from_static("198.51.100.7"));
req.extensions_mut().insert(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
)));
let got = extract_client_ip(&req, true).expect("X-Real-IP fallback");
assert_eq!(got, "198.51.100.7".parse::<IpAddr>().unwrap());
}
#[test]
fn test_extract_ip_skips_headers_when_direct_peer_is_public() {
let mut req = HttpRequest::builder()
.uri("/v1/models")
.body(Body::empty())
.unwrap();
req.headers_mut()
.insert("x-forwarded-for", HeaderValue::from_static("203.0.113.42"));
req.extensions_mut().insert(ConnectInfo(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)),
12345,
)));
let got = extract_client_ip(&req, true).expect("ConnectInfo used");
assert_eq!(got, IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)));
}
#[test]
fn test_eviction_removes_stale() {
let limiter = RateLimiter::new(60, 1);
let fresh: IpAddr = "10.0.0.4".parse().unwrap();
let stale: IpAddr = "10.0.0.5".parse().unwrap();
assert!(limiter.check(fresh));
assert!(limiter.check(stale));
{
let mut guard = limiter.buckets.get_mut(&stale).expect("stale bucket");
guard.last_seen_ms = unix_ms().saturating_sub(10 * 60_000); }
limiter.evict_stale(Duration::from_secs(60));
assert_eq!(limiter.len(), 1, "stale bucket should be evicted");
assert!(
limiter.buckets.contains_key(&fresh),
"fresh bucket must survive eviction"
);
}
#[test]
fn test_rpm_new_rejects_zero() {
assert!(Rpm::new(0).is_err());
}
#[test]
fn test_rpm_new_rejects_too_high() {
assert!(Rpm::new(MAX_RPM + 1).is_err());
}
#[test]
fn test_rpm_new_accepts_valid() {
let r = Rpm::new(30).unwrap();
assert_eq!(r.get(), 30);
}
#[test]
fn test_burst_new_rejects_zero() {
assert!(Burst::new(0).is_err());
}
#[test]
fn test_burst_new_accepts_valid() {
let b = Burst::new(5).unwrap();
assert_eq!(b.get(), 5);
}
}