Skip to main content

aiclient_api/server/
middleware.rs

1use axum::extract::{ConnectInfo, Request, State};
2use axum::middleware::Next;
3use axum::response::Response;
4use std::collections::HashMap;
5use std::net::{IpAddr, SocketAddr};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tokio::time::Instant;
9use tower_http::cors::CorsLayer;
10use uuid::Uuid;
11
12use crate::server::state::AppState;
13use crate::util::error::AppError;
14
15pub type RateLimitMap = Arc<RwLock<HashMap<IpAddr, Instant>>>;
16
17pub fn new_rate_limit_map() -> RateLimitMap {
18    Arc::new(RwLock::new(HashMap::new()))
19}
20
21pub async fn request_id(mut req: Request, next: Next) -> Response {
22    let id = Uuid::new_v4().to_string();
23    req.headers_mut().insert("x-request-id", id.parse().unwrap());
24    next.run(req).await
25}
26
27pub fn cors_layer() -> CorsLayer {
28    CorsLayer::very_permissive()
29}
30
31fn is_anthropic_path(uri: &axum::http::Uri) -> bool {
32    uri.path().contains("/messages")
33}
34
35fn middleware_error(uri: &axum::http::Uri, err: AppError) -> Response {
36    let (status, msg) = err.status_and_message();
37    if is_anthropic_path(uri) {
38        AppError::anthropic_error(status, &msg)
39    } else {
40        AppError::openai_error(status, &msg)
41    }
42}
43
44/// Bearer token auth middleware.
45/// If `config.api_key` is non-empty, validates `Authorization: Bearer <key>`.
46/// If `api_key` is empty, all requests are allowed through.
47pub async fn auth(
48    State(state): State<AppState>,
49    req: Request,
50    next: Next,
51) -> Response {
52    let config = state.config.load();
53    let api_key = &config.api_key;
54
55    if api_key.is_empty() {
56        return next.run(req).await;
57    }
58
59    let auth_header = req
60        .headers()
61        .get("authorization")
62        .and_then(|v| v.to_str().ok())
63        .map(String::from);
64
65    let uri = req.uri().clone();
66
67    match auth_header.as_deref() {
68        Some(header) if header.starts_with("Bearer ") => {
69            let token = &header[7..];
70            if token == api_key {
71                next.run(req).await
72            } else {
73                middleware_error(&uri, AppError::Unauthorized("Invalid API key".to_string()))
74            }
75        }
76        Some(_) => middleware_error(
77            &uri,
78            AppError::Unauthorized(
79                "Invalid authorization format, expected Bearer token".to_string(),
80            ),
81        ),
82        None => middleware_error(
83            &uri,
84            AppError::Unauthorized("Missing Authorization header".to_string()),
85        ),
86    }
87}
88
89/// Per-IP rate limiting middleware.
90/// If `config.server.rate_limit_seconds > 0`, rejects requests that come
91/// faster than the configured interval with 429 Too Many Requests.
92pub async fn rate_limit(
93    State(state): State<AppState>,
94    State(limiter): State<RateLimitMap>,
95    req: Request,
96    next: Next,
97) -> Response {
98    let config = state.config.load();
99    let limit_secs = config.server.rate_limit_seconds;
100
101    if limit_secs == 0 {
102        return next.run(req).await;
103    }
104
105    // Extract client IP from ConnectInfo or fall back to a default
106    let ip = req
107        .extensions()
108        .get::<ConnectInfo<SocketAddr>>()
109        .map(|ci| ci.0.ip())
110        .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
111
112    let uri = req.uri().clone();
113    let now = Instant::now();
114    let interval = std::time::Duration::from_secs(limit_secs);
115
116    {
117        let mut map = limiter.write().await;
118        if let Some(last) = map.get(&ip) {
119            if now.duration_since(*last) < interval {
120                return middleware_error(&uri, AppError::RateLimited);
121            }
122        }
123        map.insert(ip, now);
124
125        // Periodic cleanup
126        if map.len() > 10_000 {
127            map.retain(|_, last| now.duration_since(*last) < interval);
128        }
129    }
130
131    next.run(req).await
132}