1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::db::Db;
7use crate::error::AuthError;
8use crate::password::hash_password;
9use crate::types::{Email, User, UserId, Username};
10
11pub(crate) fn map_unique_violation(err: sqlx::Error) -> AuthError {
16 if let sqlx::Error::Database(ref db_err) = err {
17 let msg = db_err.message();
18 if msg.contains("UNIQUE constraint failed") {
19 if msg.contains("email") {
20 return AuthError::Conflict("email already exists".into());
21 }
22 if msg.contains("username") {
23 return AuthError::Conflict("username already exists".into());
24 }
25 return AuthError::Conflict(msg.to_string());
26 }
27 }
28 AuthError::Database(err)
29}
30
31pub struct SearchUsersParams<'a> {
33 pub query: Option<&'a str>,
34 pub is_active: Option<bool>,
35 pub has_mfa: Option<bool>,
36 pub limit: u32,
37 pub offset: u32,
38}
39
40#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
42pub struct UserListEntry {
43 pub id: UserId,
44 pub email: Email,
45 pub username: Option<Username>,
46 pub is_active: bool,
47 pub has_mfa: bool,
48 pub created_at: DateTime<Utc>,
49}
50
51pub struct SearchUsersResult {
53 pub users: Vec<UserListEntry>,
54 pub total: u32,
55}
56
57pub struct UserCursor {
61 pub created_at: DateTime<Utc>,
62 pub id: UserId,
63}
64
65#[derive(Serialize, Deserialize)]
66struct RawUserCursor {
67 ca: String,
68 id: String,
69}
70
71impl UserCursor {
72 pub fn from_entry(entry: &UserListEntry) -> Self {
73 Self {
74 created_at: entry.created_at,
75 id: entry.id,
76 }
77 }
78
79 pub fn encode(&self) -> String {
80 let raw = RawUserCursor {
81 ca: self.created_at.to_rfc3339(),
82 id: self.id.to_string(),
83 };
84 let json = serde_json::to_string(&raw).expect("RawUserCursor serializes");
85 Base64UrlUnpadded::encode_string(json.as_bytes())
86 }
87
88 pub fn decode(s: &str) -> Option<Self> {
89 let bytes = Base64UrlUnpadded::decode_vec(s).ok()?;
90 let raw: RawUserCursor = serde_json::from_slice(&bytes).ok()?;
91 let created_at = chrono::DateTime::parse_from_rfc3339(&raw.ca)
92 .ok()?
93 .with_timezone(&Utc);
94 let id = raw.id.parse::<uuid::Uuid>().ok().map(UserId::from_uuid)?;
95 Some(Self { created_at, id })
96 }
97}
98
99impl Db {
100 pub async fn create_user(
105 &self,
106 email: Email,
107 password: &str,
108 username: Option<Username>,
109 custom_data: Option<&Value>,
110 ) -> Result<User, AuthError> {
111 let id = UserId::new();
112 let pw_hash = hash_password(password)?;
113 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
114
115 sqlx::query(
116 "INSERT INTO allowthem_users \
117 (id, email, username, password_hash, email_verified, is_active, created_at, updated_at, custom_data) \
118 VALUES (?1, ?2, ?3, ?4, 0, 1, ?5, ?5, ?6)",
119 )
120 .bind(id)
121 .bind(&email)
122 .bind(&username)
123 .bind(&pw_hash)
124 .bind(&now)
125 .bind(custom_data.map(sqlx::types::Json))
126 .execute(self.pool())
127 .await
128 .map_err(map_unique_violation)?;
129
130 self.get_user(id).await
131 }
132
133 pub async fn create_user_with_hash(
136 &self,
137 email: Email,
138 password_hash: &str,
139 username: Option<Username>,
140 custom_data: Option<&Value>,
141 ) -> Result<User, AuthError> {
142 let id = UserId::new();
143 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
144
145 sqlx::query(
146 "INSERT INTO allowthem_users (id, email, username, password_hash, email_verified, is_active, created_at, updated_at, custom_data)
147 VALUES (?1, ?2, ?3, ?4, 0, 1, ?5, ?5, ?6)",
148 )
149 .bind(id)
150 .bind(&email)
151 .bind(&username)
152 .bind(password_hash)
153 .bind(&now)
154 .bind(custom_data.map(sqlx::types::Json))
155 .execute(self.pool())
156 .await
157 .map_err(map_unique_violation)?;
158
159 self.get_user(id).await
160 }
161
162 pub async fn get_user(&self, id: UserId) -> Result<User, AuthError> {
164 sqlx::query_as::<_, User>(
165 "SELECT id, email, username, NULL as password_hash, \
166 email_verified, is_active, created_at, updated_at, custom_data \
167 FROM allowthem_users WHERE id = ?",
168 )
169 .bind(id)
170 .fetch_optional(self.pool())
171 .await?
172 .ok_or(AuthError::NotFound)
173 }
174
175 pub async fn get_user_by_email(&self, email: &Email) -> Result<User, AuthError> {
177 sqlx::query_as::<_, User>(
178 "SELECT id, email, username, NULL as password_hash, \
179 email_verified, is_active, created_at, updated_at, custom_data \
180 FROM allowthem_users WHERE email = ?",
181 )
182 .bind(email)
183 .fetch_optional(self.pool())
184 .await?
185 .ok_or(AuthError::NotFound)
186 }
187
188 pub async fn get_user_by_username(&self, username: &Username) -> Result<User, AuthError> {
190 sqlx::query_as::<_, User>(
191 "SELECT id, email, username, NULL as password_hash, \
192 email_verified, is_active, created_at, updated_at, custom_data \
193 FROM allowthem_users WHERE username = ?",
194 )
195 .bind(username)
196 .fetch_optional(self.pool())
197 .await?
198 .ok_or(AuthError::NotFound)
199 }
200
201 pub async fn find_for_login(&self, identifier: &str) -> Result<User, AuthError> {
206 sqlx::query_as::<_, User>(
207 "SELECT id, email, username, password_hash, \
208 email_verified, is_active, created_at, updated_at, custom_data \
209 FROM allowthem_users WHERE email = ?1 OR username = ?1",
210 )
211 .bind(identifier)
212 .fetch_optional(self.pool())
213 .await?
214 .ok_or(AuthError::NotFound)
215 }
216
217 pub async fn update_user_email(&self, id: UserId, email: Email) -> Result<(), AuthError> {
219 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
220 let result =
221 sqlx::query("UPDATE allowthem_users SET email = ?1, updated_at = ?2 WHERE id = ?3")
222 .bind(&email)
223 .bind(&now)
224 .bind(id)
225 .execute(self.pool())
226 .await
227 .map_err(map_unique_violation)?;
228
229 if result.rows_affected() == 0 {
230 return Err(AuthError::NotFound);
231 }
232 Ok(())
233 }
234
235 pub async fn update_user_username(
237 &self,
238 id: UserId,
239 username: Option<Username>,
240 ) -> Result<(), AuthError> {
241 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
242 let result =
243 sqlx::query("UPDATE allowthem_users SET username = ?1, updated_at = ?2 WHERE id = ?3")
244 .bind(&username)
245 .bind(&now)
246 .bind(id)
247 .execute(self.pool())
248 .await
249 .map_err(map_unique_violation)?;
250
251 if result.rows_affected() == 0 {
252 return Err(AuthError::NotFound);
253 }
254 Ok(())
255 }
256
257 pub async fn update_user_active(&self, id: UserId, is_active: bool) -> Result<(), AuthError> {
259 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
260 let result =
261 sqlx::query("UPDATE allowthem_users SET is_active = ?1, updated_at = ?2 WHERE id = ?3")
262 .bind(is_active)
263 .bind(&now)
264 .bind(id)
265 .execute(self.pool())
266 .await?;
267
268 if result.rows_affected() == 0 {
269 return Err(AuthError::NotFound);
270 }
271 Ok(())
272 }
273
274 pub async fn delete_user(&self, id: UserId) -> Result<(), AuthError> {
276 let result = sqlx::query("DELETE FROM allowthem_users WHERE id = ?")
277 .bind(id)
278 .execute(self.pool())
279 .await?;
280
281 if result.rows_affected() == 0 {
282 return Err(AuthError::NotFound);
283 }
284 Ok(())
285 }
286
287 pub async fn list_users(&self) -> Result<Vec<User>, AuthError> {
289 sqlx::query_as::<_, User>(
290 "SELECT id, email, username, NULL as password_hash, \
291 email_verified, is_active, created_at, updated_at, custom_data \
292 FROM allowthem_users ORDER BY created_at ASC",
293 )
294 .fetch_all(self.pool())
295 .await
296 .map_err(AuthError::Database)
297 }
298
299 pub async fn list_users_paginated(
304 &self,
305 limit: u32,
306 cursor: Option<&UserCursor>,
307 ) -> Result<Vec<UserListEntry>, AuthError> {
308 let limit = (limit as i64).min(200);
309 match cursor {
310 None => sqlx::query_as::<_, UserListEntry>(
311 "SELECT u.id, u.email, u.username, u.is_active, \
312 EXISTS (SELECT 1 FROM allowthem_mfa_secrets \
313 WHERE user_id = u.id AND enabled = 1) AS has_mfa, \
314 u.created_at \
315 FROM allowthem_users u \
316 ORDER BY u.created_at ASC, u.id ASC \
317 LIMIT ?",
318 )
319 .bind(limit)
320 .fetch_all(self.pool())
321 .await
322 .map_err(AuthError::Database),
323 Some(c) => {
324 let ca = c.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
325 sqlx::query_as::<_, UserListEntry>(
326 "SELECT u.id, u.email, u.username, u.is_active, \
327 EXISTS (SELECT 1 FROM allowthem_mfa_secrets \
328 WHERE user_id = u.id AND enabled = 1) AS has_mfa, \
329 u.created_at \
330 FROM allowthem_users u \
331 WHERE (u.created_at > ?1 OR (u.created_at = ?1 AND u.id > ?2)) \
332 ORDER BY u.created_at ASC, u.id ASC \
333 LIMIT ?3",
334 )
335 .bind(&ca)
336 .bind(c.id)
337 .bind(limit)
338 .fetch_all(self.pool())
339 .await
340 .map_err(AuthError::Database)
341 }
342 }
343 }
344
345 pub async fn search_users(
351 &self,
352 params: SearchUsersParams<'_>,
353 ) -> Result<SearchUsersResult, AuthError> {
354 let mut where_clauses: Vec<String> = Vec::new();
355 let mut bind_values: Vec<String> = Vec::new();
356
357 if let Some(q) = params.query {
358 let trimmed = q.trim();
359 if !trimmed.is_empty() {
360 let escaped = trimmed
361 .replace('\\', "\\\\")
362 .replace('%', "\\%")
363 .replace('_', "\\_");
364 let pattern = format!("%{escaped}%");
365 where_clauses
366 .push("(u.email LIKE ? ESCAPE '\\' OR u.username LIKE ? ESCAPE '\\')".into());
367 bind_values.push(pattern.clone());
368 bind_values.push(pattern);
369 }
370 }
371
372 if let Some(active) = params.is_active {
373 where_clauses.push("u.is_active = ?".into());
374 bind_values.push(if active { "1".into() } else { "0".into() });
375 }
376
377 if let Some(has_mfa) = params.has_mfa {
378 let exists = if has_mfa { "EXISTS" } else { "NOT EXISTS" };
379 where_clauses.push(format!(
380 "{exists} (SELECT 1 FROM allowthem_mfa_secrets WHERE user_id = u.id AND enabled = 1)"
381 ));
382 }
383
384 let where_sql = if where_clauses.is_empty() {
385 String::new()
386 } else {
387 format!("WHERE {}", where_clauses.join(" AND "))
388 };
389
390 let count_sql: &'static str = Box::leak(
391 format!("SELECT COUNT(*) FROM allowthem_users u {where_sql}").into_boxed_str(),
392 );
393 let mut count_query = sqlx::query_scalar::<_, i64>(count_sql);
394 for val in &bind_values {
395 count_query = count_query.bind(val);
396 }
397 let total = count_query
398 .fetch_one(self.pool())
399 .await
400 .map_err(AuthError::Database)? as u32;
401
402 let data_sql: &'static str = Box::leak(
403 format!(
404 "SELECT u.id, u.email, u.username, u.is_active, \
405 EXISTS (SELECT 1 FROM allowthem_mfa_secrets \
406 WHERE user_id = u.id AND enabled = 1) as has_mfa, \
407 u.created_at \
408 FROM allowthem_users u {where_sql} \
409 ORDER BY u.created_at ASC \
410 LIMIT ? OFFSET ?"
411 )
412 .into_boxed_str(),
413 );
414 let mut data_query = sqlx::query_as::<_, UserListEntry>(data_sql);
415 for val in &bind_values {
416 data_query = data_query.bind(val);
417 }
418 data_query = data_query.bind(params.limit).bind(params.offset);
419
420 let users = data_query
421 .fetch_all(self.pool())
422 .await
423 .map_err(AuthError::Database)?;
424
425 Ok(SearchUsersResult { users, total })
426 }
427
428 pub async fn update_user_password(
432 &self,
433 id: UserId,
434 new_password: &str,
435 ) -> Result<(), AuthError> {
436 let pw_hash = hash_password(new_password)?;
437 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
438 let result = sqlx::query(
439 "UPDATE allowthem_users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3",
440 )
441 .bind(&pw_hash)
442 .bind(&now)
443 .bind(id)
444 .execute(self.pool())
445 .await?;
446
447 if result.rows_affected() == 0 {
448 return Err(AuthError::NotFound);
449 }
450 Ok(())
451 }
452
453 pub async fn clear_password_hash(&self, id: UserId) -> Result<(), AuthError> {
459 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
460 let result = sqlx::query(
461 "UPDATE allowthem_users SET password_hash = NULL, updated_at = ? WHERE id = ?",
462 )
463 .bind(&now)
464 .bind(id)
465 .execute(self.pool())
466 .await?;
467
468 if result.rows_affected() == 0 {
469 return Err(AuthError::NotFound);
470 }
471 Ok(())
472 }
473
474 pub async fn get_custom_data(&self, id: &UserId) -> Result<Option<Value>, AuthError> {
479 let row: Option<(Option<Value>,)> =
480 sqlx::query_as("SELECT custom_data FROM allowthem_users WHERE id = ?")
481 .bind(id)
482 .fetch_optional(self.pool())
483 .await?;
484
485 match row {
486 None => Err(AuthError::NotFound),
487 Some((data,)) => Ok(data),
488 }
489 }
490
491 pub async fn set_custom_data(&self, id: &UserId, data: &Value) -> Result<(), AuthError> {
495 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
496 let result = sqlx::query(
497 "UPDATE allowthem_users SET custom_data = ?1, updated_at = ?2 WHERE id = ?3",
498 )
499 .bind(sqlx::types::Json(data))
500 .bind(&now)
501 .bind(id)
502 .execute(self.pool())
503 .await?;
504
505 if result.rows_affected() == 0 {
506 return Err(AuthError::NotFound);
507 }
508 Ok(())
509 }
510
511 pub async fn delete_custom_data(&self, id: &UserId) -> Result<(), AuthError> {
515 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
516 sqlx::query("UPDATE allowthem_users SET custom_data = NULL, updated_at = ?1 WHERE id = ?2")
517 .bind(&now)
518 .bind(id)
519 .execute(self.pool())
520 .await?;
521
522 Ok(())
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::handle::{AllowThem, AllowThemBuilder};
530
531 async fn setup() -> AllowThem {
532 AllowThemBuilder::new("sqlite::memory:")
533 .cookie_secure(false)
534 .build()
535 .await
536 .unwrap()
537 }
538
539 async fn make_user(db: &Db, tag: u32) -> crate::types::User {
540 let email = Email::new(format!("user{tag}@example.com")).unwrap();
541 db.create_user(email, "pw123456", None, None).await.unwrap()
542 }
543
544 #[tokio::test]
545 async fn user_cursor_encode_decode_roundtrip() {
546 let ath = setup().await;
547 let db = ath.db();
548 let user = make_user(db, 1).await;
549 let entries = db.list_users_paginated(10, None).await.unwrap();
550 assert_eq!(entries.len(), 1);
551 let cursor = UserCursor::from_entry(&entries[0]);
552 let encoded = cursor.encode();
553 let decoded = UserCursor::decode(&encoded).unwrap();
554 assert_eq!(decoded.id, user.id);
555 }
556
557 #[tokio::test]
558 async fn list_users_paginated_returns_first_page() {
559 let ath = setup().await;
560 let db = ath.db();
561 for i in 0..5 {
562 make_user(db, i).await;
563 }
564 let page = db.list_users_paginated(3, None).await.unwrap();
565 assert_eq!(page.len(), 3);
566 }
567
568 #[tokio::test]
569 async fn list_users_paginated_cursor_advances() {
570 let ath = setup().await;
571 let db = ath.db();
572 for i in 0..5 {
573 make_user(db, i + 10).await;
574 }
575 let page1 = db.list_users_paginated(3, None).await.unwrap();
576 assert_eq!(page1.len(), 3);
577 let cursor = UserCursor::from_entry(page1.last().unwrap());
578 let page2 = db.list_users_paginated(3, Some(&cursor)).await.unwrap();
579 assert_eq!(page2.len(), 2);
580 assert!(!page2.iter().any(|u| page1.iter().any(|v| v.id == u.id)));
581 }
582}