Skip to main content

autter_core/database/
users.rs

1use crate::{
2    DataManager,
3    model::{
4        AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5        UserSettings, organizations::Organization,
6    },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14    /// Get a [`User`] from an SQL row.
15    pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16        User {
17            id: get!(x->0(i64)) as usize,
18            created: get!(x->1(i64)) as usize,
19            username: get!(x->2(String)),
20            password: get!(x->3(String)),
21            salt: get!(x->4(String)),
22            settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23            tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24            permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25            is_verified: get!(x->8(i32)) as i8 == 1,
26            notification_count: {
27                let x = get!(x->9(i32)) as usize;
28                // we're a little too close to the maximum count, clearly something's gone wrong
29                if x > usize::MAX - 1000 { 0 } else { x }
30            },
31            totp: get!(x->10(String)),
32            recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33            stripe_id: get!(x->12(String)),
34            ban_reason: get!(x->13(String)),
35            ban_expire: get!(x->14(i64)) as usize,
36            is_deactivated: get!(x->15(i32)) as i8 == 1,
37            checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38            last_policy_consent: get!(x->17(i64)) as usize,
39            linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40            badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41            principal_org: get!(x->20(i64) default (0)) as usize,
42            org_as_tenant: get!(x->21(i32)) as i8 == 1,
43            org_creation_credits: get!(x->22(i32)),
44            org_user_register_credits: get!(x->23(i32)),
45            is_super_verified: get!(x->24(i32)) as i8 == 1,
46        }
47    }
48
49    auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
50    auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
51    auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
52
53    /// Get a user given just their ID. Returns the void user if the user doesn't exist.
54    ///
55    /// # Arguments
56    /// * `id` - the ID of the user
57    pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
58        let conn = match self.0.connect().await {
59            Ok(c) => c,
60            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
61        };
62
63        let res = query_row!(
64            &conn,
65            "SELECT * FROM a_users WHERE id = $1",
66            &[&(id as i64)],
67            |x| Ok(Self::get_user_from_row(x))
68        );
69
70        if res.is_err() {
71            return Ok(User::deleted());
72            // return Err(Error::UserNotFound);
73        }
74
75        Ok(res.unwrap())
76    }
77
78    /// Get a user given just their auth token.
79    ///
80    /// # Arguments
81    /// * `token` - the token of the user
82    pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
83        let conn = match self.0.connect().await {
84            Ok(c) => c,
85            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
86        };
87
88        let res = query_row!(
89            &conn,
90            "SELECT * FROM a_users WHERE tokens LIKE $1",
91            &[&format!("%\"{token}\"%")],
92            |x| Ok(Self::get_user_from_row(x))
93        );
94
95        if res.is_err() {
96            return Err(Error::UserNotFound);
97        }
98
99        Ok(res.unwrap())
100    }
101
102    /// Create a new user in the database.
103    ///
104    /// # Arguments
105    /// * `data` - a mock [`User`] object to insert
106    pub async fn create_user(&self, mut data: User) -> Result<User> {
107        if !self.0.0.security.registration_enabled {
108            return Err(Error::RegistrationDisabled);
109        }
110
111        data.username = data.username.to_lowercase();
112
113        // check values
114        if data.username.len() < 2 {
115            return Err(Error::DataTooShort("username".to_string()));
116        } else if data.username.len() > 32 {
117            return Err(Error::DataTooLong("username".to_string()));
118        }
119
120        if data.password.len() < 6 {
121            return Err(Error::DataTooShort("password".to_string()));
122        }
123
124        if self.0.0.banned_usernames.contains(&data.username) {
125            return Err(Error::MiscError("This username cannot be used".to_string()));
126        }
127
128        // make sure username isn't taken
129        if self.get_user_by_username(&data.username).await.is_ok() {
130            return Err(Error::UsernameInUse);
131        }
132
133        // ...
134        let conn = match self.0.connect().await {
135            Ok(c) => c,
136            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
137        };
138
139        let res = execute!(
140            &conn,
141            "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25)",
142            params![
143                &(data.id as i64),
144                &(data.created as i64),
145                &data.username.to_lowercase(),
146                &data.password,
147                &data.salt,
148                &serde_json::to_string(&data.settings).unwrap(),
149                &serde_json::to_string(&data.tokens).unwrap(),
150                &serde_json::to_string(&data.permissions).unwrap(),
151                &(data.is_verified as i32),
152                &0_i32,
153                &String::new(),
154                "[]",
155                &data.stripe_id,
156                &data.ban_reason,
157                &(data.ban_expire as i64),
158                &(data.is_deactivated as i32),
159                &serde_json::to_string(&data.checkouts).unwrap(),
160                &(data.last_policy_consent as i64),
161                &serde_json::to_string(&data.linked_accounts).unwrap(),
162                &serde_json::to_string(&data.badges).unwrap(),
163                &if data.principal_org != 0 {
164                    Some(data.principal_org as i64)
165                } else {
166                    None
167                },
168                &((data.principal_org > 0) as i32),
169                &data.org_creation_credits,
170                &data.org_user_register_credits,
171                &data.is_super_verified
172            ]
173        );
174
175        if let Err(e) = res {
176            return Err(Error::DatabaseError(e.to_string()));
177        }
178
179        Ok(data)
180    }
181
182    /// Delete an existing user in the database.
183    ///
184    /// # Arguments
185    /// * `id` - the ID of the user
186    /// * `password` - the current password of the user
187    /// * `force` - if we should delete even if the given password is incorrect
188    pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
189        let user = self.get_user_by_id(id).await?;
190
191        if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
192            return Err(Error::IncorrectPassword);
193        }
194
195        let conn = match self.0.connect().await {
196            Ok(c) => c,
197            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
198        };
199
200        let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
201
202        if let Err(e) = res {
203            return Err(Error::DatabaseError(e.to_string()));
204        }
205
206        self.cache_clear_user(&user).await;
207
208        // delete uploads
209        for upload in match self.1.get_uploads_by_owner_all(user.id).await {
210            Ok(x) => x,
211            Err(e) => return Err(Error::MiscError(e.to_string())),
212        } {
213            if let Err(e) = self.1.delete_upload(upload.id).await {
214                return Err(Error::MiscError(e.to_string()));
215            }
216        }
217
218        // ...
219        Ok(user)
220    }
221
222    pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
223        if !user.permissions.contains(&UserPermission::ManageVerified) {
224            return Err(Error::NotAllowed);
225        }
226
227        let other_user = self.get_user_by_id(id).await?;
228
229        let conn = match self.0.connect().await {
230            Ok(c) => c,
231            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
232        };
233
234        let res = execute!(
235            &conn,
236            "UPDATE a_users SET is_verified = $1 WHERE id = $2",
237            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
238        );
239
240        if let Err(e) = res {
241            return Err(Error::DatabaseError(e.to_string()));
242        }
243
244        self.cache_clear_user(&other_user).await;
245
246        // create audit log entry
247        self.create_audit_log_entry(AuditLogEntry::new(
248            user.id,
249            format!(
250                "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
251                other_user.id, x
252            ),
253        ))
254        .await?;
255
256        // ...
257        Ok(())
258    }
259
260    /// Update a user's super verified status. This is intended to be called after they
261    /// successfully pay for verification.
262    pub async fn update_user_super_verified_status(&self, id: usize, x: bool) -> Result<()> {
263        let other_user = self.get_user_by_id(id).await?;
264
265        let conn = match self.0.connect().await {
266            Ok(c) => c,
267            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
268        };
269
270        let res = execute!(
271            &conn,
272            "UPDATE a_users SET is_super_verified = $1 WHERE id = $2",
273            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
274        );
275
276        if let Err(e) = res {
277            return Err(Error::DatabaseError(e.to_string()));
278        }
279
280        self.cache_clear_user(&other_user).await;
281
282        // create audit log entry
283        self.create_audit_log_entry(AuditLogEntry::new(
284            other_user.id,
285            format!(
286                "invoked `update_user_super_verified_status` with x value `{}` and y value `{}`",
287                other_user.id, x
288            ),
289        ))
290        .await?;
291
292        // ...
293        Ok(())
294    }
295
296    pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
297        if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
298            return Err(Error::NotAllowed);
299        }
300
301        let other_user = self.get_user_by_id(id).await?;
302
303        let conn = match self.0.connect().await {
304            Ok(c) => c,
305            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
306        };
307
308        let res = execute!(
309            &conn,
310            "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
311            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
312        );
313
314        if let Err(e) = res {
315            return Err(Error::DatabaseError(e.to_string()));
316        }
317
318        self.cache_clear_user(&other_user).await;
319
320        // create audit log entry (if we aren't the user that is being updated)
321        if user.id != other_user.id {
322            self.create_audit_log_entry(AuditLogEntry::new(
323                user.id,
324                format!(
325                    "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
326                    other_user.id, x
327                ),
328            ))
329            .await?;
330        }
331
332        // ...
333        Ok(())
334    }
335
336    pub async fn update_user_password(
337        &self,
338        id: usize,
339        from: String,
340        to: String,
341        user: User,
342        force: bool,
343    ) -> Result<()> {
344        // verify password
345        if !user.check_password(from.clone()) && !force {
346            return Err(Error::MiscError("Password does not match".to_string()));
347        }
348
349        // ...
350        let conn = match self.0.connect().await {
351            Ok(c) => c,
352            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
353        };
354
355        let new_salt = salt();
356        let new_password = hash_salted(to, new_salt.clone());
357        let res = execute!(
358            &conn,
359            "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
360            params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
361        );
362
363        if let Err(e) = res {
364            return Err(Error::DatabaseError(e.to_string()));
365        }
366
367        self.cache_clear_user(&user).await;
368        Ok(())
369    }
370
371    pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
372        // check value
373        if to.len() < 2 {
374            return Err(Error::DataTooShort("username".to_string()));
375        } else if to.len() > 32 {
376            return Err(Error::DataTooLong("username".to_string()));
377        }
378
379        if self.0.0.banned_usernames.contains(&to) {
380            return Err(Error::MiscError("This username cannot be used".to_string()));
381        }
382
383        let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
384            .multi_line(true)
385            .build()
386            .unwrap();
387
388        if regex.captures(&to).is_some() {
389            return Err(Error::MiscError(
390                "This username contains invalid characters".to_string(),
391            ));
392        }
393
394        // ...
395        let conn = match self.0.connect().await {
396            Ok(c) => c,
397            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
398        };
399
400        let res = execute!(
401            &conn,
402            "UPDATE a_users SET username = $1 WHERE id = $2",
403            params![&to.to_lowercase(), &(id as i64)]
404        );
405
406        if let Err(e) = res {
407            return Err(Error::DatabaseError(e.to_string()));
408        }
409
410        self.cache_clear_user(&user).await;
411        Ok(())
412    }
413
414    /// Validate a given TOTP code for the given profile.
415    pub fn check_totp(&self, ua: &User, code: &str) -> bool {
416        let totp = ua.totp(Some(
417            self.0
418                .0
419                .host
420                .replace("http://", "")
421                .replace("https://", "")
422                .replace(":", "_"),
423        ));
424
425        if let Some(totp) = totp {
426            return !code.is_empty()
427                && (totp.check_current(code).unwrap()
428                    | ua.recovery_codes.contains(&code.to_string()));
429        }
430
431        true
432    }
433
434    /// Generate 8 random recovery codes for TOTP.
435    pub fn generate_totp_recovery_codes() -> Vec<String> {
436        let mut out: Vec<String> = Vec::new();
437
438        for _ in 0..9 {
439            out.push(salt())
440        }
441
442        out
443    }
444
445    /// Update the profile's TOTP secret.
446    ///
447    /// # Arguments
448    /// * `id` - the ID of the user
449    /// * `secret` - the TOTP secret
450    /// * `recovery` - the TOTP recovery codes
451    pub async fn update_user_totp(
452        &self,
453        id: usize,
454        secret: &str,
455        recovery: &Vec<String>,
456    ) -> Result<()> {
457        let user = self.get_user_by_id(id).await?;
458
459        // update
460        let conn = match self.0.connect().await {
461            Ok(c) => c,
462            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
463        };
464
465        let res = execute!(
466            &conn,
467            "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
468            params![
469                &secret,
470                &serde_json::to_string(recovery).unwrap(),
471                &(id as i64)
472            ]
473        );
474
475        if let Err(e) = res {
476            return Err(Error::DatabaseError(e.to_string()));
477        }
478
479        self.cache_clear_user(&user).await;
480        Ok(())
481    }
482
483    /// Enable TOTP for a profile.
484    ///
485    /// # Arguments
486    /// * `id` - the ID of the user to enable TOTP for
487    /// * `user` - the user doing this
488    ///
489    /// # Returns
490    /// `Result<(secret, qr base64)>`
491    pub async fn enable_totp(
492        &self,
493        id: usize,
494        user: User,
495    ) -> Result<(String, String, Vec<String>)> {
496        let other_user = self.get_user_by_id(id).await?;
497
498        if other_user.id != user.id {
499            if other_user
500                .permissions
501                .contains(&UserPermission::ManageUsers)
502            {
503                // create audit log entry
504                self.create_audit_log_entry(AuditLogEntry::new(
505                    user.id,
506                    format!("invoked `enable_totp` with x value `{}`", other_user.id,),
507                ))
508                .await?;
509            } else {
510                return Err(Error::NotAllowed);
511            }
512        }
513
514        let secret = totp_rs::Secret::default().to_string();
515        let recovery = Self::generate_totp_recovery_codes();
516        self.update_user_totp(id, &secret, &recovery).await?;
517
518        // fetch profile again (with totp information)
519        let other_user = self.get_user_by_id(id).await?;
520
521        // get totp
522        let totp = other_user.totp(Some(
523            self.0
524                .0
525                .host
526                .replace("http://", "")
527                .replace("https://", "")
528                .replace(":", "_"),
529        ));
530
531        if totp.is_none() {
532            return Err(Error::MiscError("Failed to get TOTP code".to_string()));
533        }
534
535        let totp = totp.unwrap();
536
537        // generate qr
538        let qr = match totp.get_qr_base64() {
539            Ok(q) => q,
540            Err(e) => return Err(Error::MiscError(e.to_string())),
541        };
542
543        // return
544        Ok((totp.get_secret_base32(), qr, recovery))
545    }
546
547    /// Get the given user's principal organization.
548    pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
549        if user.principal_org == 0 {
550            return None;
551        }
552
553        if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
554            Some(x)
555        } else {
556            self.update_user_principal_org(user.id, None)
557                .await
558                .expect("failed to clear user principal org");
559
560            None
561        }
562    }
563
564    pub async fn cache_clear_user(&self, user: &User) {
565        self.0.1.remove(format!("srmp.user:{}", user.id)).await;
566        self.0
567            .1
568            .remove(format!("srmp.user:{}", user.username))
569            .await;
570    }
571
572    auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
573    auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
574    auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
575    auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
576    auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
577    auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
578    auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
579    auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
580    auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
581    auto_method!(update_user_principal_org(Option<i64>)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
582    auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
583
584    auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
585    auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
586
587    auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
588    auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
589    auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
590
591    auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
592    auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
593
594    auto_method!(incr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
595    auto_method!(decr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
596}