reasonkit_web/portal/
middleware.rs1use axum::{
16 extract::Request,
17 http::{header, StatusCode},
18 middleware::Next,
19 response::Response,
20 Json,
21};
22use serde::Serialize;
23
24use crate::portal::auth::{AuthService, Claims};
25
26#[derive(Debug, Serialize)]
28pub struct AuthErrorResponse {
29 pub error: String,
30 pub code: String,
31}
32
33fn extract_token(req: &Request) -> Option<&str> {
35 req.headers()
36 .get(header::AUTHORIZATION)
37 .and_then(|value| value.to_str().ok())
38 .and_then(|value| value.strip_prefix("Bearer "))
39}
40
41pub async fn require_auth(
45 req: Request,
46 next: Next,
47) -> Result<Response, (StatusCode, Json<AuthErrorResponse>)> {
48 let token = extract_token(&req).ok_or_else(|| {
50 (
51 StatusCode::UNAUTHORIZED,
52 Json(AuthErrorResponse {
53 error: "Missing or invalid Authorization header".to_string(),
54 code: "MISSING_TOKEN".to_string(),
55 }),
56 )
57 })?;
58
59 let auth_service = AuthService::new(Default::default());
61
62 let claims = auth_service.validate_token(token).map_err(|e| {
64 (
65 StatusCode::UNAUTHORIZED,
66 Json(AuthErrorResponse {
67 error: e.to_string(),
68 code: "INVALID_TOKEN".to_string(),
69 }),
70 )
71 })?;
72
73 if claims.token_type != crate::portal::auth::TokenType::Access {
75 return Err((
76 StatusCode::UNAUTHORIZED,
77 Json(AuthErrorResponse {
78 error: "Invalid token type".to_string(),
79 code: "WRONG_TOKEN_TYPE".to_string(),
80 }),
81 ));
82 }
83
84 let mut req = req;
86 req.extensions_mut().insert(claims);
87
88 Ok(next.run(req).await)
89}
90
91pub async fn optional_auth(req: Request, next: Next) -> Response {
95 if let Some(token) = extract_token(&req) {
97 let auth_service = AuthService::new(Default::default());
98 if let Ok(claims) = auth_service.validate_token(token) {
99 if claims.token_type == crate::portal::auth::TokenType::Access {
100 let mut req = req;
101 req.extensions_mut().insert(claims);
102 return next.run(req).await;
103 }
104 }
105 }
106
107 next.run(req).await
109}
110
111#[allow(clippy::type_complexity)]
115pub fn require_scope(
116 required_scope: &'static str,
117) -> impl Fn(
118 Request,
119 Next,
120) -> std::pin::Pin<
121 Box<
122 dyn std::future::Future<Output = Result<Response, (StatusCode, Json<AuthErrorResponse>)>>
123 + Send,
124 >,
125> + Clone {
126 move |req: Request, next: Next| {
127 Box::pin(async move {
128 let claims = req.extensions().get::<Claims>().ok_or_else(|| {
130 (
131 StatusCode::UNAUTHORIZED,
132 Json(AuthErrorResponse {
133 error: "Not authenticated".to_string(),
134 code: "NOT_AUTHENTICATED".to_string(),
135 }),
136 )
137 })?;
138
139 if !claims
141 .scopes
142 .iter()
143 .any(|s| s == required_scope || s == "admin")
144 {
145 return Err((
146 StatusCode::FORBIDDEN,
147 Json(AuthErrorResponse {
148 error: format!("Missing required scope: {}", required_scope),
149 code: "INSUFFICIENT_SCOPE".to_string(),
150 }),
151 ));
152 }
153
154 Ok(next.run(req).await)
155 })
156 }
157}
158
159pub fn get_claims(req: &Request) -> Option<&Claims> {
161 req.extensions().get::<Claims>()
162}
163
164#[derive(Debug, Clone)]
166pub struct AuthClaims(pub Claims);
167
168#[axum::async_trait]
169impl<S> axum::extract::FromRequestParts<S> for AuthClaims
170where
171 S: Send + Sync,
172{
173 type Rejection = (StatusCode, Json<AuthErrorResponse>);
174
175 async fn from_request_parts(
176 parts: &mut axum::http::request::Parts,
177 _state: &S,
178 ) -> Result<Self, Self::Rejection> {
179 parts
180 .extensions
181 .get::<Claims>()
182 .cloned()
183 .map(AuthClaims)
184 .ok_or_else(|| {
185 (
186 StatusCode::UNAUTHORIZED,
187 Json(AuthErrorResponse {
188 error: "Not authenticated".to_string(),
189 code: "NOT_AUTHENTICATED".to_string(),
190 }),
191 )
192 })
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use axum::body::Body;
200 use axum::http::Request;
201
202 #[test]
203 fn test_extract_token_valid() {
204 let req = Request::builder()
205 .header("Authorization", "Bearer test_token_123")
206 .body(Body::empty())
207 .unwrap();
208
209 assert_eq!(extract_token(&req), Some("test_token_123"));
210 }
211
212 #[test]
213 fn test_extract_token_missing() {
214 let req = Request::builder().body(Body::empty()).unwrap();
215
216 assert_eq!(extract_token(&req), None);
217 }
218
219 #[test]
220 fn test_extract_token_invalid_format() {
221 let req = Request::builder()
222 .header("Authorization", "Basic user:pass")
223 .body(Body::empty())
224 .unwrap();
225
226 assert_eq!(extract_token(&req), None);
227 }
228}