Skip to main content

allowthem_core/
users.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::db::Db;
7use crate::error::AuthError;
8use crate::password::hash_password;
9use crate::types::{Email, User, UserId, Username};
10
11/// Map a SQLite UNIQUE constraint violation to `AuthError::Conflict`.
12///
13/// SQLite UNIQUE violations include the constraint name in the message,
14/// e.g. "UNIQUE constraint failed: allowthem_users.email".
15pub(crate) fn map_unique_violation(err: sqlx::Error) -> AuthError {
16    if let sqlx::Error::Database(ref db_err) = err {
17        let msg = db_err.message();
18        if msg.contains("UNIQUE constraint failed") {
19            if msg.contains("email") {
20                return AuthError::Conflict("email already exists".into());
21            }
22            if msg.contains("username") {
23                return AuthError::Conflict("username already exists".into());
24            }
25            return AuthError::Conflict(msg.to_string());
26        }
27    }
28    AuthError::Database(err)
29}
30
31/// Parameters for searching/filtering users in the admin directory.
32pub struct SearchUsersParams<'a> {
33    pub query: Option<&'a str>,
34    pub is_active: Option<bool>,
35    pub has_mfa: Option<bool>,
36    pub limit: u32,
37    pub offset: u32,
38}
39
40/// User with MFA enrollment status, for list display.
41#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
42pub struct UserListEntry {
43    pub id: UserId,
44    pub email: Email,
45    pub username: Option<Username>,
46    pub is_active: bool,
47    pub has_mfa: bool,
48    pub created_at: DateTime<Utc>,
49}
50
51/// Result of a paginated user search.
52pub struct SearchUsersResult {
53    pub users: Vec<UserListEntry>,
54    pub total: u32,
55}
56
57/// Opaque keyset cursor for paginating `list_users_paginated`.
58///
59/// Encodes `(created_at, id)` as a base64url-encoded JSON blob.
60pub struct UserCursor {
61    pub created_at: DateTime<Utc>,
62    pub id: UserId,
63}
64
65#[derive(Serialize, Deserialize)]
66struct RawUserCursor {
67    ca: String,
68    id: String,
69}
70
71impl UserCursor {
72    pub fn from_entry(entry: &UserListEntry) -> Self {
73        Self {
74            created_at: entry.created_at,
75            id: entry.id,
76        }
77    }
78
79    pub fn encode(&self) -> String {
80        let raw = RawUserCursor {
81            ca: self.created_at.to_rfc3339(),
82            id: self.id.to_string(),
83        };
84        let json = serde_json::to_string(&raw).expect("RawUserCursor serializes");
85        Base64UrlUnpadded::encode_string(json.as_bytes())
86    }
87
88    pub fn decode(s: &str) -> Option<Self> {
89        let bytes = Base64UrlUnpadded::decode_vec(s).ok()?;
90        let raw: RawUserCursor = serde_json::from_slice(&bytes).ok()?;
91        let created_at = chrono::DateTime::parse_from_rfc3339(&raw.ca)
92            .ok()?
93            .with_timezone(&Utc);
94        let id = raw.id.parse::<uuid::Uuid>().ok().map(UserId::from_uuid)?;
95        Some(Self { created_at, id })
96    }
97}
98
99impl Db {
100    /// Create a user with email, plaintext password, optional username, and optional custom data.
101    ///
102    /// Hashes the password with Argon2id (via `password::hash_password`).
103    /// Returns the created User (without password_hash in the returned struct).
104    pub async fn create_user(
105        &self,
106        email: Email,
107        password: &str,
108        username: Option<Username>,
109        custom_data: Option<&Value>,
110    ) -> Result<User, AuthError> {
111        let id = UserId::new();
112        let pw_hash = hash_password(password)?;
113        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
114
115        sqlx::query(
116            "INSERT INTO allowthem_users \
117             (id, email, username, password_hash, email_verified, is_active, created_at, updated_at, custom_data) \
118             VALUES (?1, ?2, ?3, ?4, 0, 1, ?5, ?5, ?6)",
119        )
120        .bind(id)
121        .bind(&email)
122        .bind(&username)
123        .bind(&pw_hash)
124        .bind(&now)
125        .bind(custom_data.map(sqlx::types::Json))
126        .execute(self.pool())
127        .await
128        .map_err(map_unique_violation)?;
129
130        self.get_user(id).await
131    }
132
133    /// Import a user with a pre-existing password hash (for migration from external systems).
134    /// The hash must be a valid Argon2 PHC string. No validation is performed on it.
135    pub async fn create_user_with_hash(
136        &self,
137        email: Email,
138        password_hash: &str,
139        username: Option<Username>,
140        custom_data: Option<&Value>,
141    ) -> Result<User, AuthError> {
142        let id = UserId::new();
143        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
144
145        sqlx::query(
146            "INSERT INTO allowthem_users (id, email, username, password_hash, email_verified, is_active, created_at, updated_at, custom_data)
147             VALUES (?1, ?2, ?3, ?4, 0, 1, ?5, ?5, ?6)",
148        )
149        .bind(id)
150        .bind(&email)
151        .bind(&username)
152        .bind(password_hash)
153        .bind(&now)
154        .bind(custom_data.map(sqlx::types::Json))
155        .execute(self.pool())
156        .await
157        .map_err(map_unique_violation)?;
158
159        self.get_user(id).await
160    }
161
162    /// Look up a user by ID. Returns User with password_hash = None.
163    pub async fn get_user(&self, id: UserId) -> Result<User, AuthError> {
164        sqlx::query_as::<_, User>(
165            "SELECT id, email, username, NULL as password_hash, \
166             email_verified, is_active, created_at, updated_at, custom_data \
167             FROM allowthem_users WHERE id = ?",
168        )
169        .bind(id)
170        .fetch_optional(self.pool())
171        .await?
172        .ok_or(AuthError::NotFound)
173    }
174
175    /// Look up a user by email. Returns User with password_hash = None.
176    pub async fn get_user_by_email(&self, email: &Email) -> Result<User, AuthError> {
177        sqlx::query_as::<_, User>(
178            "SELECT id, email, username, NULL as password_hash, \
179             email_verified, is_active, created_at, updated_at, custom_data \
180             FROM allowthem_users WHERE email = ?",
181        )
182        .bind(email)
183        .fetch_optional(self.pool())
184        .await?
185        .ok_or(AuthError::NotFound)
186    }
187
188    /// Look up a user by username. Returns User with password_hash = None.
189    pub async fn get_user_by_username(&self, username: &Username) -> Result<User, AuthError> {
190        sqlx::query_as::<_, User>(
191            "SELECT id, email, username, NULL as password_hash, \
192             email_verified, is_active, created_at, updated_at, custom_data \
193             FROM allowthem_users WHERE username = ?",
194        )
195        .bind(username)
196        .fetch_optional(self.pool())
197        .await?
198        .ok_or(AuthError::NotFound)
199    }
200
201    /// Look up a user by email OR username for login.
202    ///
203    /// Returns User WITH password_hash populated. The caller is responsible
204    /// for calling `verify_password()` to check the password.
205    pub async fn find_for_login(&self, identifier: &str) -> Result<User, AuthError> {
206        sqlx::query_as::<_, User>(
207            "SELECT id, email, username, password_hash, \
208             email_verified, is_active, created_at, updated_at, custom_data \
209             FROM allowthem_users WHERE email = ?1 OR username = ?1",
210        )
211        .bind(identifier)
212        .fetch_optional(self.pool())
213        .await?
214        .ok_or(AuthError::NotFound)
215    }
216
217    /// Update a user's email. Also updates updated_at.
218    pub async fn update_user_email(&self, id: UserId, email: Email) -> Result<(), AuthError> {
219        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
220        let result =
221            sqlx::query("UPDATE allowthem_users SET email = ?1, updated_at = ?2 WHERE id = ?3")
222                .bind(&email)
223                .bind(&now)
224                .bind(id)
225                .execute(self.pool())
226                .await
227                .map_err(map_unique_violation)?;
228
229        if result.rows_affected() == 0 {
230            return Err(AuthError::NotFound);
231        }
232        Ok(())
233    }
234
235    /// Update a user's username (set or clear). Also updates updated_at.
236    pub async fn update_user_username(
237        &self,
238        id: UserId,
239        username: Option<Username>,
240    ) -> Result<(), AuthError> {
241        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
242        let result =
243            sqlx::query("UPDATE allowthem_users SET username = ?1, updated_at = ?2 WHERE id = ?3")
244                .bind(&username)
245                .bind(&now)
246                .bind(id)
247                .execute(self.pool())
248                .await
249                .map_err(map_unique_violation)?;
250
251        if result.rows_affected() == 0 {
252            return Err(AuthError::NotFound);
253        }
254        Ok(())
255    }
256
257    /// Update a user's is_active flag. Also updates updated_at.
258    pub async fn update_user_active(&self, id: UserId, is_active: bool) -> Result<(), AuthError> {
259        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
260        let result =
261            sqlx::query("UPDATE allowthem_users SET is_active = ?1, updated_at = ?2 WHERE id = ?3")
262                .bind(is_active)
263                .bind(&now)
264                .bind(id)
265                .execute(self.pool())
266                .await?;
267
268        if result.rows_affected() == 0 {
269            return Err(AuthError::NotFound);
270        }
271        Ok(())
272    }
273
274    /// Delete a user by ID. Cascades to sessions, user_roles, user_permissions.
275    pub async fn delete_user(&self, id: UserId) -> Result<(), AuthError> {
276        let result = sqlx::query("DELETE FROM allowthem_users WHERE id = ?")
277            .bind(id)
278            .execute(self.pool())
279            .await?;
280
281        if result.rows_affected() == 0 {
282            return Err(AuthError::NotFound);
283        }
284        Ok(())
285    }
286
287    /// List all users ordered by `created_at ASC`. Returns User with `password_hash = None`.
288    pub async fn list_users(&self) -> Result<Vec<User>, AuthError> {
289        sqlx::query_as::<_, User>(
290            "SELECT id, email, username, NULL as password_hash, \
291             email_verified, is_active, created_at, updated_at, custom_data \
292             FROM allowthem_users ORDER BY created_at ASC",
293        )
294        .fetch_all(self.pool())
295        .await
296        .map_err(AuthError::Database)
297    }
298
299    /// Paginated list of users using a `(created_at, id)` keyset cursor.
300    ///
301    /// Limits are capped at 200. Pass `None` for cursor to start from the beginning.
302    /// Results are ordered oldest-first.
303    pub async fn list_users_paginated(
304        &self,
305        limit: u32,
306        cursor: Option<&UserCursor>,
307    ) -> Result<Vec<UserListEntry>, AuthError> {
308        let limit = (limit as i64).min(200);
309        match cursor {
310            None => sqlx::query_as::<_, UserListEntry>(
311                "SELECT u.id, u.email, u.username, u.is_active, \
312                 EXISTS (SELECT 1 FROM allowthem_mfa_secrets \
313                         WHERE user_id = u.id AND enabled = 1) AS has_mfa, \
314                 u.created_at \
315                 FROM allowthem_users u \
316                 ORDER BY u.created_at ASC, u.id ASC \
317                 LIMIT ?",
318            )
319            .bind(limit)
320            .fetch_all(self.pool())
321            .await
322            .map_err(AuthError::Database),
323            Some(c) => {
324                let ca = c.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
325                sqlx::query_as::<_, UserListEntry>(
326                    "SELECT u.id, u.email, u.username, u.is_active, \
327                     EXISTS (SELECT 1 FROM allowthem_mfa_secrets \
328                             WHERE user_id = u.id AND enabled = 1) AS has_mfa, \
329                     u.created_at \
330                     FROM allowthem_users u \
331                     WHERE (u.created_at > ?1 OR (u.created_at = ?1 AND u.id > ?2)) \
332                     ORDER BY u.created_at ASC, u.id ASC \
333                     LIMIT ?3",
334                )
335                .bind(&ca)
336                .bind(c.id)
337                .bind(limit)
338                .fetch_all(self.pool())
339                .await
340                .map_err(AuthError::Database)
341            }
342        }
343    }
344
345    /// Search and filter users with pagination.
346    ///
347    /// Builds a dynamic query with optional search term (matched against
348    /// email and username via LIKE), status filter, and MFA filter.
349    /// Returns matching users with their MFA enrollment status.
350    pub async fn search_users(
351        &self,
352        params: SearchUsersParams<'_>,
353    ) -> Result<SearchUsersResult, AuthError> {
354        let mut where_clauses: Vec<String> = Vec::new();
355        let mut bind_values: Vec<String> = Vec::new();
356
357        if let Some(q) = params.query {
358            let trimmed = q.trim();
359            if !trimmed.is_empty() {
360                let escaped = trimmed
361                    .replace('\\', "\\\\")
362                    .replace('%', "\\%")
363                    .replace('_', "\\_");
364                let pattern = format!("%{escaped}%");
365                where_clauses
366                    .push("(u.email LIKE ? ESCAPE '\\' OR u.username LIKE ? ESCAPE '\\')".into());
367                bind_values.push(pattern.clone());
368                bind_values.push(pattern);
369            }
370        }
371
372        if let Some(active) = params.is_active {
373            where_clauses.push("u.is_active = ?".into());
374            bind_values.push(if active { "1".into() } else { "0".into() });
375        }
376
377        if let Some(has_mfa) = params.has_mfa {
378            let exists = if has_mfa { "EXISTS" } else { "NOT EXISTS" };
379            where_clauses.push(format!(
380                "{exists} (SELECT 1 FROM allowthem_mfa_secrets WHERE user_id = u.id AND enabled = 1)"
381            ));
382        }
383
384        let where_sql = if where_clauses.is_empty() {
385            String::new()
386        } else {
387            format!("WHERE {}", where_clauses.join(" AND "))
388        };
389
390        let count_sql: &'static str = Box::leak(
391            format!("SELECT COUNT(*) FROM allowthem_users u {where_sql}").into_boxed_str(),
392        );
393        let mut count_query = sqlx::query_scalar::<_, i64>(count_sql);
394        for val in &bind_values {
395            count_query = count_query.bind(val);
396        }
397        let total = count_query
398            .fetch_one(self.pool())
399            .await
400            .map_err(AuthError::Database)? as u32;
401
402        let data_sql: &'static str = Box::leak(
403            format!(
404                "SELECT u.id, u.email, u.username, u.is_active, \
405                 EXISTS (SELECT 1 FROM allowthem_mfa_secrets \
406                         WHERE user_id = u.id AND enabled = 1) as has_mfa, \
407                 u.created_at \
408                 FROM allowthem_users u {where_sql} \
409                 ORDER BY u.created_at ASC \
410                 LIMIT ? OFFSET ?"
411            )
412            .into_boxed_str(),
413        );
414        let mut data_query = sqlx::query_as::<_, UserListEntry>(data_sql);
415        for val in &bind_values {
416            data_query = data_query.bind(val);
417        }
418        data_query = data_query.bind(params.limit).bind(params.offset);
419
420        let users = data_query
421            .fetch_all(self.pool())
422            .await
423            .map_err(AuthError::Database)?;
424
425        Ok(SearchUsersResult { users, total })
426    }
427
428    /// Update a user's password. Hashes `new_password` with Argon2id and stores it.
429    ///
430    /// Returns `AuthError::NotFound` if no user with `id` exists.
431    pub async fn update_user_password(
432        &self,
433        id: UserId,
434        new_password: &str,
435    ) -> Result<(), AuthError> {
436        let pw_hash = hash_password(new_password)?;
437        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
438        let result = sqlx::query(
439            "UPDATE allowthem_users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3",
440        )
441        .bind(&pw_hash)
442        .bind(&now)
443        .bind(id)
444        .execute(self.pool())
445        .await?;
446
447        if result.rows_affected() == 0 {
448            return Err(AuthError::NotFound);
449        }
450        Ok(())
451    }
452
453    /// Set a user's password hash to NULL.
454    ///
455    /// Used by admin force-password-reset to invalidate the current password.
456    /// The login flow falls back to a dummy hash when `password_hash` is NULL,
457    /// so `verify_password` will always fail.
458    pub async fn clear_password_hash(&self, id: UserId) -> Result<(), AuthError> {
459        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
460        let result = sqlx::query(
461            "UPDATE allowthem_users SET password_hash = NULL, updated_at = ? WHERE id = ?",
462        )
463        .bind(&now)
464        .bind(id)
465        .execute(self.pool())
466        .await?;
467
468        if result.rows_affected() == 0 {
469            return Err(AuthError::NotFound);
470        }
471        Ok(())
472    }
473
474    /// Get a user's custom data.
475    ///
476    /// Returns `Err(NotFound)` if no user with `id` exists.
477    /// Returns `Ok(None)` if the user exists but has no custom data.
478    pub async fn get_custom_data(&self, id: &UserId) -> Result<Option<Value>, AuthError> {
479        let row: Option<(Option<Value>,)> =
480            sqlx::query_as("SELECT custom_data FROM allowthem_users WHERE id = ?")
481                .bind(id)
482                .fetch_optional(self.pool())
483                .await?;
484
485        match row {
486            None => Err(AuthError::NotFound),
487            Some((data,)) => Ok(data),
488        }
489    }
490
491    /// Set a user's custom data. Also updates `updated_at`.
492    ///
493    /// Returns `Err(NotFound)` if no user with `id` exists.
494    pub async fn set_custom_data(&self, id: &UserId, data: &Value) -> Result<(), AuthError> {
495        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
496        let result = sqlx::query(
497            "UPDATE allowthem_users SET custom_data = ?1, updated_at = ?2 WHERE id = ?3",
498        )
499        .bind(sqlx::types::Json(data))
500        .bind(&now)
501        .bind(id)
502        .execute(self.pool())
503        .await?;
504
505        if result.rows_affected() == 0 {
506            return Err(AuthError::NotFound);
507        }
508        Ok(())
509    }
510
511    /// Delete (clear) a user's custom data by setting it to NULL. Also updates `updated_at`.
512    ///
513    /// Idempotent -- succeeds even if custom data is already NULL.
514    pub async fn delete_custom_data(&self, id: &UserId) -> Result<(), AuthError> {
515        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
516        sqlx::query("UPDATE allowthem_users SET custom_data = NULL, updated_at = ?1 WHERE id = ?2")
517            .bind(&now)
518            .bind(id)
519            .execute(self.pool())
520            .await?;
521
522        Ok(())
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use crate::handle::{AllowThem, AllowThemBuilder};
530
531    async fn setup() -> AllowThem {
532        AllowThemBuilder::new("sqlite::memory:")
533            .cookie_secure(false)
534            .build()
535            .await
536            .unwrap()
537    }
538
539    async fn make_user(db: &Db, tag: u32) -> crate::types::User {
540        let email = Email::new(format!("user{tag}@example.com")).unwrap();
541        db.create_user(email, "pw123456", None, None).await.unwrap()
542    }
543
544    #[tokio::test]
545    async fn user_cursor_encode_decode_roundtrip() {
546        let ath = setup().await;
547        let db = ath.db();
548        let user = make_user(db, 1).await;
549        let entries = db.list_users_paginated(10, None).await.unwrap();
550        assert_eq!(entries.len(), 1);
551        let cursor = UserCursor::from_entry(&entries[0]);
552        let encoded = cursor.encode();
553        let decoded = UserCursor::decode(&encoded).unwrap();
554        assert_eq!(decoded.id, user.id);
555    }
556
557    #[tokio::test]
558    async fn list_users_paginated_returns_first_page() {
559        let ath = setup().await;
560        let db = ath.db();
561        for i in 0..5 {
562            make_user(db, i).await;
563        }
564        let page = db.list_users_paginated(3, None).await.unwrap();
565        assert_eq!(page.len(), 3);
566    }
567
568    #[tokio::test]
569    async fn list_users_paginated_cursor_advances() {
570        let ath = setup().await;
571        let db = ath.db();
572        for i in 0..5 {
573            make_user(db, i + 10).await;
574        }
575        let page1 = db.list_users_paginated(3, None).await.unwrap();
576        assert_eq!(page1.len(), 3);
577        let cursor = UserCursor::from_entry(page1.last().unwrap());
578        let page2 = db.list_users_paginated(3, Some(&cursor)).await.unwrap();
579        assert_eq!(page2.len(), 2);
580        assert!(!page2.iter().any(|u| page1.iter().any(|v| v.id == u.id)));
581    }
582}