1use axum::{
7 extract::{Extension, State},
8 http::StatusCode,
9 Json,
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tracing::{debug, info, instrument};
14
15use crate::{
16 auth::AuthUser,
17 error::{ApiError, ApiResult},
18 jwt::{Claims, JwtManager, TokenPair},
19 responses::{ok, ApiResponse},
20};
21
22#[derive(Clone, Debug)]
24pub struct AuthHandlerState {
25 jwt_manager: Arc<JwtManager>,
26}
27
28impl AuthHandlerState {
29 pub fn new(jwt_manager: JwtManager) -> Self {
31 Self {
32 jwt_manager: Arc::new(jwt_manager),
33 }
34 }
35
36 pub fn jwt_manager(&self) -> &JwtManager {
38 &self.jwt_manager
39 }
40}
41
42#[derive(Debug, Deserialize, Serialize)]
44pub struct LoginRequest {
45 pub username: String,
47
48 pub password: String,
50}
51
52#[derive(Debug, Serialize, Deserialize)]
54pub struct LoginResponse {
55 #[serde(flatten)]
57 pub token_pair: TokenPair,
58
59 pub user: UserInfo,
61}
62
63#[derive(Debug, Serialize, Deserialize)]
65pub struct UserInfo {
66 pub id: String,
68
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub email: Option<String>,
72
73 #[serde(default, skip_serializing_if = "Vec::is_empty")]
75 pub roles: Vec<String>,
76}
77
78impl UserInfo {
79 pub fn from_claims(claims: &Claims) -> Self {
81 Self {
82 id: claims.sub.clone(),
83 email: claims.email.clone(),
84 roles: claims.roles.clone(),
85 }
86 }
87}
88
89#[derive(Debug, Deserialize, Serialize)]
91pub struct RefreshTokenRequest {
92 pub refresh_token: String,
94}
95
96#[derive(Debug, Serialize, Deserialize)]
98pub struct RefreshTokenResponse {
99 #[serde(flatten)]
101 pub token_pair: TokenPair,
102}
103
104#[instrument(skip(state, request))]
114pub async fn login(
115 State(state): State<AuthHandlerState>,
116 Json(request): Json<LoginRequest>,
117) -> ApiResult<(StatusCode, Json<ApiResponse<LoginResponse>>)> {
118 info!("Login attempt for user: {}", request.username);
119
120 if request.username.is_empty() || request.password.is_empty() {
123 return Err(ApiError::bad_request("Username and password are required"));
124 }
125
126 let claims = Claims::new(
128 &request.username,
129 state.jwt_manager().config.issuer.clone(),
130 state.jwt_manager().config.audience.clone(),
131 state.jwt_manager().config.expiration_seconds,
132 )
133 .with_email(format!("{}@example.com", request.username))
134 .with_role("user");
135
136 let token_pair = state
138 .jwt_manager()
139 .generate_token_pair(&request.username)
140 .map_err(|e| ApiError::internal_server_error(format!("Failed to generate token: {}", e)))?;
141
142 let response = LoginResponse {
143 token_pair,
144 user: UserInfo::from_claims(&claims),
145 };
146
147 info!("User logged in successfully: {}", request.username);
148 Ok((StatusCode::OK, Json(ok(response))))
149}
150
151#[instrument(skip(state, request))]
153pub async fn refresh_token(
154 State(state): State<AuthHandlerState>,
155 Json(request): Json<RefreshTokenRequest>,
156) -> ApiResult<Json<ApiResponse<RefreshTokenResponse>>> {
157 debug!("Token refresh requested");
158
159 let token_pair = state
161 .jwt_manager()
162 .refresh_access_token(&request.refresh_token)
163 .map_err(|e| match e {
164 crate::jwt::TokenError::Expired => {
165 ApiError::unauthorized("Refresh token has expired")
166 }
167 crate::jwt::TokenError::InvalidClaims(_) => {
168 ApiError::bad_request("Invalid refresh token")
169 }
170 _ => ApiError::unauthorized("Invalid refresh token"),
171 })?;
172
173 let response = RefreshTokenResponse { token_pair };
174
175 debug!("Token refreshed successfully");
176 Ok(Json(ok(response)))
177}
178
179#[instrument(skip(user))]
181pub async fn me(
182 Extension(user): Extension<AuthUser>,
183) -> ApiResult<Json<ApiResponse<UserInfo>>> {
184 debug!("Current user info requested");
185
186 let user_info = UserInfo::from_claims(&user.claims);
187
188 Ok(Json(ok(user_info)))
189}
190
191#[instrument(skip(user))]
199pub async fn logout(
200 Extension(user): Extension<AuthUser>,
201) -> ApiResult<Json<ApiResponse<LogoutResponse>>> {
202 info!("User logout: {}", user.user_id());
203
204 let response = LogoutResponse {
208 message: "Logged out successfully".to_string(),
209 };
210
211 Ok(Json(ok(response)))
212}
213
214#[derive(Debug, Serialize, Deserialize)]
216pub struct LogoutResponse {
217 pub message: String,
219}
220
221#[instrument(skip(user))]
223pub async fn generate_api_key(
224 State(state): State<AuthHandlerState>,
225 Extension(user): Extension<AuthUser>,
226) -> ApiResult<Json<ApiResponse<ApiKeyResponse>>> {
227 info!("Generating API key for user: {}", user.user_id());
228
229 if !user.has_role("admin") && !user.has_role("developer") {
231 return Err(ApiError::forbidden(
232 "Only admin or developer roles can generate API keys",
233 ));
234 }
235
236 let claims = Claims::new(
238 user.user_id(),
239 state.jwt_manager().config.issuer.clone(),
240 state.jwt_manager().config.audience.clone(),
241 86400 * 30, )
243 .with_roles(user.claims.roles.clone())
244 .with_custom("api_key", serde_json::json!(true));
245
246 let api_key = state
247 .jwt_manager()
248 .generate_token_with_claims(claims)
249 .map_err(|e| ApiError::internal_server_error(format!("Failed to generate API key: {}", e)))?;
250
251 let response = ApiKeyResponse { api_key };
252
253 info!("API key generated for user: {}", user.user_id());
254 Ok(Json(ok(response)))
255}
256
257#[derive(Debug, Serialize, Deserialize)]
259pub struct ApiKeyResponse {
260 pub api_key: String,
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::jwt::JwtConfig;
268
269 fn create_test_state() -> AuthHandlerState {
270 let config = JwtConfig::new("test-secret")
271 .with_issuer("test")
272 .with_audience("test");
273 let jwt_manager = JwtManager::new(config).unwrap();
274 AuthHandlerState::new(jwt_manager)
275 }
276
277 #[test]
278 fn test_user_info_from_claims() {
279 let claims = Claims::new("user123", "test", "test", 3600)
280 .with_email("user@example.com")
281 .with_role("admin");
282
283 let user_info = UserInfo::from_claims(&claims);
284
285 assert_eq!(user_info.id, "user123");
286 assert_eq!(user_info.email, Some("user@example.com".to_string()));
287 assert_eq!(user_info.roles, vec!["admin"]);
288 }
289
290 #[tokio::test]
291 async fn test_login_request_validation() {
292 let state = create_test_state();
293
294 let request = LoginRequest {
295 username: "".to_string(),
296 password: "password".to_string(),
297 };
298
299 let result = login(State(state), Json(request)).await;
300 assert!(result.is_err());
301 }
302}