autter_core/database/
users.rs

1use crate::{
2    DataManager,
3    model::{
4        AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5        UserSettings, organizations::Organization,
6    },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14    /// Get a [`User`] from an SQL row.
15    pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16        User {
17            id: get!(x->0(i64)) as usize,
18            created: get!(x->1(i64)) as usize,
19            username: get!(x->2(String)),
20            password: get!(x->3(String)),
21            salt: get!(x->4(String)),
22            settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23            tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24            permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25            is_verified: get!(x->8(i32)) as i8 == 1,
26            notification_count: {
27                let x = get!(x->9(i32)) as usize;
28                // we're a little too close to the maximum count, clearly something's gone wrong
29                if x > usize::MAX - 1000 { 0 } else { x }
30            },
31            totp: get!(x->10(String)),
32            recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33            stripe_id: get!(x->12(String)),
34            ban_reason: get!(x->13(String)),
35            ban_expire: get!(x->14(i64)) as usize,
36            is_deactivated: get!(x->15(i32)) as i8 == 1,
37            checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38            last_policy_consent: get!(x->17(i64)) as usize,
39            linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40            badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41            principal_org: get!(x->20(i64)) as usize,
42            org_as_tenant: get!(x->21(i32)) as i8 == 1,
43            org_creation_credits: get!(x->22(i32)),
44            org_user_register_credits: get!(x->23(i32)),
45        }
46    }
47
48    auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
49    auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
50    auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
51
52    /// Get a user given just their ID. Returns the void user if the user doesn't exist.
53    ///
54    /// # Arguments
55    /// * `id` - the ID of the user
56    pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
57        let conn = match self.0.connect().await {
58            Ok(c) => c,
59            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
60        };
61
62        let res = query_row!(
63            &conn,
64            "SELECT * FROM a_users WHERE id = $1",
65            &[&(id as i64)],
66            |x| Ok(Self::get_user_from_row(x))
67        );
68
69        if res.is_err() {
70            return Ok(User::deleted());
71            // return Err(Error::UserNotFound);
72        }
73
74        Ok(res.unwrap())
75    }
76
77    /// Get a user given just their auth token.
78    ///
79    /// # Arguments
80    /// * `token` - the token of the user
81    pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
82        let conn = match self.0.connect().await {
83            Ok(c) => c,
84            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
85        };
86
87        let res = query_row!(
88            &conn,
89            "SELECT * FROM a_users WHERE tokens LIKE $1",
90            &[&format!("%\"{token}\"%")],
91            |x| Ok(Self::get_user_from_row(x))
92        );
93
94        if res.is_err() {
95            return Err(Error::UserNotFound);
96        }
97
98        Ok(res.unwrap())
99    }
100
101    /// Create a new user in the database.
102    ///
103    /// # Arguments
104    /// * `data` - a mock [`User`] object to insert
105    pub async fn create_user(&self, mut data: User) -> Result<User> {
106        if !self.0.0.security.registration_enabled {
107            return Err(Error::RegistrationDisabled);
108        }
109
110        data.username = data.username.to_lowercase();
111
112        // check values
113        if data.username.len() < 2 {
114            return Err(Error::DataTooShort("username".to_string()));
115        } else if data.username.len() > 32 {
116            return Err(Error::DataTooLong("username".to_string()));
117        }
118
119        if data.password.len() < 6 {
120            return Err(Error::DataTooShort("password".to_string()));
121        }
122
123        if self.0.0.banned_usernames.contains(&data.username) {
124            return Err(Error::MiscError("This username cannot be used".to_string()));
125        }
126
127        // make sure username isn't taken
128        if self.get_user_by_username(&data.username).await.is_ok() {
129            return Err(Error::UsernameInUse);
130        }
131
132        // ...
133        let conn = match self.0.connect().await {
134            Ok(c) => c,
135            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
136        };
137
138        let res = execute!(
139            &conn,
140            "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)",
141            params![
142                &(data.id as i64),
143                &(data.created as i64),
144                &data.username.to_lowercase(),
145                &data.password,
146                &data.salt,
147                &serde_json::to_string(&data.settings).unwrap(),
148                &serde_json::to_string(&data.tokens).unwrap(),
149                &serde_json::to_string(&data.permissions).unwrap(),
150                &(data.is_verified as i32),
151                &0_i32,
152                &String::new(),
153                "[]",
154                &data.stripe_id,
155                &data.ban_reason,
156                &(data.ban_expire as i64),
157                &(data.is_deactivated as i32),
158                &serde_json::to_string(&data.checkouts).unwrap(),
159                &(data.last_policy_consent as i64),
160                &serde_json::to_string(&data.linked_accounts).unwrap(),
161                &serde_json::to_string(&data.badges).unwrap(),
162                &(data.principal_org as i64),
163                &((data.principal_org > 0) as i32),
164                &data.org_creation_credits,
165                &data.org_user_register_credits,
166            ]
167        );
168
169        if let Err(e) = res {
170            return Err(Error::DatabaseError(e.to_string()));
171        }
172
173        Ok(data)
174    }
175
176    /// Delete an existing user in the database.
177    ///
178    /// # Arguments
179    /// * `id` - the ID of the user
180    /// * `password` - the current password of the user
181    /// * `force` - if we should delete even if the given password is incorrect
182    pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
183        let user = self.get_user_by_id(id).await?;
184
185        if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
186            return Err(Error::IncorrectPassword);
187        }
188
189        let conn = match self.0.connect().await {
190            Ok(c) => c,
191            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
192        };
193
194        let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
195
196        if let Err(e) = res {
197            return Err(Error::DatabaseError(e.to_string()));
198        }
199
200        self.cache_clear_user(&user).await;
201
202        // delete notifications
203        let res = execute!(
204            &conn,
205            "DELETE FROM a_notifications WHERE owner = $1",
206            &[&(id as i64)]
207        );
208
209        if let Err(e) = res {
210            return Err(Error::DatabaseError(e.to_string()));
211        }
212
213        // delete warnings
214        let res = execute!(
215            &conn,
216            "DELETE FROM a_user_warnings WHERE receiver = $1",
217            &[&(id as i64)]
218        );
219
220        if let Err(e) = res {
221            return Err(Error::DatabaseError(e.to_string()));
222        }
223
224        // delete uploads
225        for upload in match self.1.get_uploads_by_owner_all(user.id).await {
226            Ok(x) => x,
227            Err(e) => return Err(Error::MiscError(e.to_string())),
228        } {
229            if let Err(e) = self.1.delete_upload(upload.id).await {
230                return Err(Error::MiscError(e.to_string()));
231            }
232        }
233
234        // ...
235        Ok(user)
236    }
237
238    pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
239        if !user.permissions.contains(&UserPermission::ManageVerified) {
240            return Err(Error::NotAllowed);
241        }
242
243        let other_user = self.get_user_by_id(id).await?;
244
245        let conn = match self.0.connect().await {
246            Ok(c) => c,
247            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
248        };
249
250        let res = execute!(
251            &conn,
252            "UPDATE a_users SET verified = $1 WHERE id = $2",
253            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
254        );
255
256        if let Err(e) = res {
257            return Err(Error::DatabaseError(e.to_string()));
258        }
259
260        self.cache_clear_user(&other_user).await;
261
262        // create audit log entry
263        self.create_audit_log_entry(AuditLogEntry::new(
264            user.id,
265            format!(
266                "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
267                other_user.id, x
268            ),
269        ))
270        .await?;
271
272        // ...
273        Ok(())
274    }
275
276    pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
277        if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
278            return Err(Error::NotAllowed);
279        }
280
281        let other_user = self.get_user_by_id(id).await?;
282
283        let conn = match self.0.connect().await {
284            Ok(c) => c,
285            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
286        };
287
288        let res = execute!(
289            &conn,
290            "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
291            params![&{ if x { 1 } else { 0 } }, &(id as i64)]
292        );
293
294        if let Err(e) = res {
295            return Err(Error::DatabaseError(e.to_string()));
296        }
297
298        self.cache_clear_user(&other_user).await;
299
300        // create audit log entry (if we aren't the user that is being updated)
301        if user.id != other_user.id {
302            self.create_audit_log_entry(AuditLogEntry::new(
303                user.id,
304                format!(
305                    "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
306                    other_user.id, x
307                ),
308            ))
309            .await?;
310        }
311
312        // ...
313        Ok(())
314    }
315
316    pub async fn update_user_password(
317        &self,
318        id: usize,
319        from: String,
320        to: String,
321        user: User,
322        force: bool,
323    ) -> Result<()> {
324        // verify password
325        if !user.check_password(from.clone()) && !force {
326            return Err(Error::MiscError("Password does not match".to_string()));
327        }
328
329        // ...
330        let conn = match self.0.connect().await {
331            Ok(c) => c,
332            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
333        };
334
335        let new_salt = salt();
336        let new_password = hash_salted(to, new_salt.clone());
337        let res = execute!(
338            &conn,
339            "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
340            params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
341        );
342
343        if let Err(e) = res {
344            return Err(Error::DatabaseError(e.to_string()));
345        }
346
347        self.cache_clear_user(&user).await;
348        Ok(())
349    }
350
351    pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
352        // check value
353        if to.len() < 2 {
354            return Err(Error::DataTooShort("username".to_string()));
355        } else if to.len() > 32 {
356            return Err(Error::DataTooLong("username".to_string()));
357        }
358
359        if self.0.0.banned_usernames.contains(&to) {
360            return Err(Error::MiscError("This username cannot be used".to_string()));
361        }
362
363        let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
364            .multi_line(true)
365            .build()
366            .unwrap();
367
368        if regex.captures(&to).is_some() {
369            return Err(Error::MiscError(
370                "This username contains invalid characters".to_string(),
371            ));
372        }
373
374        // ...
375        let conn = match self.0.connect().await {
376            Ok(c) => c,
377            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
378        };
379
380        let res = execute!(
381            &conn,
382            "UPDATE a_users SET username = $1 WHERE id = $2",
383            params![&to.to_lowercase(), &(id as i64)]
384        );
385
386        if let Err(e) = res {
387            return Err(Error::DatabaseError(e.to_string()));
388        }
389
390        self.cache_clear_user(&user).await;
391        Ok(())
392    }
393
394    /// Validate a given TOTP code for the given profile.
395    pub fn check_totp(&self, ua: &User, code: &str) -> bool {
396        let totp = ua.totp(Some(
397            self.0
398                .0
399                .host
400                .replace("http://", "")
401                .replace("https://", "")
402                .replace(":", "_"),
403        ));
404
405        if let Some(totp) = totp {
406            return !code.is_empty()
407                && (totp.check_current(code).unwrap()
408                    | ua.recovery_codes.contains(&code.to_string()));
409        }
410
411        true
412    }
413
414    /// Generate 8 random recovery codes for TOTP.
415    pub fn generate_totp_recovery_codes() -> Vec<String> {
416        let mut out: Vec<String> = Vec::new();
417
418        for _ in 0..9 {
419            out.push(salt())
420        }
421
422        out
423    }
424
425    /// Update the profile's TOTP secret.
426    ///
427    /// # Arguments
428    /// * `id` - the ID of the user
429    /// * `secret` - the TOTP secret
430    /// * `recovery` - the TOTP recovery codes
431    pub async fn update_user_totp(
432        &self,
433        id: usize,
434        secret: &str,
435        recovery: &Vec<String>,
436    ) -> Result<()> {
437        let user = self.get_user_by_id(id).await?;
438
439        // update
440        let conn = match self.0.connect().await {
441            Ok(c) => c,
442            Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
443        };
444
445        let res = execute!(
446            &conn,
447            "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
448            params![
449                &secret,
450                &serde_json::to_string(recovery).unwrap(),
451                &(id as i64)
452            ]
453        );
454
455        if let Err(e) = res {
456            return Err(Error::DatabaseError(e.to_string()));
457        }
458
459        self.cache_clear_user(&user).await;
460        Ok(())
461    }
462
463    /// Enable TOTP for a profile.
464    ///
465    /// # Arguments
466    /// * `id` - the ID of the user to enable TOTP for
467    /// * `user` - the user doing this
468    ///
469    /// # Returns
470    /// `Result<(secret, qr base64)>`
471    pub async fn enable_totp(
472        &self,
473        id: usize,
474        user: User,
475    ) -> Result<(String, String, Vec<String>)> {
476        let other_user = self.get_user_by_id(id).await?;
477
478        if other_user.id != user.id {
479            if other_user
480                .permissions
481                .contains(&UserPermission::ManageUsers)
482            {
483                // create audit log entry
484                self.create_audit_log_entry(AuditLogEntry::new(
485                    user.id,
486                    format!("invoked `enable_totp` with x value `{}`", other_user.id,),
487                ))
488                .await?;
489            } else {
490                return Err(Error::NotAllowed);
491            }
492        }
493
494        let secret = totp_rs::Secret::default().to_string();
495        let recovery = Self::generate_totp_recovery_codes();
496        self.update_user_totp(id, &secret, &recovery).await?;
497
498        // fetch profile again (with totp information)
499        let other_user = self.get_user_by_id(id).await?;
500
501        // get totp
502        let totp = other_user.totp(Some(
503            self.0
504                .0
505                .host
506                .replace("http://", "")
507                .replace("https://", "")
508                .replace(":", "_"),
509        ));
510
511        if totp.is_none() {
512            return Err(Error::MiscError("Failed to get TOTP code".to_string()));
513        }
514
515        let totp = totp.unwrap();
516
517        // generate qr
518        let qr = match totp.get_qr_base64() {
519            Ok(q) => q,
520            Err(e) => return Err(Error::MiscError(e.to_string())),
521        };
522
523        // return
524        Ok((totp.get_secret_base32(), qr, recovery))
525    }
526
527    /// Get the given user's principal organization.
528    pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
529        if user.principal_org == 0 {
530            return None;
531        }
532
533        if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
534            Some(x)
535        } else {
536            self.update_user_principal_org(user.id, 0)
537                .await
538                .expect("failed to clear user principal org");
539
540            None
541        }
542    }
543
544    pub async fn cache_clear_user(&self, user: &User) {
545        self.0.1.remove(format!("srmp.user:{}", user.id)).await;
546        self.0
547            .1
548            .remove(format!("srmp.user:{}", user.username))
549            .await;
550    }
551
552    auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
553    auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
554    auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
555    auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
556    auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
557    auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
558    auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
559    auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
560    auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
561    auto_method!(update_user_principal_org(i64)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
562    auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
563
564    auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
565    auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
566
567    auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
568    auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
569    auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
570
571    auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
572    auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
573
574    auto_method!(incr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
575    auto_method!(decr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
576}