1use std::sync::Arc;
4
5use axum::{
6 Json,
7 extract::{FromRef, FromRequestParts},
8 http::{StatusCode, header::AUTHORIZATION, request::Parts},
9 response::{IntoResponse, Response},
10};
11use serde::Serialize;
12use tokio::sync::RwLock;
13
14use super::AuthError;
15use super::config::AuthConfig;
16use super::jwt::{Claims, JwtManager};
17use super::setup::BootstrapManager;
18use super::users::{UserRole, UserStore};
19
20pub struct AuthState {
22 pub config: AuthConfig,
24 pub jwt: JwtManager,
26 pub users: UserStore,
28 pub bootstrap: RwLock<BootstrapManager>,
30}
31
32impl AuthState {
33 #[must_use]
35 pub fn new(config: AuthConfig, jwt: JwtManager, users: UserStore) -> Self {
36 Self {
37 config,
38 jwt,
39 users,
40 bootstrap: RwLock::new(BootstrapManager::new()),
41 }
42 }
43
44 pub fn initialize(
50 mut config: AuthConfig,
51 data_dir: &std::path::Path,
52 ) -> Result<Self, AuthError> {
53 let users = UserStore::open(data_dir)?;
55
56 let jwt_secret = if let Some(secret) = &config.jwt_secret {
58 secret.clone()
59 } else {
60 let secret = JwtManager::generate_hex_secret();
61 config.jwt_secret = Some(secret.clone());
62 tracing::info!("Generated new JWT secret");
64 secret
65 };
66
67 let jwt = JwtManager::from_hex_secret(
68 &jwt_secret,
69 config.token_expiry(),
70 config.refresh_expiry(),
71 )?;
72
73 Ok(Self::new(config, jwt, users))
74 }
75
76 #[must_use]
78 pub fn requires_auth(&self, method: &str) -> bool {
79 self.config.enabled && !self.config.is_public_method(method)
80 }
81
82 pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
88 self.jwt.validate_access_token(token)
89 }
90}
91
92impl std::fmt::Debug for AuthState {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("AuthState")
95 .field("config", &self.config)
96 .field("user_count", &self.users.count())
97 .finish_non_exhaustive()
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct AuthLayer;
104
105#[derive(Debug, Clone)]
109pub struct RequireAuth {
110 pub claims: Claims,
112}
113
114impl RequireAuth {
115 #[must_use]
117 pub fn user_id(&self) -> &str {
118 &self.claims.sub
119 }
120
121 #[must_use]
123 pub fn username(&self) -> &str {
124 &self.claims.username
125 }
126
127 #[must_use]
129 pub const fn role(&self) -> UserRole {
130 self.claims.role
131 }
132
133 #[must_use]
135 pub fn is_admin(&self) -> bool {
136 self.claims.role.is_admin()
137 }
138
139 pub fn require_admin(&self) -> Result<(), AuthError> {
145 if self.is_admin() {
146 Ok(())
147 } else {
148 Err(AuthError::PermissionDenied(
149 "Admin role required".to_string(),
150 ))
151 }
152 }
153}
154
155#[derive(Debug, Serialize)]
157struct AuthErrorResponse {
158 error: String,
159 code: &'static str,
160}
161
162impl IntoResponse for AuthError {
163 fn into_response(self) -> Response {
164 let (status, code) = match &self {
165 Self::InvalidCredentials => (StatusCode::UNAUTHORIZED, "invalid_credentials"),
166 Self::TokenError(_) => (StatusCode::UNAUTHORIZED, "invalid_token"),
167 Self::PermissionDenied(_) => (StatusCode::FORBIDDEN, "permission_denied"),
168 Self::SetupRequired => (StatusCode::SERVICE_UNAVAILABLE, "setup_required"),
169 Self::InvalidBootstrapToken => (StatusCode::UNAUTHORIZED, "invalid_bootstrap_token"),
170 Self::UserNotFound(_) => (StatusCode::NOT_FOUND, "user_not_found"),
171 Self::UserExists(_) => (StatusCode::CONFLICT, "user_exists"),
172 Self::Storage(_) | Self::Config(_) => {
173 (StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
174 }
175 };
176
177 let body = AuthErrorResponse {
178 error: self.to_string(),
179 code,
180 };
181
182 (status, Json(body)).into_response()
183 }
184}
185
186impl<S> FromRequestParts<S> for RequireAuth
188where
189 S: Send + Sync,
190 Arc<AuthState>: FromRef<S>,
191{
192 type Rejection = Response;
193
194 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
195 let auth_state = Arc::<AuthState>::from_ref(state);
196 extract_auth(parts, &auth_state).await
197 }
198}
199
200async fn extract_auth(parts: &Parts, auth_state: &AuthState) -> Result<RequireAuth, Response> {
201 if !auth_state.config.enabled {
203 return Ok(RequireAuth {
205 claims: Claims {
206 sub: "system".to_string(),
207 username: "system".to_string(),
208 role: UserRole::Admin,
209 iat: 0,
210 exp: i64::MAX,
211 token_type: super::jwt::TokenType::Access,
212 family_id: None,
213 },
214 });
215 }
216
217 let auth_header = parts
219 .headers
220 .get(AUTHORIZATION)
221 .and_then(|v| v.to_str().ok())
222 .ok_or_else(|| {
223 AuthError::TokenError("Missing Authorization header".to_string()).into_response()
224 })?;
225
226 let token = JwtManager::extract_from_header(auth_header).ok_or_else(|| {
227 AuthError::TokenError("Invalid Authorization header format".to_string()).into_response()
228 })?;
229
230 let claims = auth_state
232 .validate_token(token)
233 .map_err(IntoResponse::into_response)?;
234
235 let user = auth_state
237 .users
238 .get(&claims.sub)
239 .map_err(IntoResponse::into_response)?
240 .ok_or_else(|| AuthError::UserNotFound(claims.sub.clone()).into_response())?;
241
242 if !user.active {
243 return Err(AuthError::PermissionDenied("Account disabled".to_string()).into_response());
244 }
245
246 Ok(RequireAuth { claims })
247}
248
249#[derive(Debug, Clone)]
253pub struct OptionalAuth(pub Option<RequireAuth>);
254
255impl<S> FromRequestParts<S> for OptionalAuth
256where
257 S: Send + Sync,
258 Arc<AuthState>: FromRef<S>,
259{
260 type Rejection = std::convert::Infallible;
261
262 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
263 Ok(Self(
264 RequireAuth::from_request_parts(parts, state).await.ok(),
265 ))
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct RequireAdmin(pub RequireAuth);
272
273impl<S> FromRequestParts<S> for RequireAdmin
274where
275 S: Send + Sync,
276 Arc<AuthState>: FromRef<S>,
277{
278 type Rejection = Response;
279
280 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
281 let auth = RequireAuth::from_request_parts(parts, state).await?;
282
283 if !auth.is_admin() {
284 return Err(
285 AuthError::PermissionDenied("Admin role required".to_string()).into_response(),
286 );
287 }
288
289 Ok(Self(auth))
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_require_auth_methods() {
299 let auth = RequireAuth {
300 claims: Claims {
301 sub: "user_123".to_string(),
302 username: "testuser".to_string(),
303 role: UserRole::Operator,
304 iat: 0,
305 exp: i64::MAX,
306 token_type: super::super::jwt::TokenType::Access,
307 family_id: None,
308 },
309 };
310
311 assert_eq!(auth.user_id(), "user_123");
312 assert_eq!(auth.username(), "testuser");
313 assert_eq!(auth.role(), UserRole::Operator);
314 assert!(!auth.is_admin());
315 assert!(auth.require_admin().is_err());
316 }
317
318 #[test]
319 fn test_admin_auth() {
320 let auth = RequireAuth {
321 claims: Claims {
322 sub: "admin_1".to_string(),
323 username: "admin".to_string(),
324 role: UserRole::Admin,
325 iat: 0,
326 exp: i64::MAX,
327 token_type: super::super::jwt::TokenType::Access,
328 family_id: None,
329 },
330 };
331
332 assert!(auth.is_admin());
333 assert!(auth.require_admin().is_ok());
334 }
335}