1pub mod cred;
2pub mod token;
3
4use std::{io::Write, net::SocketAddr};
5
6use argon2::{
7 password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
8 Argon2, PasswordHash, PasswordVerifier,
9};
10use axum::{body::Body, extract::State, http::Request, middleware::Next, response::IntoResponse};
11use axum_extra::{
12 headers::{authorization::Bearer, Authorization},
13 TypedHeader,
14};
15use rand::{rngs::StdRng, RngCore, SeedableRng};
16use sea_orm::{entity::prelude::*, Set};
17
18use crate::{
19 config::InfraPool,
20 entity::{state::UserState, users as User, workers as Worker},
21 error::{ApiError, AuthError},
22 schema::{UserChangePasswordReq, UserLoginReq},
23};
24use token::{generate_token, verify_token};
25
26#[derive(Debug, Clone)]
27pub struct AuthUser {
28 pub id: i64,
29}
30
31#[derive(Debug, Clone)]
32pub struct AuthUserWithName {
33 pub id: i64,
34 pub username: String,
35}
36
37#[derive(Debug, Clone)]
38pub struct AuthAdminUser {
39 pub id: i64,
40}
41
42#[derive(Debug, Clone)]
43pub struct AuthWorker {
44 pub id: i64,
45}
46
47pub(crate) fn get_and_prompt_username(
48 username: Option<String>,
49 prompt: &str,
50) -> crate::error::Result<String> {
51 let username = username
52 .map(|u| {
53 println!("{prompt}: {u}");
54 Ok::<_, std::io::Error>(u.clone())
55 })
56 .unwrap_or_else(|| {
57 let mut user = String::new();
58 print!("{prompt}: ");
59 std::io::stdout().flush()?;
60 std::io::stdin().read_line(&mut user)?;
61 user.pop();
62 Ok(user)
63 })?;
64 Ok(username)
65}
66
67pub(crate) fn get_and_prompt_password(
68 password: Option<String>,
69 prompt: &str,
70) -> crate::error::Result<[u8; 16]> {
71 let md5_password = password
72 .map(|p| {
73 println!("{prompt} Already Given");
74 Ok::<_, std::io::Error>(md5::compute(p.as_bytes()).0)
75 })
76 .unwrap_or_else(|| {
77 let password = rpassword::prompt_password(format!("Please Input {prompt}: "))?;
78 Ok(md5::compute(password.as_bytes()).0)
79 })?;
80 Ok(md5_password)
81}
82
83pub(crate) fn fill_user_login(
84 username: Option<String>,
85 password: Option<String>,
86 retain: bool,
87) -> crate::error::Result<UserLoginReq> {
88 match (username, password) {
89 (Some(username), Some(password)) => Ok(UserLoginReq {
90 username,
91 md5_password: md5::compute(password.as_bytes()).0,
92 retain,
93 }),
94 (username, password) => {
95 let username = get_and_prompt_username(username, "Username")?;
96 let md5_password = get_and_prompt_password(password, "Password")?;
97 Ok(UserLoginReq {
98 username,
99 md5_password,
100 retain,
101 })
102 }
103 }
104}
105
106pub async fn user_login(
107 db: &DatabaseConnection,
108 username: &str,
109 md5_password: &[u8; 16],
110 retain: bool,
111 ip: SocketAddr,
112) -> crate::error::Result<String> {
113 match User::Entity::find()
114 .filter(User::Column::Username.eq(username))
115 .one(db)
116 .await?
117 {
118 Some(user) => {
119 if user.state != UserState::Active {
120 return Err(AuthError::PermissionDenied.into());
121 }
122 let parsed_hash = PasswordHash::new(&user.encrypted_password)?;
123 if Argon2::default()
124 .verify_password(md5_password, &parsed_hash)
125 .is_ok()
126 {
127 let sign = if retain {
128 user.auth_signature
129 .unwrap_or_else(|| StdRng::from_os_rng().next_u32() as i64)
130 } else {
131 (1 + StdRng::from_os_rng().next_u32()) as i64
132 };
133 let token = generate_token(username, sign)?;
134 let now = TimeDateTimeWithTimeZone::now_utc();
135 let active_user = User::ActiveModel {
136 id: Set(user.id),
137 auth_signature: Set(Some(sign)),
138 current_sign_in_at: Set(Some(now)),
139 last_sign_in_at: Set(user.current_sign_in_at),
140 current_sign_in_ip: Set(Some(ip.ip().to_string())),
141 last_sign_in_ip: Set(user.current_sign_in_ip),
142 updated_at: Set(now),
143 ..Default::default()
144 };
145 active_user.update(db).await?;
146 tracing::debug!("User {} logged in", username);
147 Ok(token)
148 } else {
149 tracing::debug!("Wrong password for user {}", username);
150 Err(AuthError::WrongCredentials.into())
151 }
152 }
153 None => {
154 tracing::debug!("User {} not found", username);
155 Err(AuthError::WrongCredentials.into())
156 }
157 }
158}
159
160pub async fn user_change_password(
161 db: &DatabaseConnection,
162 user_id: i64,
163 ip: SocketAddr,
164 username: String,
165 UserChangePasswordReq {
166 old_md5_password,
167 new_md5_password,
168 }: UserChangePasswordReq,
169) -> crate::error::Result<String> {
170 let user = User::Entity::find_by_id(user_id)
171 .one(db)
172 .await?
173 .ok_or(ApiError::NotFound("User not found".to_string()))?;
174 if user.username != username {
175 return Err(AuthError::WrongCredentials.into());
176 }
177 if user.state != UserState::Active {
178 return Err(AuthError::PermissionDenied.into());
179 }
180 let parsed_hash = PasswordHash::new(&user.encrypted_password)?;
181 if Argon2::default()
182 .verify_password(&old_md5_password, &parsed_hash)
183 .is_ok()
184 {
185 let salt = SaltString::generate(&mut OsRng);
186 let argon2 = Argon2::default();
187 let password_hash = argon2.hash_password(&new_md5_password, &salt)?.to_string();
188 let sign = StdRng::from_os_rng().next_u32() as i64;
189 let token = generate_token(&username, sign)?;
190 let now = TimeDateTimeWithTimeZone::now_utc();
191 let active_user = User::ActiveModel {
192 id: Set(user.id),
193 encrypted_password: Set(password_hash),
194 auth_signature: Set(Some(sign)),
195 current_sign_in_at: Set(Some(now)),
196 last_sign_in_at: Set(user.current_sign_in_at),
197 current_sign_in_ip: Set(Some(ip.ip().to_string())),
198 last_sign_in_ip: Set(user.current_sign_in_ip),
199 updated_at: Set(now),
200 ..Default::default()
201 };
202 tracing::debug!("User {} change password and logged in", username);
203 active_user.update(db).await?;
204 Ok(token)
205 } else {
206 tracing::debug!("Wrong password for user {}", username);
207 Err(AuthError::WrongCredentials.into())
208 }
209}
210
211pub async fn admin_change_password(
212 db: &DatabaseConnection,
213 username: String,
214 new_md5_password: [u8; 16],
215) -> crate::error::Result<()> {
216 let user = User::Entity::find()
217 .filter(User::Column::Username.eq(&username))
218 .one(db)
219 .await?
220 .ok_or(ApiError::NotFound("User not found".to_string()))?;
221 let salt = SaltString::generate(&mut OsRng);
222 let argon2 = Argon2::default();
223 let password_hash = argon2.hash_password(&new_md5_password, &salt)?.to_string();
224 let now = TimeDateTimeWithTimeZone::now_utc();
225 let sign = StdRng::from_os_rng().next_u32() as i64;
226 let active_user = User::ActiveModel {
227 id: Set(user.id),
228 encrypted_password: Set(password_hash),
229 auth_signature: Set(Some(sign)),
230 updated_at: Set(now),
231 ..Default::default()
232 };
233 tracing::debug!("User {} change password", username);
234 active_user.update(db).await?;
235 Ok(())
236}
237
238pub async fn user_auth_middleware(
239 State(pool): State<InfraPool>,
240 TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
241 mut req: Request<Body>,
242 next: Next,
243) -> Result<impl IntoResponse, ApiError> {
244 let auth_user = user_auth(&pool.db, &bearer).await?;
245 req.extensions_mut().insert(AuthUser { id: auth_user.id });
246 Ok(next.run(req).await)
247}
248
249pub async fn user_auth_with_name_middleware(
250 State(pool): State<InfraPool>,
251 TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
252 mut req: Request<Body>,
253 next: Next,
254) -> Result<impl IntoResponse, ApiError> {
255 let auth_user = user_auth(&pool.db, &bearer).await?;
256 req.extensions_mut().insert(AuthUserWithName {
257 id: auth_user.id,
258 username: auth_user.username,
259 });
260 Ok(next.run(req).await)
261}
262
263async fn user_auth(db: &DatabaseConnection, bearer: &Bearer) -> Result<User::Model, AuthError> {
264 let token = bearer.token();
265 let claims = verify_token(token).map_err(|_| AuthError::InvalidToken)?;
266 let now = TimeDateTimeWithTimeZone::now_utc();
267 if claims.exp < now {
268 return Err(AuthError::WrongCredentials);
269 }
270
271 let user = User::Entity::find()
272 .filter(User::Column::Username.eq(claims.sub))
273 .one(db)
274 .await
275 .map_err(|_| AuthError::WrongCredentials)?
276 .ok_or(AuthError::WrongCredentials)?;
277
278 if user.state != UserState::Active {
279 Err(AuthError::PermissionDenied)
280 } else if user.auth_signature != Some(claims.sign) {
281 Err(AuthError::WrongCredentials)
282 } else {
283 Ok(user)
284 }
285}
286
287pub async fn admin_auth_middleware(
288 State(pool): State<InfraPool>,
289 TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
290 mut req: Request<Body>,
291 next: Next,
292) -> Result<impl IntoResponse, ApiError> {
293 let admin_user = admin_auth(&pool.db, &bearer).await?;
294 req.extensions_mut().insert(admin_user);
295 Ok(next.run(req).await)
296}
297
298async fn admin_auth(db: &DatabaseConnection, bearer: &Bearer) -> Result<AuthAdminUser, AuthError> {
299 let token = bearer.token();
300 let claims = verify_token(token).map_err(|_| AuthError::InvalidToken)?;
301 let now = TimeDateTimeWithTimeZone::now_utc();
302 if claims.exp < now {
303 return Err(AuthError::WrongCredentials);
304 }
305
306 let user = User::Entity::find()
307 .filter(User::Column::Username.eq(claims.sub))
308 .one(db)
309 .await
310 .map_err(|_| AuthError::WrongCredentials)?
311 .ok_or(AuthError::WrongCredentials)?;
312 if user.admin {
313 if user.state != UserState::Active {
314 Err(AuthError::PermissionDenied)
315 } else if user.auth_signature != Some(claims.sign) {
316 Err(AuthError::WrongCredentials)
317 } else {
318 Ok(AuthAdminUser { id: user.id })
319 }
320 } else {
321 Err(AuthError::PermissionDenied)
322 }
323}
324
325pub async fn worker_auth_middleware(
326 State(pool): State<InfraPool>,
327 TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
328 mut req: Request<Body>,
329 next: Next,
330) -> Result<impl IntoResponse, ApiError> {
331 let auth_worker = worker_auth(&pool.db, &bearer).await?;
332 req.extensions_mut().insert(auth_worker);
333 Ok(next.run(req).await)
334}
335
336async fn worker_auth(db: &DatabaseConnection, bearer: &Bearer) -> Result<AuthWorker, AuthError> {
337 let token = bearer.token();
338 let claims = verify_token(token).map_err(|_| AuthError::InvalidToken)?;
339 let uuid = Uuid::parse_str(&claims.sub).map_err(|_| AuthError::InvalidToken)?;
340
341 let worker = Worker::Entity::find()
342 .filter(Worker::Column::WorkerId.eq(uuid))
343 .one(db)
344 .await
345 .map_err(|_| AuthError::WrongCredentials)?
346 .ok_or(AuthError::WrongCredentials)?;
347 Ok(AuthWorker { id: worker.id })
348}