llm_registry_api/
auth_handlers.rs

1//! Authentication API handlers
2//!
3//! This module provides HTTP handlers for authentication endpoints including
4//! login, token refresh, and user information.
5
6use axum::{
7    extract::{Extension, State},
8    http::StatusCode,
9    Json,
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tracing::{debug, info, instrument};
14
15use crate::{
16    auth::AuthUser,
17    error::{ApiError, ApiResult},
18    jwt::{Claims, JwtManager, TokenPair},
19    responses::{ok, ApiResponse},
20};
21
22/// Authentication state for handlers
23#[derive(Clone, Debug)]
24pub struct AuthHandlerState {
25    jwt_manager: Arc<JwtManager>,
26}
27
28impl AuthHandlerState {
29    /// Create new auth handler state
30    pub fn new(jwt_manager: JwtManager) -> Self {
31        Self {
32            jwt_manager: Arc::new(jwt_manager),
33        }
34    }
35
36    /// Get JWT manager reference
37    pub fn jwt_manager(&self) -> &JwtManager {
38        &self.jwt_manager
39    }
40}
41
42/// Login request
43#[derive(Debug, Deserialize, Serialize)]
44pub struct LoginRequest {
45    /// Username or email
46    pub username: String,
47
48    /// Password
49    pub password: String,
50}
51
52/// Login response
53#[derive(Debug, Serialize, Deserialize)]
54pub struct LoginResponse {
55    /// Token pair (access + refresh)
56    #[serde(flatten)]
57    pub token_pair: TokenPair,
58
59    /// User information
60    pub user: UserInfo,
61}
62
63/// User information
64#[derive(Debug, Serialize, Deserialize)]
65pub struct UserInfo {
66    /// User ID
67    pub id: String,
68
69    /// Email
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub email: Option<String>,
72
73    /// Roles
74    #[serde(default, skip_serializing_if = "Vec::is_empty")]
75    pub roles: Vec<String>,
76}
77
78impl UserInfo {
79    /// Create from claims
80    pub fn from_claims(claims: &Claims) -> Self {
81        Self {
82            id: claims.sub.clone(),
83            email: claims.email.clone(),
84            roles: claims.roles.clone(),
85        }
86    }
87}
88
89/// Token refresh request
90#[derive(Debug, Deserialize, Serialize)]
91pub struct RefreshTokenRequest {
92    /// Refresh token
93    pub refresh_token: String,
94}
95
96/// Token refresh response
97#[derive(Debug, Serialize, Deserialize)]
98pub struct RefreshTokenResponse {
99    /// New token pair
100    #[serde(flatten)]
101    pub token_pair: TokenPair,
102}
103
104/// Login handler
105///
106/// NOTE: This is a simplified login handler for demonstration.
107/// In production, you would:
108/// 1. Validate credentials against a database
109/// 2. Use proper password hashing (bcrypt, argon2)
110/// 3. Implement rate limiting
111/// 4. Add audit logging
112/// 5. Handle MFA/2FA if required
113#[instrument(skip(state, request))]
114pub async fn login(
115    State(state): State<AuthHandlerState>,
116    Json(request): Json<LoginRequest>,
117) -> ApiResult<(StatusCode, Json<ApiResponse<LoginResponse>>)> {
118    info!("Login attempt for user: {}", request.username);
119
120    // TODO: In production, validate against database
121    // For now, this is a stub implementation
122    if request.username.is_empty() || request.password.is_empty() {
123        return Err(ApiError::bad_request("Username and password are required"));
124    }
125
126    // Create claims for the user
127    let claims = Claims::new(
128        &request.username,
129        state.jwt_manager().config.issuer.clone(),
130        state.jwt_manager().config.audience.clone(),
131        state.jwt_manager().config.expiration_seconds,
132    )
133    .with_email(format!("{}@example.com", request.username))
134    .with_role("user");
135
136    // Generate token pair
137    let token_pair = state
138        .jwt_manager()
139        .generate_token_pair(&request.username)
140        .map_err(|e| ApiError::internal_server_error(format!("Failed to generate token: {}", e)))?;
141
142    let response = LoginResponse {
143        token_pair,
144        user: UserInfo::from_claims(&claims),
145    };
146
147    info!("User logged in successfully: {}", request.username);
148    Ok((StatusCode::OK, Json(ok(response))))
149}
150
151/// Refresh token handler
152#[instrument(skip(state, request))]
153pub async fn refresh_token(
154    State(state): State<AuthHandlerState>,
155    Json(request): Json<RefreshTokenRequest>,
156) -> ApiResult<Json<ApiResponse<RefreshTokenResponse>>> {
157    debug!("Token refresh requested");
158
159    // Refresh the token
160    let token_pair = state
161        .jwt_manager()
162        .refresh_access_token(&request.refresh_token)
163        .map_err(|e| match e {
164            crate::jwt::TokenError::Expired => {
165                ApiError::unauthorized("Refresh token has expired")
166            }
167            crate::jwt::TokenError::InvalidClaims(_) => {
168                ApiError::bad_request("Invalid refresh token")
169            }
170            _ => ApiError::unauthorized("Invalid refresh token"),
171        })?;
172
173    let response = RefreshTokenResponse { token_pair };
174
175    debug!("Token refreshed successfully");
176    Ok(Json(ok(response)))
177}
178
179/// Get current user information
180#[instrument(skip(user))]
181pub async fn me(
182    Extension(user): Extension<AuthUser>,
183) -> ApiResult<Json<ApiResponse<UserInfo>>> {
184    debug!("Current user info requested");
185
186    let user_info = UserInfo::from_claims(&user.claims);
187
188    Ok(Json(ok(user_info)))
189}
190
191/// Logout handler
192///
193/// NOTE: JWT tokens are stateless, so logout is primarily client-side.
194/// In production, you might:
195/// 1. Maintain a token blacklist in Redis
196/// 2. Use short-lived tokens with refresh tokens
197/// 3. Implement token revocation
198#[instrument(skip(user))]
199pub async fn logout(
200    Extension(user): Extension<AuthUser>,
201) -> ApiResult<Json<ApiResponse<LogoutResponse>>> {
202    info!("User logout: {}", user.user_id());
203
204    // TODO: In production, add token to blacklist
205    // For now, just acknowledge the logout
206
207    let response = LogoutResponse {
208        message: "Logged out successfully".to_string(),
209    };
210
211    Ok(Json(ok(response)))
212}
213
214/// Logout response
215#[derive(Debug, Serialize, Deserialize)]
216pub struct LogoutResponse {
217    /// Success message
218    pub message: String,
219}
220
221/// Generate API key handler (example of protected endpoint)
222#[instrument(skip(user))]
223pub async fn generate_api_key(
224    State(state): State<AuthHandlerState>,
225    Extension(user): Extension<AuthUser>,
226) -> ApiResult<Json<ApiResponse<ApiKeyResponse>>> {
227    info!("Generating API key for user: {}", user.user_id());
228
229    // Check if user has permission
230    if !user.has_role("admin") && !user.has_role("developer") {
231        return Err(ApiError::forbidden(
232            "Only admin or developer roles can generate API keys",
233        ));
234    }
235
236    // Create a long-lived token for API access
237    let claims = Claims::new(
238        user.user_id(),
239        state.jwt_manager().config.issuer.clone(),
240        state.jwt_manager().config.audience.clone(),
241        86400 * 30, // 30 days
242    )
243    .with_roles(user.claims.roles.clone())
244    .with_custom("api_key", serde_json::json!(true));
245
246    let api_key = state
247        .jwt_manager()
248        .generate_token_with_claims(claims)
249        .map_err(|e| ApiError::internal_server_error(format!("Failed to generate API key: {}", e)))?;
250
251    let response = ApiKeyResponse { api_key };
252
253    info!("API key generated for user: {}", user.user_id());
254    Ok(Json(ok(response)))
255}
256
257/// API key response
258#[derive(Debug, Serialize, Deserialize)]
259pub struct ApiKeyResponse {
260    /// Generated API key
261    pub api_key: String,
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::jwt::JwtConfig;
268
269    fn create_test_state() -> AuthHandlerState {
270        let config = JwtConfig::new("test-secret")
271            .with_issuer("test")
272            .with_audience("test");
273        let jwt_manager = JwtManager::new(config).unwrap();
274        AuthHandlerState::new(jwt_manager)
275    }
276
277    #[test]
278    fn test_user_info_from_claims() {
279        let claims = Claims::new("user123", "test", "test", 3600)
280            .with_email("user@example.com")
281            .with_role("admin");
282
283        let user_info = UserInfo::from_claims(&claims);
284
285        assert_eq!(user_info.id, "user123");
286        assert_eq!(user_info.email, Some("user@example.com".to_string()));
287        assert_eq!(user_info.roles, vec!["admin"]);
288    }
289
290    #[tokio::test]
291    async fn test_login_request_validation() {
292        let state = create_test_state();
293
294        let request = LoginRequest {
295            username: "".to_string(),
296            password: "password".to_string(),
297        };
298
299        let result = login(State(state), Json(request)).await;
300        assert!(result.is_err());
301    }
302}