autter_core/database/
users.rs

1use crate::{
2    DataManager,
3    model::{
4        AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5        UserSettings, organizations::Organization,
6    },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14    /// Get a [`User`] from an SQL row.
15    pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16        User {
17            id: get!(x->0(i64)) as usize,
18            created: get!(x->1(i64)) as usize,
19            username: get!(x->2(String)),
20            password: get!(x->3(String)),
21            salt: get!(x->4(String)),
22            settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23            tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24            permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25            is_verified: get!(x->8(i32)) as i8 == 1,
26            notification_count: {
27                let x = get!(x->9(i32)) as usize;
28                // we're a little too close to the maximum count, clearly something's gone wrong
29                if x > usize::MAX - 1000 { 0 } else { x }
30            },
31            totp: get!(x->10(String)),
32            recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33            stripe_id: get!(x->12(String)),
34            ban_reason: get!(x->13(String)),
35            ban_expire: get!(x->14(i64)) as usize,
36            is_deactivated: get!(x->15(i32)) as i8 == 1,
37            checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38            last_policy_consent: get!(x->17(i64)) as usize,
39            linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40            badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41            principal_org: get!(x->20(i64) default (0)) as usize,
42            org_as_tenant: get!(x->21(i32)) as i8 == 1,
43            org_creation_credits: get!(x->22(i32)),
44            org_user_register_credits: get!(x->23(i32)),
45        }
46    }
47
48    auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
49    auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
50    auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
51
52    /// Get a user given just their ID. Returns the void user if the user doesn't exist.
53    ///
54    /// # Arguments
55    /// * `id` - the ID of the user
56    pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
57        let conn = match self.0.connect().await {
58            Ok(c) => c,
59            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
60        };
61
62        let res = query_row!(
63            &conn,
64            "SELECT * FROM a_users WHERE id = $1",
65            &[&(id as i64)],
66            |x| Ok(Self::get_user_from_row(x))
67        );
68
69        if res.is_err() {
70            return Ok(User::deleted());
71            // return Err(Error::UserNotFound);
72        }
73
74        Ok(res.unwrap())
75    }
76
77    /// Get a user given just their auth token.
78    ///
79    /// # Arguments
80    /// * `token` - the token of the user
81    pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
82        let conn = match self.0.connect().await {
83            Ok(c) => c,
84            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
85        };
86
87        let res = query_row!(
88            &conn,
89            "SELECT * FROM a_users WHERE tokens LIKE $1",
90            &[&format!("%\"{token}\"%")],
91            |x| Ok(Self::get_user_from_row(x))
92        );
93
94        if res.is_err() {
95            return Err(Error::UserNotFound);
96        }
97
98        Ok(res.unwrap())
99    }
100
101    /// Create a new user in the database.
102    ///
103    /// # Arguments
104    /// * `data` - a mock [`User`] object to insert
105    pub async fn create_user(&self, mut data: User) -> Result<User> {
106        if !self.0.0.security.registration_enabled {
107            return Err(Error::RegistrationDisabled);
108        }
109
110        data.username = data.username.to_lowercase();
111
112        // check values
113        if data.username.len() < 2 {
114            return Err(Error::DataTooShort("username".to_string()));
115        } else if data.username.len() > 32 {
116            return Err(Error::DataTooLong("username".to_string()));
117        }
118
119        if data.password.len() < 6 {
120            return Err(Error::DataTooShort("password".to_string()));
121        }
122
123        if self.0.0.banned_usernames.contains(&data.username) {
124            return Err(Error::MiscError("This username cannot be used".to_string()));
125        }
126
127        // make sure username isn't taken
128        if self.get_user_by_username(&data.username).await.is_ok() {
129            return Err(Error::UsernameInUse);
130        }
131
132        // ...
133        let conn = match self.0.connect().await {
134            Ok(c) => c,
135            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
136        };
137
138        let res = execute!(
139            &conn,
140            "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)",
141            params![
142                &(data.id as i64),
143                &(data.created as i64),
144                &data.username.to_lowercase(),
145                &data.password,
146                &data.salt,
147                &serde_json::to_string(&data.settings).unwrap(),
148                &serde_json::to_string(&data.tokens).unwrap(),
149                &serde_json::to_string(&data.permissions).unwrap(),
150                &(data.is_verified as i32),
151                &0_i32,
152                &String::new(),
153                "[]",
154                &data.stripe_id,
155                &data.ban_reason,
156                &(data.ban_expire as i64),
157                &(data.is_deactivated as i32),
158                &serde_json::to_string(&data.checkouts).unwrap(),
159                &(data.last_policy_consent as i64),
160                &serde_json::to_string(&data.linked_accounts).unwrap(),
161                &serde_json::to_string(&data.badges).unwrap(),
162                &if data.principal_org != 0 {
163                    Some(data.principal_org as i64)
164                } else {
165                    None
166                },
167                &((data.principal_org > 0) as i32),
168                &data.org_creation_credits,
169                &data.org_user_register_credits,
170            ]
171        );
172
173        if let Err(e) = res {
174            return Err(Error::DatabaseError(e.to_string()));
175        }
176
177        Ok(data)
178    }
179
180    /// Delete an existing user in the database.
181    ///
182    /// # Arguments
183    /// * `id` - the ID of the user
184    /// * `password` - the current password of the user
185    /// * `force` - if we should delete even if the given password is incorrect
186    pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
187        let user = self.get_user_by_id(id).await?;
188
189        if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
190            return Err(Error::IncorrectPassword);
191        }
192
193        let conn = match self.0.connect().await {
194            Ok(c) => c,
195            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
196        };
197
198        let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
199
200        if let Err(e) = res {
201            return Err(Error::DatabaseError(e.to_string()));
202        }
203
204        self.cache_clear_user(&user).await;
205
206        // delete notifications
207        let res = execute!(
208            &conn,
209            "DELETE FROM a_notifications WHERE owner = $1",
210            &[&(id as i64)]
211        );
212
213        if let Err(e) = res {
214            return Err(Error::DatabaseError(e.to_string()));
215        }
216
217        // delete warnings
218        let res = execute!(
219            &conn,
220            "DELETE FROM a_user_warnings WHERE receiver = $1",
221            &[&(id as i64)]
222        );
223
224        if let Err(e) = res {
225            return Err(Error::DatabaseError(e.to_string()));
226        }
227
228        // delete uploads
229        for upload in match self.1.get_uploads_by_owner_all(user.id).await {
230            Ok(x) => x,
231            Err(e) => return Err(Error::MiscError(e.to_string())),
232        } {
233            if let Err(e) = self.1.delete_upload(upload.id).await {
234                return Err(Error::MiscError(e.to_string()));
235            }
236        }
237
238        // ...
239        Ok(user)
240    }
241
242    pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
243        if !user.permissions.contains(&UserPermission::ManageVerified) {
244            return Err(Error::NotAllowed);
245        }
246
247        let other_user = self.get_user_by_id(id).await?;
248
249        let conn = match self.0.connect().await {
250            Ok(c) => c,
251            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
252        };
253
254        let res = execute!(
255            &conn,
256            "UPDATE a_users SET verified = $1 WHERE id = $2",
257            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
258        );
259
260        if let Err(e) = res {
261            return Err(Error::DatabaseError(e.to_string()));
262        }
263
264        self.cache_clear_user(&other_user).await;
265
266        // create audit log entry
267        self.create_audit_log_entry(AuditLogEntry::new(
268            user.id,
269            format!(
270                "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
271                other_user.id, x
272            ),
273        ))
274        .await?;
275
276        // ...
277        Ok(())
278    }
279
280    pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
281        if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
282            return Err(Error::NotAllowed);
283        }
284
285        let other_user = self.get_user_by_id(id).await?;
286
287        let conn = match self.0.connect().await {
288            Ok(c) => c,
289            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
290        };
291
292        let res = execute!(
293            &conn,
294            "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
295            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
296        );
297
298        if let Err(e) = res {
299            return Err(Error::DatabaseError(e.to_string()));
300        }
301
302        self.cache_clear_user(&other_user).await;
303
304        // create audit log entry (if we aren't the user that is being updated)
305        if user.id != other_user.id {
306            self.create_audit_log_entry(AuditLogEntry::new(
307                user.id,
308                format!(
309                    "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
310                    other_user.id, x
311                ),
312            ))
313            .await?;
314        }
315
316        // ...
317        Ok(())
318    }
319
320    pub async fn update_user_password(
321        &self,
322        id: usize,
323        from: String,
324        to: String,
325        user: User,
326        force: bool,
327    ) -> Result<()> {
328        // verify password
329        if !user.check_password(from.clone()) && !force {
330            return Err(Error::MiscError("Password does not match".to_string()));
331        }
332
333        // ...
334        let conn = match self.0.connect().await {
335            Ok(c) => c,
336            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
337        };
338
339        let new_salt = salt();
340        let new_password = hash_salted(to, new_salt.clone());
341        let res = execute!(
342            &conn,
343            "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
344            params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
345        );
346
347        if let Err(e) = res {
348            return Err(Error::DatabaseError(e.to_string()));
349        }
350
351        self.cache_clear_user(&user).await;
352        Ok(())
353    }
354
355    pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
356        // check value
357        if to.len() < 2 {
358            return Err(Error::DataTooShort("username".to_string()));
359        } else if to.len() > 32 {
360            return Err(Error::DataTooLong("username".to_string()));
361        }
362
363        if self.0.0.banned_usernames.contains(&to) {
364            return Err(Error::MiscError("This username cannot be used".to_string()));
365        }
366
367        let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
368            .multi_line(true)
369            .build()
370            .unwrap();
371
372        if regex.captures(&to).is_some() {
373            return Err(Error::MiscError(
374                "This username contains invalid characters".to_string(),
375            ));
376        }
377
378        // ...
379        let conn = match self.0.connect().await {
380            Ok(c) => c,
381            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
382        };
383
384        let res = execute!(
385            &conn,
386            "UPDATE a_users SET username = $1 WHERE id = $2",
387            params![&to.to_lowercase(), &(id as i64)]
388        );
389
390        if let Err(e) = res {
391            return Err(Error::DatabaseError(e.to_string()));
392        }
393
394        self.cache_clear_user(&user).await;
395        Ok(())
396    }
397
398    /// Validate a given TOTP code for the given profile.
399    pub fn check_totp(&self, ua: &User, code: &str) -> bool {
400        let totp = ua.totp(Some(
401            self.0
402                .0
403                .host
404                .replace("http://", "")
405                .replace("https://", "")
406                .replace(":", "_"),
407        ));
408
409        if let Some(totp) = totp {
410            return !code.is_empty()
411                && (totp.check_current(code).unwrap()
412                    | ua.recovery_codes.contains(&code.to_string()));
413        }
414
415        true
416    }
417
418    /// Generate 8 random recovery codes for TOTP.
419    pub fn generate_totp_recovery_codes() -> Vec<String> {
420        let mut out: Vec<String> = Vec::new();
421
422        for _ in 0..9 {
423            out.push(salt())
424        }
425
426        out
427    }
428
429    /// Update the profile's TOTP secret.
430    ///
431    /// # Arguments
432    /// * `id` - the ID of the user
433    /// * `secret` - the TOTP secret
434    /// * `recovery` - the TOTP recovery codes
435    pub async fn update_user_totp(
436        &self,
437        id: usize,
438        secret: &str,
439        recovery: &Vec<String>,
440    ) -> Result<()> {
441        let user = self.get_user_by_id(id).await?;
442
443        // update
444        let conn = match self.0.connect().await {
445            Ok(c) => c,
446            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
447        };
448
449        let res = execute!(
450            &conn,
451            "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
452            params![
453                &secret,
454                &serde_json::to_string(recovery).unwrap(),
455                &(id as i64)
456            ]
457        );
458
459        if let Err(e) = res {
460            return Err(Error::DatabaseError(e.to_string()));
461        }
462
463        self.cache_clear_user(&user).await;
464        Ok(())
465    }
466
467    /// Enable TOTP for a profile.
468    ///
469    /// # Arguments
470    /// * `id` - the ID of the user to enable TOTP for
471    /// * `user` - the user doing this
472    ///
473    /// # Returns
474    /// `Result<(secret, qr base64)>`
475    pub async fn enable_totp(
476        &self,
477        id: usize,
478        user: User,
479    ) -> Result<(String, String, Vec<String>)> {
480        let other_user = self.get_user_by_id(id).await?;
481
482        if other_user.id != user.id {
483            if other_user
484                .permissions
485                .contains(&UserPermission::ManageUsers)
486            {
487                // create audit log entry
488                self.create_audit_log_entry(AuditLogEntry::new(
489                    user.id,
490                    format!("invoked `enable_totp` with x value `{}`", other_user.id,),
491                ))
492                .await?;
493            } else {
494                return Err(Error::NotAllowed);
495            }
496        }
497
498        let secret = totp_rs::Secret::default().to_string();
499        let recovery = Self::generate_totp_recovery_codes();
500        self.update_user_totp(id, &secret, &recovery).await?;
501
502        // fetch profile again (with totp information)
503        let other_user = self.get_user_by_id(id).await?;
504
505        // get totp
506        let totp = other_user.totp(Some(
507            self.0
508                .0
509                .host
510                .replace("http://", "")
511                .replace("https://", "")
512                .replace(":", "_"),
513        ));
514
515        if totp.is_none() {
516            return Err(Error::MiscError("Failed to get TOTP code".to_string()));
517        }
518
519        let totp = totp.unwrap();
520
521        // generate qr
522        let qr = match totp.get_qr_base64() {
523            Ok(q) => q,
524            Err(e) => return Err(Error::MiscError(e.to_string())),
525        };
526
527        // return
528        Ok((totp.get_secret_base32(), qr, recovery))
529    }
530
531    /// Get the given user's principal organization.
532    pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
533        if user.principal_org == 0 {
534            return None;
535        }
536
537        if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
538            Some(x)
539        } else {
540            self.update_user_principal_org(user.id, None)
541                .await
542                .expect("failed to clear user principal org");
543
544            None
545        }
546    }
547
548    pub async fn cache_clear_user(&self, user: &User) {
549        self.0.1.remove(format!("srmp.user:{}", user.id)).await;
550        self.0
551            .1
552            .remove(format!("srmp.user:{}", user.username))
553            .await;
554    }
555
556    auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
557    auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
558    auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
559    auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
560    auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
561    auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
562    auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
563    auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
564    auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
565    auto_method!(update_user_principal_org(Option<i64>)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
566    auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
567
568    auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
569    auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
570
571    auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
572    auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
573    auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
574
575    auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
576    auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
577
578    auto_method!(incr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
579    auto_method!(decr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
580}