Skip to main content

ares/middleware/
api_key_auth.rs

1use crate::db::tenants::TenantDb;
2use axum::{
3    extract::Request,
4    http::StatusCode,
5    middleware::Next,
6    response::{IntoResponse, Response},
7    Json,
8};
9use std::sync::Arc;
10
11pub async fn api_key_auth_middleware(req: Request, next: Next) -> Response {
12    let auth_header = match req.headers().get("authorization") {
13        Some(h) => h,
14        None => {
15            return error_response(StatusCode::UNAUTHORIZED, "Missing Authorization header");
16        }
17    };
18
19    let auth_str = match auth_header.to_str() {
20        Ok(s) => s,
21        Err(_) => {
22            return error_response(StatusCode::UNAUTHORIZED, "Invalid Authorization header");
23        }
24    };
25
26    let api_key = match auth_str.strip_prefix("Bearer ") {
27        Some(k) => k,
28        None => {
29            return error_response(
30                StatusCode::UNAUTHORIZED,
31                "Invalid Authorization format. Expected: Bearer ares_...",
32            );
33        }
34    };
35
36    if !api_key.starts_with("ares_") {
37        return error_response(
38            StatusCode::UNAUTHORIZED,
39            "Invalid API key format. Must start with ares_",
40        );
41    }
42
43    let extensions = req.extensions();
44    let tenant_db: Arc<TenantDb> = match extensions.get::<Arc<TenantDb>>() {
45        Some(db) => db.clone(),
46        None => {
47            return error_response(
48                StatusCode::INTERNAL_SERVER_ERROR,
49                "Tenant database not configured",
50            );
51        }
52    };
53
54    let tenant_ctx = match tenant_db.verify_api_key(api_key).await {
55        Ok(Some(ctx)) => ctx,
56        Ok(None) => {
57            return error_response(StatusCode::UNAUTHORIZED, "Invalid API key");
58        }
59        Err(e) => {
60            tracing::error!("API key verification error: {}", e);
61            return error_response(
62                StatusCode::INTERNAL_SERVER_ERROR,
63                "Failed to verify API key",
64            );
65        }
66    };
67
68    let monthly_usage = match tenant_db.get_monthly_requests(&tenant_ctx.tenant_id).await {
69        Ok(m) => m,
70        Err(_) => {
71            return error_response(StatusCode::INTERNAL_SERVER_ERROR, "Failed to check usage");
72        }
73    };
74
75    let daily_usage = match tenant_db.get_daily_requests(&tenant_ctx.tenant_id).await {
76        Ok(d) => d,
77        Err(_) => {
78            return error_response(
79                StatusCode::INTERNAL_SERVER_ERROR,
80                "Failed to check rate limit",
81            );
82        }
83    };
84
85    if !tenant_ctx.can_make_request(monthly_usage, daily_usage) {
86        if monthly_usage >= tenant_ctx.quota.requests_per_month {
87            return error_response(
88                StatusCode::TOO_MANY_REQUESTS,
89                "Monthly request quota exceeded",
90            );
91        }
92        if daily_usage >= tenant_ctx.quota.requests_per_day {
93            return error_response(StatusCode::TOO_MANY_REQUESTS, "Daily rate limit exceeded");
94        }
95    }
96
97    let mut req = req;
98    req.extensions_mut().insert(tenant_ctx);
99
100    next.run(req).await
101}
102
103fn error_response(status: StatusCode, message: &str) -> Response {
104    let body = Json(serde_json::json!({
105        "error": message
106    }));
107    (status, body).into_response()
108}
109
110pub use crate::auth::middleware::AuthUser;
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use axum::{
116        body::Body,
117        http::{Request, StatusCode},
118        routing::get,
119        Router,
120    };
121    use tower::ServiceExt;
122
123    async fn protected_handler() -> &'static str {
124        "protected content"
125    }
126
127    #[tokio::test]
128    async fn test_middleware_no_auth_header() {
129        let app = Router::new()
130            .route("/protected", get(protected_handler))
131            .layer(axum::middleware::from_fn(api_key_auth_middleware));
132
133        let response = app
134            .oneshot(
135                Request::builder()
136                    .uri("/protected")
137                    .body(Body::empty())
138                    .unwrap(),
139            )
140            .await
141            .unwrap();
142
143        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
144    }
145
146    #[tokio::test]
147    async fn test_middleware_invalid_format() {
148        let app = Router::new()
149            .route("/protected", get(protected_handler))
150            .layer(axum::middleware::from_fn(api_key_auth_middleware));
151
152        let response = app
153            .oneshot(
154                Request::builder()
155                    .uri("/protected")
156                    .header("Authorization", "Basic abc123")
157                    .body(Body::empty())
158                    .unwrap(),
159            )
160            .await
161            .unwrap();
162
163        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
164    }
165
166    #[tokio::test]
167    async fn test_middleware_missing_prefix() {
168        let app = Router::new()
169            .route("/protected", get(protected_handler))
170            .layer(axum::middleware::from_fn(api_key_auth_middleware));
171
172        let response = app
173            .oneshot(
174                Request::builder()
175                    .uri("/protected")
176                    .header("Authorization", "Bearer abc123")
177                    .body(Body::empty())
178                    .unwrap(),
179            )
180            .await
181            .unwrap();
182
183        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
184    }
185}