aiclient_api/server/
middleware.rs1use axum::extract::{ConnectInfo, Request, State};
2use axum::middleware::Next;
3use axum::response::Response;
4use std::collections::HashMap;
5use std::net::{IpAddr, SocketAddr};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tokio::time::Instant;
9use tower_http::cors::CorsLayer;
10use uuid::Uuid;
11
12use crate::server::state::AppState;
13use crate::util::error::AppError;
14
15pub type RateLimitMap = Arc<RwLock<HashMap<IpAddr, Instant>>>;
16
17pub fn new_rate_limit_map() -> RateLimitMap {
18 Arc::new(RwLock::new(HashMap::new()))
19}
20
21pub async fn request_id(mut req: Request, next: Next) -> Response {
22 let id = Uuid::new_v4().to_string();
23 req.headers_mut().insert("x-request-id", id.parse().unwrap());
24 next.run(req).await
25}
26
27pub fn cors_layer() -> CorsLayer {
28 CorsLayer::very_permissive()
29}
30
31fn is_anthropic_path(uri: &axum::http::Uri) -> bool {
32 uri.path().contains("/messages")
33}
34
35fn middleware_error(uri: &axum::http::Uri, err: AppError) -> Response {
36 let (status, msg) = err.status_and_message();
37 if is_anthropic_path(uri) {
38 AppError::anthropic_error(status, &msg)
39 } else {
40 AppError::openai_error(status, &msg)
41 }
42}
43
44pub async fn auth(
48 State(state): State<AppState>,
49 req: Request,
50 next: Next,
51) -> Response {
52 let config = state.config.load();
53 let api_key = &config.api_key;
54
55 if api_key.is_empty() {
56 return next.run(req).await;
57 }
58
59 let auth_header = req
60 .headers()
61 .get("authorization")
62 .and_then(|v| v.to_str().ok())
63 .map(String::from);
64
65 let uri = req.uri().clone();
66
67 match auth_header.as_deref() {
68 Some(header) if header.starts_with("Bearer ") => {
69 let token = &header[7..];
70 if token == api_key {
71 next.run(req).await
72 } else {
73 middleware_error(&uri, AppError::Unauthorized("Invalid API key".to_string()))
74 }
75 }
76 Some(_) => middleware_error(
77 &uri,
78 AppError::Unauthorized(
79 "Invalid authorization format, expected Bearer token".to_string(),
80 ),
81 ),
82 None => middleware_error(
83 &uri,
84 AppError::Unauthorized("Missing Authorization header".to_string()),
85 ),
86 }
87}
88
89pub async fn rate_limit(
93 State(state): State<AppState>,
94 State(limiter): State<RateLimitMap>,
95 req: Request,
96 next: Next,
97) -> Response {
98 let config = state.config.load();
99 let limit_secs = config.server.rate_limit_seconds;
100
101 if limit_secs == 0 {
102 return next.run(req).await;
103 }
104
105 let ip = req
107 .extensions()
108 .get::<ConnectInfo<SocketAddr>>()
109 .map(|ci| ci.0.ip())
110 .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
111
112 let uri = req.uri().clone();
113 let now = Instant::now();
114 let interval = std::time::Duration::from_secs(limit_secs);
115
116 {
117 let mut map = limiter.write().await;
118 if let Some(last) = map.get(&ip) {
119 if now.duration_since(*last) < interval {
120 return middleware_error(&uri, AppError::RateLimited);
121 }
122 }
123 map.insert(ip, now);
124
125 if map.len() > 10_000 {
127 map.retain(|_, last| now.duration_since(*last) < interval);
128 }
129 }
130
131 next.run(req).await
132}