Skip to main content

mockforge_collab/
middleware.rs

1//! Middleware for authentication and authorization
2
3use crate::auth::AuthService;
4use crate::error::CollabError;
5use axum::{
6    extract::{Request, State},
7    http::StatusCode,
8    middleware::Next,
9    response::Response,
10};
11use std::sync::Arc;
12use uuid::Uuid;
13
14/// Extension for authenticated user
15#[derive(Clone, Debug)]
16pub struct AuthUser {
17    /// User's unique identifier
18    pub user_id: Uuid,
19    /// User's username
20    pub username: String,
21}
22
23/// JWT authentication middleware
24///
25/// # Errors
26///
27/// Returns an error tuple if authentication fails.
28pub async fn auth_middleware(
29    State(auth): State<Arc<AuthService>>,
30    mut request: Request,
31    next: Next,
32) -> Result<Response, (StatusCode, String)> {
33    // Extract Authorization header
34    let auth_header = request
35        .headers()
36        .get("Authorization")
37        .and_then(|h| h.to_str().ok())
38        .ok_or_else(|| (StatusCode::UNAUTHORIZED, "Missing Authorization header".to_string()))?;
39
40    // Check Bearer prefix
41    let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
42        (StatusCode::UNAUTHORIZED, "Invalid Authorization header format".to_string())
43    })?;
44
45    // Verify token
46    let claims = auth
47        .verify_token(token)
48        .map_err(|e| (StatusCode::UNAUTHORIZED, format!("Invalid token: {e}")))?;
49
50    // Parse user ID
51    let user_id = Uuid::parse_str(&claims.sub)
52        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Invalid user ID in token".to_string()))?;
53
54    // Add user to request extensions
55    request.extensions_mut().insert(AuthUser {
56        user_id,
57        username: claims.username,
58    });
59
60    Ok(next.run(request).await)
61}
62
63/// Extract authenticated user from request
64///
65/// # Errors
66///
67/// Returns an error if the user is not authenticated.
68pub fn extract_auth_user(request: &Request) -> Result<&AuthUser, CollabError> {
69    request
70        .extensions()
71        .get::<AuthUser>()
72        .ok_or_else(|| CollabError::AuthenticationFailed("Not authenticated".to_string()))
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn test_auth_user_creation() {
81        let user = AuthUser {
82            user_id: Uuid::new_v4(),
83            username: "testuser".to_string(),
84        };
85
86        assert_eq!(user.username, "testuser");
87    }
88}