Skip to main content

chamber_api/
auth.rs

1use axum::extract::{FromRequestParts, State};
2use axum::http::request::Parts;
3use chrono::{Duration, Utc};
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use uuid::Uuid;
8
9use crate::error::{ApiError, ApiResult};
10use crate::server::AppState;
11
12#[derive(Debug, Clone)]
13pub struct AuthState {
14    pub secret: Vec<u8>,
15    pub vault_unlocked: Arc<std::sync::Mutex<bool>>,
16}
17
18impl Default for AuthState {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl AuthState {
25    #[must_use]
26    pub fn new() -> Self {
27        let mut secret = vec![0u8; 32];
28        rand::Rng::fill(&mut rand::rng(), &mut secret[..]);
29
30        Self {
31            secret,
32            vault_unlocked: Arc::new(std::sync::Mutex::new(false)),
33        }
34    }
35
36    pub fn set_vault_unlocked(&self, unlocked: bool) {
37        if let Ok(mut status) = self.vault_unlocked.lock() {
38            *status = unlocked;
39        }
40    }
41
42    /// # Errors
43    ///
44    /// This function does not return errors, but may return false if:
45    /// - The mutex lock is poisoned or cannot be acquired
46    #[must_use]
47    pub fn is_vault_unlocked(&self) -> bool {
48        self.vault_unlocked.lock().map(|status| *status).unwrap_or(false)
49    }
50
51    /// # Errors
52    ///
53    /// This function returns an error if:
54    /// - Token expiration timestamp conversion fails
55    /// - Token issue timestamp conversion fails
56    /// - Token generation process fails
57    pub fn generate_token(&self, scopes: Vec<String>) -> ApiResult<String> {
58        let expiration = Utc::now() + Duration::hours(1);
59
60        let claims = TokenClaims {
61            sub: "api-user".to_string(),
62            exp: usize::try_from(expiration.timestamp())
63                .map_err(|_| ApiError::InternalError("Token expiration timestamp overflow".to_string()))?,
64            iat: usize::try_from(Utc::now().timestamp())
65                .map_err(|_| ApiError::InternalError("Token issue timestamp overflow".to_string()))?,
66            jti: Uuid::new_v4().to_string(),
67            scopes,
68        };
69
70        encode(&Header::default(), &claims, &EncodingKey::from_secret(&self.secret))
71            .map_err(|e| ApiError::InternalError(format!("Token generation failed: {e}")))
72    }
73
74    /// # Errors
75    ///
76    /// This function returns an error if:
77    /// - The token is invalid or expired
78    /// - The token signature is invalid
79    /// - The token format is incorrect
80    pub fn verify_token(&self, token: &str) -> ApiResult<TokenClaims> {
81        decode::<TokenClaims>(token, &DecodingKey::from_secret(&self.secret), &Validation::default())
82            .map(|data| data.claims)
83            .map_err(|_| ApiError::Unauthorized)
84    }
85}
86
87#[derive(Debug, Serialize, Deserialize, Clone)]
88pub struct TokenClaims {
89    pub sub: String,
90    pub exp: usize,
91    pub iat: usize,
92    pub jti: String,
93    pub scopes: Vec<String>,
94}
95
96impl TokenClaims {
97    #[must_use]
98    pub fn has_scope(&self, required_scope: &str) -> bool {
99        self.scopes.contains(&required_scope.to_string())
100    }
101}
102
103// Simplified approach: Use the AuthState directly from extensions
104#[derive(Debug)]
105pub struct AuthenticatedUser(pub TokenClaims);
106
107impl<S> FromRequestParts<S> for AuthenticatedUser
108where
109    S: Send + Sync,
110{
111    type Rejection = ApiError;
112
113    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
114        // Get the authorization header
115        let auth_header = parts
116            .headers
117            .get("authorization")
118            .and_then(|header| header.to_str().ok())
119            .and_then(|header| {
120                // Make Bearer prefix case-insensitive
121                if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
122                    Some(&header[7..])
123                } else {
124                    None
125                }
126            })
127            .ok_or(ApiError::Unauthorized)?;
128
129        // Get the auth state from request extensions (added by middleware)
130        let auth_state = parts
131            .extensions
132            .get::<AuthState>()
133            .ok_or(ApiError::InternalError("Auth state not found".to_string()))?;
134
135        // Verify the token using the auth state
136        let claims = auth_state.verify_token(auth_header)?;
137        Ok(AuthenticatedUser(claims))
138    }
139}
140
141// Simple middleware that just adds AuthState to extensions
142use axum::body::Body;
143use axum::{http::Request, middleware::Next, response::Response};
144
145pub async fn auth_middleware(State(state): State<Arc<AppState>>, mut request: Request<Body>, next: Next) -> Response {
146    // Add auth state to request extensions so extractors can access it
147    request.extensions_mut().insert(state.auth.clone());
148    next.run(request).await
149}