Skip to main content

autter_core/database/
users.rs

1use crate::{
2    DataManager,
3    model::{
4        AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5        UserSettings, organizations::Organization,
6    },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core2::{auto_method, model::id::Id};
11use tritools::encoding::{hash_salted, salt};
12
13#[inline]
14fn transform_uid(uid: &Id) -> usize {
15    uid.as_usize() / 2
16}
17
18impl DataManager {
19    /// Get a [`User`] from an SQL row.
20    pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
21        User {
22            id: Id::Legacy(get!(x->0(i64)) as usize),
23            created: get!(x->1(i64)) as u128,
24            username: get!(x->2(String)),
25            password: get!(x->3(String)),
26            salt: get!(x->4(String)),
27            settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
28            tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
29            permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
30            is_verified: get!(x->8(i32)) as i8 == 1,
31            notification_count: {
32                let x = get!(x->9(i32)) as usize;
33                // we're a little too close to the maximum count, clearly something's gone wrong
34                if x > usize::MAX - 1000 { 0 } else { x }
35            },
36            totp: get!(x->10(String)),
37            recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
38            stripe_id: get!(x->12(String)),
39            ban_reason: get!(x->13(String)),
40            ban_expire: get!(x->14(i64)) as usize,
41            is_deactivated: get!(x->15(i32)) as i8 == 1,
42            checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
43            last_policy_consent: get!(x->17(i64)) as u128,
44            linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
45            badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
46            principal_org: Id::deserialize(
47                &get!(x->20(String) default (Id::default().printable())),
48            ),
49            org_as_tenant: get!(x->21(i32)) as i8 == 1,
50            org_creation_credits: get!(x->22(i32)),
51            org_user_register_credits: get!(x->23(i32)),
52            is_super_verified: get!(x->24(i32)) as i8 == 1,
53        }
54    }
55
56    auto_method!(get_user_by_id(Id :: usize)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1::bigint" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
57    auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
58    auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
59
60    /// Get a user given just their ID. Returns the void user if the user doesn't exist.
61    ///
62    /// # Arguments
63    /// * `id` - the ID of the user
64    pub async fn get_user_by_id_with_void(&self, id: &Id) -> Result<User> {
65        let conn = match self.0.connect().await {
66            Ok(c) => c,
67            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
68        };
69
70        let res = query_row!(
71            &conn,
72            "SELECT * FROM a_users WHERE id = $1::bigint",
73            &[&id.printable()],
74            |x| Ok(Self::get_user_from_row(x))
75        );
76
77        if res.is_err() {
78            return Ok(User::deleted());
79            // return Err(Error::UserNotFound);
80        }
81
82        Ok(res.unwrap())
83    }
84
85    /// Get a user given just their auth token.
86    ///
87    /// # Arguments
88    /// * `token` - the token of the user
89    pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
90        let conn = match self.0.connect().await {
91            Ok(c) => c,
92            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
93        };
94
95        let res = query_row!(
96            &conn,
97            "SELECT * FROM a_users WHERE tokens LIKE $1",
98            &[&format!("%\"{token}\"%")],
99            |x| Ok(Self::get_user_from_row(x))
100        );
101
102        if res.is_err() {
103            return Err(Error::UserNotFound);
104        }
105
106        Ok(res.unwrap())
107    }
108
109    /// Create a new user in the database.
110    ///
111    /// # Arguments
112    /// * `data` - a mock [`User`] object to insert
113    pub async fn create_user(&self, mut data: User) -> Result<User> {
114        if !self.0.0.security.registration_enabled {
115            return Err(Error::RegistrationDisabled);
116        }
117
118        data.username = data.username.to_lowercase();
119
120        // check values
121        if data.username.len() < 2 {
122            return Err(Error::DataTooShort("username".to_string()));
123        } else if data.username.len() > 32 {
124            return Err(Error::DataTooLong("username".to_string()));
125        }
126
127        if data.password.len() < 6 {
128            return Err(Error::DataTooShort("password".to_string()));
129        }
130
131        if self.0.0.banned_usernames.contains(&data.username) {
132            return Err(Error::MiscError("This username cannot be used".to_string()));
133        }
134
135        // make sure username isn't taken
136        if self.get_user_by_username(&data.username).await.is_ok() {
137            return Err(Error::UsernameInUse);
138        }
139
140        // ...
141        let conn = match self.0.connect().await {
142            Ok(c) => c,
143            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
144        };
145
146        let res = execute!(
147            &conn,
148            "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)",
149            params![
150                &(transform_uid(&data.id) as i64),
151                &(data.created as i64),
152                &data.username.to_lowercase(),
153                &data.password,
154                &data.salt,
155                &serde_json::to_string(&data.settings).unwrap(),
156                &serde_json::to_string(&data.tokens).unwrap(),
157                &serde_json::to_string(&data.permissions).unwrap(),
158                &(data.is_verified as i32),
159                &0_i32,
160                &String::new(),
161                "[]",
162                &data.stripe_id,
163                &data.ban_reason,
164                &(data.ban_expire as i64),
165                &(data.is_deactivated as i32),
166                &serde_json::to_string(&data.checkouts).unwrap(),
167                &(data.last_policy_consent as i64),
168                &serde_json::to_string(&data.linked_accounts).unwrap(),
169                &serde_json::to_string(&data.badges).unwrap(),
170                &if data.principal_org != Id::default() {
171                    Some(data.principal_org.printable())
172                } else {
173                    None
174                },
175                &data.org_creation_credits,
176                &data.org_user_register_credits,
177                &(data.is_super_verified as i32),
178            ]
179        );
180
181        if let Err(e) = res {
182            return Err(Error::DatabaseError(e.to_string()));
183        }
184
185        Ok(data)
186    }
187
188    /// Delete an existing user in the database.
189    ///
190    /// # Arguments
191    /// * `id` - the ID of the user
192    /// * `password` - the current password of the user
193    /// * `force` - if we should delete even if the given password is incorrect
194    pub async fn delete_user(&self, id: &Id, password: &str, force: bool) -> Result<User> {
195        let user = self.get_user_by_id(&id).await?;
196
197        if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
198            return Err(Error::IncorrectPassword);
199        }
200
201        let conn = match self.0.connect().await {
202            Ok(c) => c,
203            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
204        };
205
206        let res = execute!(
207            &conn,
208            "DELETE FROM a_users WHERE id = $1::bigint",
209            &[&id.printable()]
210        );
211
212        if let Err(e) = res {
213            return Err(Error::DatabaseError(e.to_string()));
214        }
215
216        self.cache_clear_user(&user).await;
217
218        // delete uploads
219        for upload in match self.1.get_uploads_by_owner_all(user.id.as_usize()).await {
220            Ok(x) => x,
221            Err(e) => return Err(Error::MiscError(e.to_string())),
222        } {
223            if let Err(e) = self.1.delete_upload(upload.id).await {
224                return Err(Error::MiscError(e.to_string()));
225            }
226        }
227
228        // ...
229        Ok(user)
230    }
231
232    pub async fn update_user_verified_status(&self, id: &Id, x: bool, user: User) -> Result<()> {
233        if !user.permissions.contains(&UserPermission::ManageVerified) {
234            return Err(Error::NotAllowed);
235        }
236
237        let other_user = self.get_user_by_id(&id).await?;
238
239        let conn = match self.0.connect().await {
240            Ok(c) => c,
241            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
242        };
243
244        let res = execute!(
245            &conn,
246            "UPDATE a_users SET is_verified = $1 WHERE id = $2::bigint",
247            params![&{ if x { 1 } else { 0 } }, &id.printable()]
248        );
249
250        if let Err(e) = res {
251            return Err(Error::DatabaseError(e.to_string()));
252        }
253
254        self.cache_clear_user(&other_user).await;
255
256        // create audit log entry
257        self.create_audit_log_entry(AuditLogEntry::new(
258            user.id,
259            format!(
260                "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
261                other_user.id, x
262            ),
263        ))
264        .await?;
265
266        // ...
267        Ok(())
268    }
269
270    /// Update a user's super verified status. This is intended to be called after they
271    /// successfully pay for verification.
272    pub async fn update_user_super_verified_status(&self, id: &Id, x: bool) -> Result<()> {
273        let other_user = self.get_user_by_id(&id).await?;
274
275        let conn = match self.0.connect().await {
276            Ok(c) => c,
277            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
278        };
279
280        let res = execute!(
281            &conn,
282            "UPDATE a_users SET is_super_verified = $1 WHERE id = $2::bigint",
283            params![&{ if x { 1 } else { 0 } }, &id.printable()]
284        );
285
286        if let Err(e) = res {
287            return Err(Error::DatabaseError(e.to_string()));
288        }
289
290        self.cache_clear_user(&other_user).await;
291
292        // create audit log entry
293        self.create_audit_log_entry(AuditLogEntry::new(
294            other_user.id.clone(),
295            format!(
296                "invoked `update_user_super_verified_status` with x value `{}` and y value `{}`",
297                other_user.id.printable(),
298                x
299            ),
300        ))
301        .await?;
302
303        // ...
304        Ok(())
305    }
306
307    pub async fn update_user_is_deactivated(&self, id: &Id, x: bool, user: User) -> Result<()> {
308        if id != &user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
309            return Err(Error::NotAllowed);
310        }
311
312        let other_user = self.get_user_by_id(&id).await?;
313
314        let conn = match self.0.connect().await {
315            Ok(c) => c,
316            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
317        };
318
319        let res = execute!(
320            &conn,
321            "UPDATE a_users SET is_deactivated = $1 WHERE id = $2::bigint",
322            params![&{ if x { 1 } else { 0 } }, &id.printable()]
323        );
324
325        if let Err(e) = res {
326            return Err(Error::DatabaseError(e.to_string()));
327        }
328
329        self.cache_clear_user(&other_user).await;
330
331        // create audit log entry (if we aren't the user that is being updated)
332        if user.id != other_user.id {
333            self.create_audit_log_entry(AuditLogEntry::new(
334                user.id,
335                format!(
336                    "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
337                    other_user.id, x
338                ),
339            ))
340            .await?;
341        }
342
343        // ...
344        Ok(())
345    }
346
347    pub async fn update_user_password(
348        &self,
349        id: &Id,
350        from: String,
351        to: String,
352        user: User,
353        force: bool,
354    ) -> Result<()> {
355        // verify password
356        if !user.check_password(from.clone()) && !force {
357            return Err(Error::MiscError("Password does not match".to_string()));
358        }
359
360        // ...
361        let conn = match self.0.connect().await {
362            Ok(c) => c,
363            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
364        };
365
366        let new_salt = salt();
367        let new_password = hash_salted(to, new_salt.clone());
368        let res = execute!(
369            &conn,
370            "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
371            params![&new_password.as_str(), &new_salt.as_str(), &id.printable()]
372        );
373
374        if let Err(e) = res {
375            return Err(Error::DatabaseError(e.to_string()));
376        }
377
378        self.cache_clear_user(&user).await;
379        Ok(())
380    }
381
382    pub async fn update_user_username(&self, id: &Id, to: String, user: User) -> Result<()> {
383        // check value
384        if to.len() < 2 {
385            return Err(Error::DataTooShort("username".to_string()));
386        } else if to.len() > 32 {
387            return Err(Error::DataTooLong("username".to_string()));
388        }
389
390        if self.0.0.banned_usernames.contains(&to) {
391            return Err(Error::MiscError("This username cannot be used".to_string()));
392        }
393
394        let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
395            .multi_line(true)
396            .build()
397            .unwrap();
398
399        if regex.captures(&to).is_some() {
400            return Err(Error::MiscError(
401                "This username contains invalid characters".to_string(),
402            ));
403        }
404
405        // ...
406        let conn = match self.0.connect().await {
407            Ok(c) => c,
408            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
409        };
410
411        let res = execute!(
412            &conn,
413            "UPDATE a_users SET username = $1 WHERE id = $2::bigint",
414            params![&to.to_lowercase(), &id.printable()]
415        );
416
417        if let Err(e) = res {
418            return Err(Error::DatabaseError(e.to_string()));
419        }
420
421        self.cache_clear_user(&user).await;
422        Ok(())
423    }
424
425    /// Validate a given TOTP code for the given profile.
426    pub fn check_totp(&self, ua: &User, code: &str) -> bool {
427        let totp = ua.totp(Some(
428            self.0
429                .0
430                .host
431                .replace("http://", "")
432                .replace("https://", "")
433                .replace(":", "_"),
434        ));
435
436        if let Some(totp) = totp {
437            return !code.is_empty()
438                && (totp.check_current(code).unwrap()
439                    | ua.recovery_codes.contains(&code.to_string()));
440        }
441
442        true
443    }
444
445    /// Generate 8 random recovery codes for TOTP.
446    pub fn generate_totp_recovery_codes() -> Vec<String> {
447        let mut out: Vec<String> = Vec::new();
448
449        for _ in 0..9 {
450            out.push(salt())
451        }
452
453        out
454    }
455
456    /// Update the profile's TOTP secret.
457    ///
458    /// # Arguments
459    /// * `id` - the ID of the user
460    /// * `secret` - the TOTP secret
461    /// * `recovery` - the TOTP recovery codes
462    pub async fn update_user_totp(
463        &self,
464        id: &Id,
465        secret: &str,
466        recovery: &Vec<String>,
467    ) -> Result<()> {
468        let user = self.get_user_by_id(&id).await?;
469
470        // update
471        let conn = match self.0.connect().await {
472            Ok(c) => c,
473            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
474        };
475
476        let res = execute!(
477            &conn,
478            "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
479            params![
480                &secret,
481                &serde_json::to_string(recovery).unwrap(),
482                &id.printable()
483            ]
484        );
485
486        if let Err(e) = res {
487            return Err(Error::DatabaseError(e.to_string()));
488        }
489
490        self.cache_clear_user(&user).await;
491        Ok(())
492    }
493
494    /// Enable TOTP for a profile.
495    ///
496    /// # Arguments
497    /// * `id` - the ID of the user to enable TOTP for
498    /// * `user` - the user doing this
499    ///
500    /// # Returns
501    /// `Result<(secret, qr base64)>`
502    pub async fn enable_totp(&self, id: &Id, user: User) -> Result<(String, String, Vec<String>)> {
503        let other_user = self.get_user_by_id(&id).await?;
504
505        if other_user.id != user.id {
506            if other_user
507                .permissions
508                .contains(&UserPermission::ManageUsers)
509            {
510                // create audit log entry
511                self.create_audit_log_entry(AuditLogEntry::new(
512                    user.id,
513                    format!("invoked `enable_totp` with x value `{}`", other_user.id,),
514                ))
515                .await?;
516            } else {
517                return Err(Error::NotAllowed);
518            }
519        }
520
521        let secret = totp_rs::Secret::default().to_string();
522        let recovery = Self::generate_totp_recovery_codes();
523        self.update_user_totp(id, &secret, &recovery).await?;
524
525        // fetch profile again (with totp information)
526        let other_user = self.get_user_by_id(&id).await?;
527
528        // get totp
529        let totp = other_user.totp(Some(
530            self.0
531                .0
532                .host
533                .replace("http://", "")
534                .replace("https://", "")
535                .replace(":", "_"),
536        ));
537
538        if totp.is_none() {
539            return Err(Error::MiscError("Failed to get TOTP code".to_string()));
540        }
541
542        let totp = totp.unwrap();
543
544        // generate qr
545        let qr = match totp.get_qr_base64() {
546            Ok(q) => q,
547            Err(e) => return Err(Error::MiscError(e.to_string())),
548        };
549
550        // return
551        Ok((totp.get_secret_base32(), qr, recovery))
552    }
553
554    /// Get the given user's principal organization.
555    pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
556        if user.principal_org == Id::default() {
557            return None;
558        }
559
560        if let Ok(x) = self.get_organization_by_id(&user.principal_org).await {
561            Some(x)
562        } else {
563            self.update_user_principal_org(&user.id, None)
564                .await
565                .expect("failed to clear user principal org");
566
567            None
568        }
569    }
570
571    pub async fn cache_clear_user(&self, user: &User) {
572        self.0.1.remove(format!("srmp.user:{}", user.id)).await;
573        self.0
574            .1
575            .remove(format!("srmp.user:{}", user.username))
576            .await;
577    }
578
579    auto_method!(update_user_permissions(Id :: usize, Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
580    auto_method!(update_user_tokens(Id :: usize, Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
581    auto_method!(update_user_settings(Id :: usize, UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
582    auto_method!(update_user_ban_reason(Id :: usize, &str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
583    auto_method!(update_user_ban_expire(Id :: usize, i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
584    auto_method!(update_user_checkouts(Id :: usize, Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
585    auto_method!(update_user_last_policy_consent(Id :: usize, i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
586    auto_method!(update_user_linked_accounts(Id :: usize, UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
587    auto_method!(update_user_badges(Id :: usize, Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
588    auto_method!(update_user_principal_org(Id :: usize, Option<String>)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
589    auto_method!(update_user_org_as_tenant(Id :: usize, i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
590
591    auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1::bigint" --name="user" --returns=User);
592    auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
593
594    auto_method!(update_user_notification_count(Id :: usize, i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
595    auto_method!(incr_user_notifications(Id :: usize)@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --incr);
596    auto_method!(decr_user_notifications(Id :: usize)@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --decr=notification_count);
597
598    auto_method!(incr_user_org_creation_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --incr);
599    auto_method!(decr_user_org_creation_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
600
601    auto_method!(incr_user_org_user_register_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --incr);
602    auto_method!(decr_user_org_user_register_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
603}