1use crate::{
2 DataManager,
3 model::{
4 AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5 UserSettings, organizations::Organization,
6 },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14 pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16 User {
17 id: get!(x->0(i64)) as usize,
18 created: get!(x->1(i64)) as usize,
19 username: get!(x->2(String)),
20 password: get!(x->3(String)),
21 salt: get!(x->4(String)),
22 settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23 tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24 permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25 is_verified: get!(x->8(i32)) as i8 == 1,
26 notification_count: {
27 let x = get!(x->9(i32)) as usize;
28 if x > usize::MAX - 1000 { 0 } else { x }
30 },
31 totp: get!(x->10(String)),
32 recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33 stripe_id: get!(x->12(String)),
34 ban_reason: get!(x->13(String)),
35 ban_expire: get!(x->14(i64)) as usize,
36 is_deactivated: get!(x->15(i32)) as i8 == 1,
37 checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38 last_policy_consent: get!(x->17(i64)) as usize,
39 linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40 badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41 principal_org: get!(x->20(i64) default (0)) as usize,
42 org_as_tenant: get!(x->21(i32)) as i8 == 1,
43 org_creation_credits: get!(x->22(i32)),
44 org_user_register_credits: get!(x->23(i32)),
45 is_super_verified: get!(x->24(i32)) as i8 == 1,
46 }
47 }
48
49 auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
50 auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
51 auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
52
53 pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
58 let conn = match self.0.connect().await {
59 Ok(c) => c,
60 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
61 };
62
63 let res = query_row!(
64 &conn,
65 "SELECT * FROM a_users WHERE id = $1",
66 &[&(id as i64)],
67 |x| Ok(Self::get_user_from_row(x))
68 );
69
70 if res.is_err() {
71 return Ok(User::deleted());
72 }
74
75 Ok(res.unwrap())
76 }
77
78 pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
83 let conn = match self.0.connect().await {
84 Ok(c) => c,
85 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
86 };
87
88 let res = query_row!(
89 &conn,
90 "SELECT * FROM a_users WHERE tokens LIKE $1",
91 &[&format!("%\"{token}\"%")],
92 |x| Ok(Self::get_user_from_row(x))
93 );
94
95 if res.is_err() {
96 return Err(Error::UserNotFound);
97 }
98
99 Ok(res.unwrap())
100 }
101
102 pub async fn create_user(&self, mut data: User) -> Result<User> {
107 if !self.0.0.security.registration_enabled {
108 return Err(Error::RegistrationDisabled);
109 }
110
111 data.username = data.username.to_lowercase();
112
113 if data.username.len() < 2 {
115 return Err(Error::DataTooShort("username".to_string()));
116 } else if data.username.len() > 32 {
117 return Err(Error::DataTooLong("username".to_string()));
118 }
119
120 if data.password.len() < 6 {
121 return Err(Error::DataTooShort("password".to_string()));
122 }
123
124 if self.0.0.banned_usernames.contains(&data.username) {
125 return Err(Error::MiscError("This username cannot be used".to_string()));
126 }
127
128 if self.get_user_by_username(&data.username).await.is_ok() {
130 return Err(Error::UsernameInUse);
131 }
132
133 let conn = match self.0.connect().await {
135 Ok(c) => c,
136 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
137 };
138
139 let res = execute!(
140 &conn,
141 "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25)",
142 params![
143 &(data.id as i64),
144 &(data.created as i64),
145 &data.username.to_lowercase(),
146 &data.password,
147 &data.salt,
148 &serde_json::to_string(&data.settings).unwrap(),
149 &serde_json::to_string(&data.tokens).unwrap(),
150 &serde_json::to_string(&data.permissions).unwrap(),
151 &(data.is_verified as i32),
152 &0_i32,
153 &String::new(),
154 "[]",
155 &data.stripe_id,
156 &data.ban_reason,
157 &(data.ban_expire as i64),
158 &(data.is_deactivated as i32),
159 &serde_json::to_string(&data.checkouts).unwrap(),
160 &(data.last_policy_consent as i64),
161 &serde_json::to_string(&data.linked_accounts).unwrap(),
162 &serde_json::to_string(&data.badges).unwrap(),
163 &if data.principal_org != 0 {
164 Some(data.principal_org as i64)
165 } else {
166 None
167 },
168 &((data.principal_org > 0) as i32),
169 &data.org_creation_credits,
170 &data.org_user_register_credits,
171 &data.is_super_verified
172 ]
173 );
174
175 if let Err(e) = res {
176 return Err(Error::DatabaseError(e.to_string()));
177 }
178
179 Ok(data)
180 }
181
182 pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
189 let user = self.get_user_by_id(id).await?;
190
191 if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
192 return Err(Error::IncorrectPassword);
193 }
194
195 let conn = match self.0.connect().await {
196 Ok(c) => c,
197 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
198 };
199
200 let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
201
202 if let Err(e) = res {
203 return Err(Error::DatabaseError(e.to_string()));
204 }
205
206 self.cache_clear_user(&user).await;
207
208 for upload in match self.1.get_uploads_by_owner_all(user.id).await {
210 Ok(x) => x,
211 Err(e) => return Err(Error::MiscError(e.to_string())),
212 } {
213 if let Err(e) = self.1.delete_upload(upload.id).await {
214 return Err(Error::MiscError(e.to_string()));
215 }
216 }
217
218 Ok(user)
220 }
221
222 pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
223 if !user.permissions.contains(&UserPermission::ManageVerified) {
224 return Err(Error::NotAllowed);
225 }
226
227 let other_user = self.get_user_by_id(id).await?;
228
229 let conn = match self.0.connect().await {
230 Ok(c) => c,
231 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
232 };
233
234 let res = execute!(
235 &conn,
236 "UPDATE a_users SET is_verified = $1 WHERE id = $2",
237 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
238 );
239
240 if let Err(e) = res {
241 return Err(Error::DatabaseError(e.to_string()));
242 }
243
244 self.cache_clear_user(&other_user).await;
245
246 self.create_audit_log_entry(AuditLogEntry::new(
248 user.id,
249 format!(
250 "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
251 other_user.id, x
252 ),
253 ))
254 .await?;
255
256 Ok(())
258 }
259
260 pub async fn update_user_super_verified_status(&self, id: usize, x: bool) -> Result<()> {
263 let other_user = self.get_user_by_id(id).await?;
264
265 let conn = match self.0.connect().await {
266 Ok(c) => c,
267 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
268 };
269
270 let res = execute!(
271 &conn,
272 "UPDATE a_users SET is_super_verified = $1 WHERE id = $2",
273 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
274 );
275
276 if let Err(e) = res {
277 return Err(Error::DatabaseError(e.to_string()));
278 }
279
280 self.cache_clear_user(&other_user).await;
281
282 self.create_audit_log_entry(AuditLogEntry::new(
284 other_user.id,
285 format!(
286 "invoked `update_user_super_verified_status` with x value `{}` and y value `{}`",
287 other_user.id, x
288 ),
289 ))
290 .await?;
291
292 Ok(())
294 }
295
296 pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
297 if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
298 return Err(Error::NotAllowed);
299 }
300
301 let other_user = self.get_user_by_id(id).await?;
302
303 let conn = match self.0.connect().await {
304 Ok(c) => c,
305 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
306 };
307
308 let res = execute!(
309 &conn,
310 "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
311 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
312 );
313
314 if let Err(e) = res {
315 return Err(Error::DatabaseError(e.to_string()));
316 }
317
318 self.cache_clear_user(&other_user).await;
319
320 if user.id != other_user.id {
322 self.create_audit_log_entry(AuditLogEntry::new(
323 user.id,
324 format!(
325 "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
326 other_user.id, x
327 ),
328 ))
329 .await?;
330 }
331
332 Ok(())
334 }
335
336 pub async fn update_user_password(
337 &self,
338 id: usize,
339 from: String,
340 to: String,
341 user: User,
342 force: bool,
343 ) -> Result<()> {
344 if !user.check_password(from.clone()) && !force {
346 return Err(Error::MiscError("Password does not match".to_string()));
347 }
348
349 let conn = match self.0.connect().await {
351 Ok(c) => c,
352 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
353 };
354
355 let new_salt = salt();
356 let new_password = hash_salted(to, new_salt.clone());
357 let res = execute!(
358 &conn,
359 "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
360 params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
361 );
362
363 if let Err(e) = res {
364 return Err(Error::DatabaseError(e.to_string()));
365 }
366
367 self.cache_clear_user(&user).await;
368 Ok(())
369 }
370
371 pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
372 if to.len() < 2 {
374 return Err(Error::DataTooShort("username".to_string()));
375 } else if to.len() > 32 {
376 return Err(Error::DataTooLong("username".to_string()));
377 }
378
379 if self.0.0.banned_usernames.contains(&to) {
380 return Err(Error::MiscError("This username cannot be used".to_string()));
381 }
382
383 let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
384 .multi_line(true)
385 .build()
386 .unwrap();
387
388 if regex.captures(&to).is_some() {
389 return Err(Error::MiscError(
390 "This username contains invalid characters".to_string(),
391 ));
392 }
393
394 let conn = match self.0.connect().await {
396 Ok(c) => c,
397 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
398 };
399
400 let res = execute!(
401 &conn,
402 "UPDATE a_users SET username = $1 WHERE id = $2",
403 params![&to.to_lowercase(), &(id as i64)]
404 );
405
406 if let Err(e) = res {
407 return Err(Error::DatabaseError(e.to_string()));
408 }
409
410 self.cache_clear_user(&user).await;
411 Ok(())
412 }
413
414 pub fn check_totp(&self, ua: &User, code: &str) -> bool {
416 let totp = ua.totp(Some(
417 self.0
418 .0
419 .host
420 .replace("http://", "")
421 .replace("https://", "")
422 .replace(":", "_"),
423 ));
424
425 if let Some(totp) = totp {
426 return !code.is_empty()
427 && (totp.check_current(code).unwrap()
428 | ua.recovery_codes.contains(&code.to_string()));
429 }
430
431 true
432 }
433
434 pub fn generate_totp_recovery_codes() -> Vec<String> {
436 let mut out: Vec<String> = Vec::new();
437
438 for _ in 0..9 {
439 out.push(salt())
440 }
441
442 out
443 }
444
445 pub async fn update_user_totp(
452 &self,
453 id: usize,
454 secret: &str,
455 recovery: &Vec<String>,
456 ) -> Result<()> {
457 let user = self.get_user_by_id(id).await?;
458
459 let conn = match self.0.connect().await {
461 Ok(c) => c,
462 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
463 };
464
465 let res = execute!(
466 &conn,
467 "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
468 params![
469 &secret,
470 &serde_json::to_string(recovery).unwrap(),
471 &(id as i64)
472 ]
473 );
474
475 if let Err(e) = res {
476 return Err(Error::DatabaseError(e.to_string()));
477 }
478
479 self.cache_clear_user(&user).await;
480 Ok(())
481 }
482
483 pub async fn enable_totp(
492 &self,
493 id: usize,
494 user: User,
495 ) -> Result<(String, String, Vec<String>)> {
496 let other_user = self.get_user_by_id(id).await?;
497
498 if other_user.id != user.id {
499 if other_user
500 .permissions
501 .contains(&UserPermission::ManageUsers)
502 {
503 self.create_audit_log_entry(AuditLogEntry::new(
505 user.id,
506 format!("invoked `enable_totp` with x value `{}`", other_user.id,),
507 ))
508 .await?;
509 } else {
510 return Err(Error::NotAllowed);
511 }
512 }
513
514 let secret = totp_rs::Secret::default().to_string();
515 let recovery = Self::generate_totp_recovery_codes();
516 self.update_user_totp(id, &secret, &recovery).await?;
517
518 let other_user = self.get_user_by_id(id).await?;
520
521 let totp = other_user.totp(Some(
523 self.0
524 .0
525 .host
526 .replace("http://", "")
527 .replace("https://", "")
528 .replace(":", "_"),
529 ));
530
531 if totp.is_none() {
532 return Err(Error::MiscError("Failed to get TOTP code".to_string()));
533 }
534
535 let totp = totp.unwrap();
536
537 let qr = match totp.get_qr_base64() {
539 Ok(q) => q,
540 Err(e) => return Err(Error::MiscError(e.to_string())),
541 };
542
543 Ok((totp.get_secret_base32(), qr, recovery))
545 }
546
547 pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
549 if user.principal_org == 0 {
550 return None;
551 }
552
553 if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
554 Some(x)
555 } else {
556 self.update_user_principal_org(user.id, None)
557 .await
558 .expect("failed to clear user principal org");
559
560 None
561 }
562 }
563
564 pub async fn cache_clear_user(&self, user: &User) {
565 self.0.1.remove(format!("srmp.user:{}", user.id)).await;
566 self.0
567 .1
568 .remove(format!("srmp.user:{}", user.username))
569 .await;
570 }
571
572 auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
573 auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
574 auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
575 auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
576 auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
577 auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
578 auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
579 auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
580 auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
581 auto_method!(update_user_principal_org(Option<i64>)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
582 auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
583
584 auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
585 auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
586
587 auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
588 auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
589 auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
590
591 auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
592 auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
593
594 auto_method!(incr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
595 auto_method!(decr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
596}