autter_core/database/
users.rs

1use crate::{
2    DataManager,
3    model::{
4        AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5        UserSettings, organizations::Organization,
6    },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14    /// Get a [`User`] from an SQL row.
15    pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16        User {
17            id: get!(x->0(i64)) as usize,
18            created: get!(x->1(i64)) as usize,
19            username: get!(x->2(String)),
20            password: get!(x->3(String)),
21            salt: get!(x->4(String)),
22            settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23            tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24            permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25            is_verified: get!(x->8(i32)) as i8 == 1,
26            notification_count: {
27                let x = get!(x->9(i32)) as usize;
28                // we're a little too close to the maximum count, clearly something's gone wrong
29                if x > usize::MAX - 1000 { 0 } else { x }
30            },
31            totp: get!(x->10(String)),
32            recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33            stripe_id: get!(x->12(String)),
34            ban_reason: get!(x->13(String)),
35            ban_expire: get!(x->14(i64)) as usize,
36            is_deactivated: get!(x->15(i32)) as i8 == 1,
37            checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38            last_policy_consent: get!(x->17(i64)) as usize,
39            linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40            badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41            principal_org: get!(x->20(i64)) as usize,
42            org_as_tenant: get!(x->21(i32)) as i8 == 1,
43            org_creation_credits: get!(x->22(i32)),
44        }
45    }
46
47    auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
48    auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
49    auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
50
51    /// Get a user given just their ID. Returns the void user if the user doesn't exist.
52    ///
53    /// # Arguments
54    /// * `id` - the ID of the user
55    pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
56        let conn = match self.0.connect().await {
57            Ok(c) => c,
58            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
59        };
60
61        let res = query_row!(
62            &conn,
63            "SELECT * FROM a_users WHERE id = $1",
64            &[&(id as i64)],
65            |x| Ok(Self::get_user_from_row(x))
66        );
67
68        if res.is_err() {
69            return Ok(User::deleted());
70            // return Err(Error::UserNotFound);
71        }
72
73        Ok(res.unwrap())
74    }
75
76    /// Get a user given just their auth token.
77    ///
78    /// # Arguments
79    /// * `token` - the token of the user
80    pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
81        let conn = match self.0.connect().await {
82            Ok(c) => c,
83            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
84        };
85
86        let res = query_row!(
87            &conn,
88            "SELECT * FROM a_users WHERE tokens LIKE $1",
89            &[&format!("%\"{token}\"%")],
90            |x| Ok(Self::get_user_from_row(x))
91        );
92
93        if res.is_err() {
94            return Err(Error::UserNotFound);
95        }
96
97        Ok(res.unwrap())
98    }
99
100    /// Create a new user in the database.
101    ///
102    /// # Arguments
103    /// * `data` - a mock [`User`] object to insert
104    pub async fn create_user(&self, mut data: User) -> Result<User> {
105        if !self.0.0.security.registration_enabled {
106            return Err(Error::RegistrationDisabled);
107        }
108
109        data.username = data.username.to_lowercase();
110
111        // check values
112        if data.username.len() < 2 {
113            return Err(Error::DataTooShort("username".to_string()));
114        } else if data.username.len() > 32 {
115            return Err(Error::DataTooLong("username".to_string()));
116        }
117
118        if data.password.len() < 6 {
119            return Err(Error::DataTooShort("password".to_string()));
120        }
121
122        if self.0.0.banned_usernames.contains(&data.username) {
123            return Err(Error::MiscError("This username cannot be used".to_string()));
124        }
125
126        // make sure username isn't taken
127        if self.get_user_by_username(&data.username).await.is_ok() {
128            return Err(Error::UsernameInUse);
129        }
130
131        // ...
132        let conn = match self.0.connect().await {
133            Ok(c) => c,
134            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
135        };
136
137        let res = execute!(
138            &conn,
139            "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23)",
140            params![
141                &(data.id as i64),
142                &(data.created as i64),
143                &data.username.to_lowercase(),
144                &data.password,
145                &data.salt,
146                &serde_json::to_string(&data.settings).unwrap(),
147                &serde_json::to_string(&data.tokens).unwrap(),
148                &serde_json::to_string(&data.permissions).unwrap(),
149                &(data.is_verified as i32),
150                &0_i32,
151                &String::new(),
152                "[]",
153                &data.stripe_id,
154                &data.ban_reason,
155                &(data.ban_expire as i64),
156                &(data.is_deactivated as i32),
157                &serde_json::to_string(&data.checkouts).unwrap(),
158                &(data.last_policy_consent as i64),
159                &serde_json::to_string(&data.linked_accounts).unwrap(),
160                &serde_json::to_string(&data.badges).unwrap(),
161                &(data.principal_org as i64),
162                &((data.principal_org > 0) as i32),
163                &data.org_creation_credits
164            ]
165        );
166
167        if let Err(e) = res {
168            return Err(Error::DatabaseError(e.to_string()));
169        }
170
171        Ok(data)
172    }
173
174    /// Delete an existing user in the database.
175    ///
176    /// # Arguments
177    /// * `id` - the ID of the user
178    /// * `password` - the current password of the user
179    /// * `force` - if we should delete even if the given password is incorrect
180    pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
181        let user = self.get_user_by_id(id).await?;
182
183        if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
184            return Err(Error::IncorrectPassword);
185        }
186
187        let conn = match self.0.connect().await {
188            Ok(c) => c,
189            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
190        };
191
192        let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
193
194        if let Err(e) = res {
195            return Err(Error::DatabaseError(e.to_string()));
196        }
197
198        self.cache_clear_user(&user).await;
199
200        // delete notifications
201        let res = execute!(
202            &conn,
203            "DELETE FROM a_notifications WHERE owner = $1",
204            &[&(id as i64)]
205        );
206
207        if let Err(e) = res {
208            return Err(Error::DatabaseError(e.to_string()));
209        }
210
211        // delete warnings
212        let res = execute!(
213            &conn,
214            "DELETE FROM a_user_warnings WHERE receiver = $1",
215            &[&(id as i64)]
216        );
217
218        if let Err(e) = res {
219            return Err(Error::DatabaseError(e.to_string()));
220        }
221
222        // delete uploads
223        for upload in match self.1.get_uploads_by_owner_all(user.id).await {
224            Ok(x) => x,
225            Err(e) => return Err(Error::MiscError(e.to_string())),
226        } {
227            if let Err(e) = self.1.delete_upload(upload.id).await {
228                return Err(Error::MiscError(e.to_string()));
229            }
230        }
231
232        // ...
233        Ok(user)
234    }
235
236    pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
237        if !user.permissions.contains(&UserPermission::ManageVerified) {
238            return Err(Error::NotAllowed);
239        }
240
241        let other_user = self.get_user_by_id(id).await?;
242
243        let conn = match self.0.connect().await {
244            Ok(c) => c,
245            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
246        };
247
248        let res = execute!(
249            &conn,
250            "UPDATE a_users SET verified = $1 WHERE id = $2",
251            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
252        );
253
254        if let Err(e) = res {
255            return Err(Error::DatabaseError(e.to_string()));
256        }
257
258        self.cache_clear_user(&other_user).await;
259
260        // create audit log entry
261        self.create_audit_log_entry(AuditLogEntry::new(
262            user.id,
263            format!(
264                "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
265                other_user.id, x
266            ),
267        ))
268        .await?;
269
270        // ...
271        Ok(())
272    }
273
274    pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
275        if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
276            return Err(Error::NotAllowed);
277        }
278
279        let other_user = self.get_user_by_id(id).await?;
280
281        let conn = match self.0.connect().await {
282            Ok(c) => c,
283            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
284        };
285
286        let res = execute!(
287            &conn,
288            "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
289            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
290        );
291
292        if let Err(e) = res {
293            return Err(Error::DatabaseError(e.to_string()));
294        }
295
296        self.cache_clear_user(&other_user).await;
297
298        // create audit log entry (if we aren't the user that is being updated)
299        if user.id != other_user.id {
300            self.create_audit_log_entry(AuditLogEntry::new(
301                user.id,
302                format!(
303                    "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
304                    other_user.id, x
305                ),
306            ))
307            .await?;
308        }
309
310        // ...
311        Ok(())
312    }
313
314    pub async fn update_user_password(
315        &self,
316        id: usize,
317        from: String,
318        to: String,
319        user: User,
320        force: bool,
321    ) -> Result<()> {
322        // verify password
323        if !user.check_password(from.clone()) && !force {
324            return Err(Error::MiscError("Password does not match".to_string()));
325        }
326
327        // ...
328        let conn = match self.0.connect().await {
329            Ok(c) => c,
330            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
331        };
332
333        let new_salt = salt();
334        let new_password = hash_salted(to, new_salt.clone());
335        let res = execute!(
336            &conn,
337            "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
338            params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
339        );
340
341        if let Err(e) = res {
342            return Err(Error::DatabaseError(e.to_string()));
343        }
344
345        self.cache_clear_user(&user).await;
346        Ok(())
347    }
348
349    pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
350        // check value
351        if to.len() < 2 {
352            return Err(Error::DataTooShort("username".to_string()));
353        } else if to.len() > 32 {
354            return Err(Error::DataTooLong("username".to_string()));
355        }
356
357        if self.0.0.banned_usernames.contains(&to) {
358            return Err(Error::MiscError("This username cannot be used".to_string()));
359        }
360
361        let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
362            .multi_line(true)
363            .build()
364            .unwrap();
365
366        if regex.captures(&to).is_some() {
367            return Err(Error::MiscError(
368                "This username contains invalid characters".to_string(),
369            ));
370        }
371
372        // ...
373        let conn = match self.0.connect().await {
374            Ok(c) => c,
375            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
376        };
377
378        let res = execute!(
379            &conn,
380            "UPDATE a_users SET username = $1 WHERE id = $2",
381            params![&to.to_lowercase(), &(id as i64)]
382        );
383
384        if let Err(e) = res {
385            return Err(Error::DatabaseError(e.to_string()));
386        }
387
388        self.cache_clear_user(&user).await;
389        Ok(())
390    }
391
392    /// Validate a given TOTP code for the given profile.
393    pub fn check_totp(&self, ua: &User, code: &str) -> bool {
394        let totp = ua.totp(Some(
395            self.0
396                .0
397                .host
398                .replace("http://", "")
399                .replace("https://", "")
400                .replace(":", "_"),
401        ));
402
403        if let Some(totp) = totp {
404            return !code.is_empty()
405                && (totp.check_current(code).unwrap()
406                    | ua.recovery_codes.contains(&code.to_string()));
407        }
408
409        true
410    }
411
412    /// Generate 8 random recovery codes for TOTP.
413    pub fn generate_totp_recovery_codes() -> Vec<String> {
414        let mut out: Vec<String> = Vec::new();
415
416        for _ in 0..9 {
417            out.push(salt())
418        }
419
420        out
421    }
422
423    /// Update the profile's TOTP secret.
424    ///
425    /// # Arguments
426    /// * `id` - the ID of the user
427    /// * `secret` - the TOTP secret
428    /// * `recovery` - the TOTP recovery codes
429    pub async fn update_user_totp(
430        &self,
431        id: usize,
432        secret: &str,
433        recovery: &Vec<String>,
434    ) -> Result<()> {
435        let user = self.get_user_by_id(id).await?;
436
437        // update
438        let conn = match self.0.connect().await {
439            Ok(c) => c,
440            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
441        };
442
443        let res = execute!(
444            &conn,
445            "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
446            params![
447                &secret,
448                &serde_json::to_string(recovery).unwrap(),
449                &(id as i64)
450            ]
451        );
452
453        if let Err(e) = res {
454            return Err(Error::DatabaseError(e.to_string()));
455        }
456
457        self.cache_clear_user(&user).await;
458        Ok(())
459    }
460
461    /// Enable TOTP for a profile.
462    ///
463    /// # Arguments
464    /// * `id` - the ID of the user to enable TOTP for
465    /// * `user` - the user doing this
466    ///
467    /// # Returns
468    /// `Result<(secret, qr base64)>`
469    pub async fn enable_totp(
470        &self,
471        id: usize,
472        user: User,
473    ) -> Result<(String, String, Vec<String>)> {
474        let other_user = self.get_user_by_id(id).await?;
475
476        if other_user.id != user.id {
477            if other_user
478                .permissions
479                .contains(&UserPermission::ManageUsers)
480            {
481                // create audit log entry
482                self.create_audit_log_entry(AuditLogEntry::new(
483                    user.id,
484                    format!("invoked `enable_totp` with x value `{}`", other_user.id,),
485                ))
486                .await?;
487            } else {
488                return Err(Error::NotAllowed);
489            }
490        }
491
492        let secret = totp_rs::Secret::default().to_string();
493        let recovery = Self::generate_totp_recovery_codes();
494        self.update_user_totp(id, &secret, &recovery).await?;
495
496        // fetch profile again (with totp information)
497        let other_user = self.get_user_by_id(id).await?;
498
499        // get totp
500        let totp = other_user.totp(Some(
501            self.0
502                .0
503                .host
504                .replace("http://", "")
505                .replace("https://", "")
506                .replace(":", "_"),
507        ));
508
509        if totp.is_none() {
510            return Err(Error::MiscError("Failed to get TOTP code".to_string()));
511        }
512
513        let totp = totp.unwrap();
514
515        // generate qr
516        let qr = match totp.get_qr_base64() {
517            Ok(q) => q,
518            Err(e) => return Err(Error::MiscError(e.to_string())),
519        };
520
521        // return
522        Ok((totp.get_secret_base32(), qr, recovery))
523    }
524
525    /// Get the given user's principal organization.
526    pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
527        if user.principal_org == 0 {
528            return None;
529        }
530
531        if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
532            Some(x)
533        } else {
534            self.update_user_principal_org(user.id, 0)
535                .await
536                .expect("failed to clear user principal org");
537
538            None
539        }
540    }
541
542    pub async fn cache_clear_user(&self, user: &User) {
543        self.0.1.remove(format!("srmp.user:{}", user.id)).await;
544        self.0
545            .1
546            .remove(format!("srmp.user:{}", user.username))
547            .await;
548    }
549
550    auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
551    auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
552    auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
553    auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
554    auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
555    auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
556    auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
557    auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
558    auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
559    auto_method!(update_user_principal_org(i64)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
560    auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
561
562    auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
563    auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
564
565    auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
566    auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
567    auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
568
569    auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
570    auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
571}