1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::{
5 body::Body,
6 extract::{Request, State},
7 middleware::Next,
8 response::Response,
9};
10use forge_core::auth::Claims;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{dangerous, decode, Algorithm, DecodingKey, Validation};
13use uuid::Uuid;
14
15#[derive(Debug, Clone)]
17pub struct AuthConfig {
18 pub jwt_secret: String,
20 pub algorithm: JwtAlgorithm,
22 pub allow_anonymous: bool,
24 pub skip_verification: bool,
27}
28
29impl Default for AuthConfig {
30 fn default() -> Self {
31 Self {
32 jwt_secret: String::new(),
33 algorithm: JwtAlgorithm::HS256,
34 allow_anonymous: true,
35 skip_verification: false,
36 }
37 }
38}
39
40impl AuthConfig {
41 pub fn with_secret(secret: impl Into<String>) -> Self {
43 Self {
44 jwt_secret: secret.into(),
45 ..Default::default()
46 }
47 }
48
49 pub fn dev_mode() -> Self {
52 Self {
53 jwt_secret: String::new(),
54 algorithm: JwtAlgorithm::HS256,
55 allow_anonymous: true,
56 skip_verification: true,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
63pub enum JwtAlgorithm {
64 #[default]
65 HS256,
66 HS384,
67 HS512,
68}
69
70impl From<JwtAlgorithm> for Algorithm {
71 fn from(alg: JwtAlgorithm) -> Self {
72 match alg {
73 JwtAlgorithm::HS256 => Algorithm::HS256,
74 JwtAlgorithm::HS384 => Algorithm::HS384,
75 JwtAlgorithm::HS512 => Algorithm::HS512,
76 }
77 }
78}
79
80#[derive(Clone)]
82pub struct AuthMiddleware {
83 config: Arc<AuthConfig>,
84 decoding_key: Option<DecodingKey>,
85}
86
87impl std::fmt::Debug for AuthMiddleware {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("AuthMiddleware")
90 .field("config", &self.config)
91 .field("decoding_key", &self.decoding_key.is_some())
92 .finish()
93 }
94}
95
96impl AuthMiddleware {
97 pub fn new(config: AuthConfig) -> Self {
99 let decoding_key = if config.skip_verification || config.jwt_secret.is_empty() {
100 None
101 } else {
102 Some(DecodingKey::from_secret(config.jwt_secret.as_bytes()))
103 };
104
105 Self {
106 config: Arc::new(config),
107 decoding_key,
108 }
109 }
110
111 pub fn permissive() -> Self {
114 Self::new(AuthConfig::dev_mode())
115 }
116
117 pub fn config(&self) -> &AuthConfig {
119 &self.config
120 }
121
122 pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
124 if self.config.skip_verification {
125 self.decode_without_verification(token)
127 } else if let Some(ref key) = self.decoding_key {
128 self.decode_with_verification(token, key)
129 } else {
130 Err(AuthError::InvalidToken(
131 "JWT secret not configured".to_string(),
132 ))
133 }
134 }
135
136 fn decode_with_verification(
138 &self,
139 token: &str,
140 key: &DecodingKey,
141 ) -> Result<Claims, AuthError> {
142 let mut validation = Validation::new(self.config.algorithm.into());
143
144 validation.validate_exp = true;
146 validation.validate_nbf = false;
147 validation.validate_aud = false;
148 validation.leeway = 60; validation.set_required_spec_claims(&["exp", "sub"]);
152
153 let token_data = decode::<Claims>(token, key, &validation).map_err(|e| match e.kind() {
154 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
155 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
156 AuthError::InvalidToken("Invalid signature".to_string())
157 }
158 jsonwebtoken::errors::ErrorKind::InvalidToken => {
159 AuthError::InvalidToken("Invalid token format".to_string())
160 }
161 jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
162 AuthError::InvalidToken(format!("Missing required claim: {}", claim))
163 }
164 _ => AuthError::InvalidToken(e.to_string()),
165 })?;
166
167 Ok(token_data.claims)
168 }
169
170 fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
173 let token_data =
174 dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
175 jsonwebtoken::errors::ErrorKind::InvalidToken => {
176 AuthError::InvalidToken("Invalid token format".to_string())
177 }
178 _ => AuthError::InvalidToken(e.to_string()),
179 })?;
180
181 if token_data.claims.is_expired() {
183 return Err(AuthError::TokenExpired);
184 }
185
186 Ok(token_data.claims)
187 }
188}
189
190#[derive(Debug, Clone, thiserror::Error)]
192pub enum AuthError {
193 #[error("Missing authorization header")]
194 MissingHeader,
195 #[error("Invalid authorization header format")]
196 InvalidHeader,
197 #[error("Invalid token: {0}")]
198 InvalidToken(String),
199 #[error("Token expired")]
200 TokenExpired,
201}
202
203pub fn extract_auth_context(req: &Request<Body>, middleware: &AuthMiddleware) -> AuthContext {
205 let auth_header = req
207 .headers()
208 .get(axum::http::header::AUTHORIZATION)
209 .and_then(|v| v.to_str().ok());
210
211 let token = match auth_header {
212 Some(header) if header.starts_with("Bearer ") => {
213 Some(header.trim_start_matches("Bearer ").trim())
214 }
215 _ => None,
216 };
217
218 match token {
219 Some(token) => match middleware.validate_token(token) {
220 Ok(claims) => {
221 let user_id = claims.user_id().unwrap_or_else(Uuid::nil);
222 let custom_claims: HashMap<String, serde_json::Value> = claims.custom;
223 AuthContext::authenticated(user_id, claims.roles, custom_claims)
224 }
225 Err(_) => AuthContext::unauthenticated(),
226 },
227 None => AuthContext::unauthenticated(),
228 }
229}
230
231pub async fn auth_middleware(
233 State(middleware): State<Arc<AuthMiddleware>>,
234 req: Request<Body>,
235 next: Next,
236) -> Response {
237 let auth_context = extract_auth_context(&req, &middleware);
238
239 let mut req = req;
241 req.extensions_mut().insert(auth_context);
242
243 next.run(req).await
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use jsonwebtoken::{encode, EncodingKey, Header};
250
251 fn create_test_claims(expired: bool) -> Claims {
252 use forge_core::auth::ClaimsBuilder;
253
254 let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
255
256 if expired {
257 builder = builder.duration_secs(-3600); } else {
259 builder = builder.duration_secs(3600); }
261
262 builder.build().unwrap()
263 }
264
265 fn create_test_token(claims: &Claims, secret: &str) -> String {
266 encode(
267 &Header::default(),
268 claims,
269 &EncodingKey::from_secret(secret.as_bytes()),
270 )
271 .unwrap()
272 }
273
274 #[test]
275 fn test_auth_config_default() {
276 let config = AuthConfig::default();
277 assert!(config.allow_anonymous);
278 assert_eq!(config.algorithm, JwtAlgorithm::HS256);
279 assert!(!config.skip_verification);
280 }
281
282 #[test]
283 fn test_auth_config_dev_mode() {
284 let config = AuthConfig::dev_mode();
285 assert!(config.skip_verification);
286 assert!(config.allow_anonymous);
287 }
288
289 #[test]
290 fn test_auth_middleware_permissive() {
291 let middleware = AuthMiddleware::permissive();
292 assert!(middleware.config.skip_verification);
293 }
294
295 #[test]
296 fn test_valid_token_with_correct_secret() {
297 let secret = "test-secret-key";
298 let config = AuthConfig::with_secret(secret);
299 let middleware = AuthMiddleware::new(config);
300
301 let claims = create_test_claims(false);
302 let token = create_test_token(&claims, secret);
303
304 let result = middleware.validate_token(&token);
305 assert!(result.is_ok());
306 let validated_claims = result.unwrap();
307 assert_eq!(validated_claims.sub, "test-user-id");
308 }
309
310 #[test]
311 fn test_valid_token_with_wrong_secret() {
312 let config = AuthConfig::with_secret("correct-secret");
313 let middleware = AuthMiddleware::new(config);
314
315 let claims = create_test_claims(false);
316 let token = create_test_token(&claims, "wrong-secret");
317
318 let result = middleware.validate_token(&token);
319 assert!(result.is_err());
320 match result {
321 Err(AuthError::InvalidToken(_)) => {}
322 _ => panic!("Expected InvalidToken error"),
323 }
324 }
325
326 #[test]
327 fn test_expired_token() {
328 let secret = "test-secret";
329 let config = AuthConfig::with_secret(secret);
330 let middleware = AuthMiddleware::new(config);
331
332 let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
334
335 let result = middleware.validate_token(&token);
336 assert!(result.is_err());
337 match result {
338 Err(AuthError::TokenExpired) => {}
339 _ => panic!("Expected TokenExpired error"),
340 }
341 }
342
343 #[test]
344 fn test_tampered_token() {
345 let secret = "test-secret";
346 let config = AuthConfig::with_secret(secret);
347 let middleware = AuthMiddleware::new(config);
348
349 let claims = create_test_claims(false);
350 let mut token = create_test_token(&claims, secret);
351
352 if let Some(last_char) = token.pop() {
354 let replacement = if last_char == 'a' { 'b' } else { 'a' };
355 token.push(replacement);
356 }
357
358 let result = middleware.validate_token(&token);
359 assert!(result.is_err());
360 }
361
362 #[test]
363 fn test_dev_mode_skips_signature() {
364 let config = AuthConfig::dev_mode();
365 let middleware = AuthMiddleware::new(config);
366
367 let claims = create_test_claims(false);
369 let token = create_test_token(&claims, "any-secret");
370
371 let result = middleware.validate_token(&token);
373 assert!(result.is_ok());
374 }
375
376 #[test]
377 fn test_dev_mode_still_checks_expiration() {
378 let config = AuthConfig::dev_mode();
379 let middleware = AuthMiddleware::new(config);
380
381 let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
383
384 let result = middleware.validate_token(&token);
385 assert!(result.is_err());
386 match result {
387 Err(AuthError::TokenExpired) => {}
388 _ => panic!("Expected TokenExpired error even in dev mode"),
389 }
390 }
391
392 #[test]
393 fn test_invalid_token_format() {
394 let config = AuthConfig::with_secret("secret");
395 let middleware = AuthMiddleware::new(config);
396
397 let result = middleware.validate_token("not-a-valid-jwt");
398 assert!(result.is_err());
399 match result {
400 Err(AuthError::InvalidToken(_)) => {}
401 _ => panic!("Expected InvalidToken error"),
402 }
403 }
404
405 #[test]
406 fn test_algorithm_conversion() {
407 assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
408 assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
409 assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
410 }
411}