1use axum::extract::{FromRequestParts, State};
2use axum::http::request::Parts;
3use chrono::{Duration, Utc};
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use uuid::Uuid;
8
9use crate::error::{ApiError, ApiResult};
10use crate::server::AppState;
11
12#[derive(Debug, Clone)]
13pub struct AuthState {
14 pub secret: Vec<u8>,
15 pub vault_unlocked: Arc<std::sync::Mutex<bool>>,
16}
17
18impl Default for AuthState {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl AuthState {
25 #[must_use]
26 pub fn new() -> Self {
27 let mut secret = vec![0u8; 32];
28 rand::Rng::fill(&mut rand::rng(), &mut secret[..]);
29
30 Self {
31 secret,
32 vault_unlocked: Arc::new(std::sync::Mutex::new(false)),
33 }
34 }
35
36 pub fn set_vault_unlocked(&self, unlocked: bool) {
37 if let Ok(mut status) = self.vault_unlocked.lock() {
38 *status = unlocked;
39 }
40 }
41
42 #[must_use]
47 pub fn is_vault_unlocked(&self) -> bool {
48 self.vault_unlocked.lock().map(|status| *status).unwrap_or(false)
49 }
50
51 pub fn generate_token(&self, scopes: Vec<String>) -> ApiResult<String> {
58 let expiration = Utc::now() + Duration::hours(1);
59
60 let claims = TokenClaims {
61 sub: "api-user".to_string(),
62 exp: usize::try_from(expiration.timestamp())
63 .map_err(|_| ApiError::InternalError("Token expiration timestamp overflow".to_string()))?,
64 iat: usize::try_from(Utc::now().timestamp())
65 .map_err(|_| ApiError::InternalError("Token issue timestamp overflow".to_string()))?,
66 jti: Uuid::new_v4().to_string(),
67 scopes,
68 };
69
70 encode(&Header::default(), &claims, &EncodingKey::from_secret(&self.secret))
71 .map_err(|e| ApiError::InternalError(format!("Token generation failed: {e}")))
72 }
73
74 pub fn verify_token(&self, token: &str) -> ApiResult<TokenClaims> {
81 decode::<TokenClaims>(token, &DecodingKey::from_secret(&self.secret), &Validation::default())
82 .map(|data| data.claims)
83 .map_err(|_| ApiError::Unauthorized)
84 }
85}
86
87#[derive(Debug, Serialize, Deserialize, Clone)]
88pub struct TokenClaims {
89 pub sub: String,
90 pub exp: usize,
91 pub iat: usize,
92 pub jti: String,
93 pub scopes: Vec<String>,
94}
95
96impl TokenClaims {
97 #[must_use]
98 pub fn has_scope(&self, required_scope: &str) -> bool {
99 self.scopes.contains(&required_scope.to_string())
100 }
101}
102
103#[derive(Debug)]
105pub struct AuthenticatedUser(pub TokenClaims);
106
107impl<S> FromRequestParts<S> for AuthenticatedUser
108where
109 S: Send + Sync,
110{
111 type Rejection = ApiError;
112
113 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
114 let auth_header = parts
116 .headers
117 .get("authorization")
118 .and_then(|header| header.to_str().ok())
119 .and_then(|header| {
120 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
122 Some(&header[7..])
123 } else {
124 None
125 }
126 })
127 .ok_or(ApiError::Unauthorized)?;
128
129 let auth_state = parts
131 .extensions
132 .get::<AuthState>()
133 .ok_or(ApiError::InternalError("Auth state not found".to_string()))?;
134
135 let claims = auth_state.verify_token(auth_header)?;
137 Ok(AuthenticatedUser(claims))
138 }
139}
140
141use axum::body::Body;
143use axum::{http::Request, middleware::Next, response::Response};
144
145pub async fn auth_middleware(State(state): State<Arc<AppState>>, mut request: Request<Body>, next: Next) -> Response {
146 request.extensions_mut().insert(state.auth.clone());
148 next.run(request).await
149}