ares/middleware/
api_key_auth.rs1use crate::db::tenants::TenantDb;
2use axum::{
3 extract::Request,
4 http::StatusCode,
5 middleware::Next,
6 response::{IntoResponse, Response},
7 Json,
8};
9use std::sync::Arc;
10
11pub async fn api_key_auth_middleware(req: Request, next: Next) -> Response {
12 let auth_header = match req.headers().get("authorization") {
13 Some(h) => h,
14 None => {
15 return error_response(StatusCode::UNAUTHORIZED, "Missing Authorization header");
16 }
17 };
18
19 let auth_str = match auth_header.to_str() {
20 Ok(s) => s,
21 Err(_) => {
22 return error_response(StatusCode::UNAUTHORIZED, "Invalid Authorization header");
23 }
24 };
25
26 let api_key = match auth_str.strip_prefix("Bearer ") {
27 Some(k) => k,
28 None => {
29 return error_response(
30 StatusCode::UNAUTHORIZED,
31 "Invalid Authorization format. Expected: Bearer ares_...",
32 );
33 }
34 };
35
36 if !api_key.starts_with("ares_") {
37 return error_response(
38 StatusCode::UNAUTHORIZED,
39 "Invalid API key format. Must start with ares_",
40 );
41 }
42
43 let extensions = req.extensions();
44 let tenant_db: Arc<TenantDb> = match extensions.get::<Arc<TenantDb>>() {
45 Some(db) => db.clone(),
46 None => {
47 return error_response(
48 StatusCode::INTERNAL_SERVER_ERROR,
49 "Tenant database not configured",
50 );
51 }
52 };
53
54 let tenant_ctx = match tenant_db.verify_api_key(api_key).await {
55 Ok(Some(ctx)) => ctx,
56 Ok(None) => {
57 return error_response(StatusCode::UNAUTHORIZED, "Invalid API key");
58 }
59 Err(e) => {
60 tracing::error!("API key verification error: {}", e);
61 return error_response(
62 StatusCode::INTERNAL_SERVER_ERROR,
63 "Failed to verify API key",
64 );
65 }
66 };
67
68 let monthly_usage = match tenant_db.get_monthly_requests(&tenant_ctx.tenant_id).await {
69 Ok(m) => m,
70 Err(_) => {
71 return error_response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to check usage");
72 }
73 };
74
75 let daily_usage = match tenant_db.get_daily_requests(&tenant_ctx.tenant_id).await {
76 Ok(d) => d,
77 Err(_) => {
78 return error_response(
79 StatusCode::INTERNAL_SERVER_ERROR,
80 "Failed to check rate limit",
81 );
82 }
83 };
84
85 if !tenant_ctx.can_make_request(monthly_usage, daily_usage) {
86 if monthly_usage >= tenant_ctx.quota.requests_per_month {
87 return error_response(
88 StatusCode::TOO_MANY_REQUESTS,
89 "Monthly request quota exceeded",
90 );
91 }
92 if daily_usage >= tenant_ctx.quota.requests_per_day {
93 return error_response(StatusCode::TOO_MANY_REQUESTS, "Daily rate limit exceeded");
94 }
95 }
96
97 let mut req = req;
98 req.extensions_mut().insert(tenant_ctx);
99
100 next.run(req).await
101}
102
103fn error_response(status: StatusCode, message: &str) -> Response {
104 let body = Json(serde_json::json!({
105 "error": message
106 }));
107 (status, body).into_response()
108}
109
110pub use crate::auth::middleware::AuthUser;
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use axum::{
116 body::Body,
117 http::{Request, StatusCode},
118 routing::get,
119 Router,
120 };
121 use tower::ServiceExt;
122
123 async fn protected_handler() -> &'static str {
124 "protected content"
125 }
126
127 #[tokio::test]
128 async fn test_middleware_no_auth_header() {
129 let app = Router::new()
130 .route("/protected", get(protected_handler))
131 .layer(axum::middleware::from_fn(api_key_auth_middleware));
132
133 let response = app
134 .oneshot(
135 Request::builder()
136 .uri("/protected")
137 .body(Body::empty())
138 .unwrap(),
139 )
140 .await
141 .unwrap();
142
143 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
144 }
145
146 #[tokio::test]
147 async fn test_middleware_invalid_format() {
148 let app = Router::new()
149 .route("/protected", get(protected_handler))
150 .layer(axum::middleware::from_fn(api_key_auth_middleware));
151
152 let response = app
153 .oneshot(
154 Request::builder()
155 .uri("/protected")
156 .header("Authorization", "Basic abc123")
157 .body(Body::empty())
158 .unwrap(),
159 )
160 .await
161 .unwrap();
162
163 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
164 }
165
166 #[tokio::test]
167 async fn test_middleware_missing_prefix() {
168 let app = Router::new()
169 .route("/protected", get(protected_handler))
170 .layer(axum::middleware::from_fn(api_key_auth_middleware));
171
172 let response = app
173 .oneshot(
174 Request::builder()
175 .uri("/protected")
176 .header("Authorization", "Bearer abc123")
177 .body(Body::empty())
178 .unwrap(),
179 )
180 .await
181 .unwrap();
182
183 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
184 }
185}