1use crate::{
2 DataManager,
3 model::{
4 AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5 UserSettings, organizations::Organization,
6 },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core2::{auto_method, model::id::Id};
11use tritools::encoding::{hash_salted, salt};
12
13#[inline]
14fn transform_uid(uid: &Id) -> usize {
15 uid.as_usize() / 2
16}
17
18impl DataManager {
19 pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
21 User {
22 id: Id::Legacy(get!(x->0(i64)) as usize),
23 created: get!(x->1(i64)) as u128,
24 username: get!(x->2(String)),
25 password: get!(x->3(String)),
26 salt: get!(x->4(String)),
27 settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
28 tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
29 permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
30 is_verified: get!(x->8(i32)) as i8 == 1,
31 notification_count: {
32 let x = get!(x->9(i32)) as usize;
33 if x > usize::MAX - 1000 { 0 } else { x }
35 },
36 totp: get!(x->10(String)),
37 recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
38 stripe_id: get!(x->12(String)),
39 ban_reason: get!(x->13(String)),
40 ban_expire: get!(x->14(i64)) as usize,
41 is_deactivated: get!(x->15(i32)) as i8 == 1,
42 checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
43 last_policy_consent: get!(x->17(i64)) as u128,
44 linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
45 badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
46 principal_org: Id::deserialize(
47 &get!(x->20(String) default (Id::default().printable())),
48 ),
49 org_as_tenant: get!(x->21(i32)) as i8 == 1,
50 org_creation_credits: get!(x->22(i32)),
51 org_user_register_credits: get!(x->23(i32)),
52 is_super_verified: get!(x->24(i32)) as i8 == 1,
53 }
54 }
55
56 auto_method!(get_user_by_id(Id :: usize)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1::bigint" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
57 auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
58 auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
59
60 pub async fn get_user_by_id_with_void(&self, id: &Id) -> Result<User> {
65 let conn = match self.0.connect().await {
66 Ok(c) => c,
67 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
68 };
69
70 let res = query_row!(
71 &conn,
72 "SELECT * FROM a_users WHERE id = $1::bigint",
73 &[&id.printable()],
74 |x| Ok(Self::get_user_from_row(x))
75 );
76
77 if res.is_err() {
78 return Ok(User::deleted());
79 }
81
82 Ok(res.unwrap())
83 }
84
85 pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
90 let conn = match self.0.connect().await {
91 Ok(c) => c,
92 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
93 };
94
95 let res = query_row!(
96 &conn,
97 "SELECT * FROM a_users WHERE tokens LIKE $1",
98 &[&format!("%\"{token}\"%")],
99 |x| Ok(Self::get_user_from_row(x))
100 );
101
102 if res.is_err() {
103 return Err(Error::UserNotFound);
104 }
105
106 Ok(res.unwrap())
107 }
108
109 pub async fn create_user(&self, mut data: User) -> Result<User> {
114 if !self.0.0.security.registration_enabled {
115 return Err(Error::RegistrationDisabled);
116 }
117
118 data.username = data.username.to_lowercase();
119
120 if data.username.len() < 2 {
122 return Err(Error::DataTooShort("username".to_string()));
123 } else if data.username.len() > 32 {
124 return Err(Error::DataTooLong("username".to_string()));
125 }
126
127 if data.password.len() < 6 {
128 return Err(Error::DataTooShort("password".to_string()));
129 }
130
131 if self.0.0.banned_usernames.contains(&data.username) {
132 return Err(Error::MiscError("This username cannot be used".to_string()));
133 }
134
135 if self.get_user_by_username(&data.username).await.is_ok() {
137 return Err(Error::UsernameInUse);
138 }
139
140 let conn = match self.0.connect().await {
142 Ok(c) => c,
143 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
144 };
145
146 let res = execute!(
147 &conn,
148 "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)",
149 params![
150 &(transform_uid(&data.id) as i64),
151 &(data.created as i64),
152 &data.username.to_lowercase(),
153 &data.password,
154 &data.salt,
155 &serde_json::to_string(&data.settings).unwrap(),
156 &serde_json::to_string(&data.tokens).unwrap(),
157 &serde_json::to_string(&data.permissions).unwrap(),
158 &(data.is_verified as i32),
159 &0_i32,
160 &String::new(),
161 "[]",
162 &data.stripe_id,
163 &data.ban_reason,
164 &(data.ban_expire as i64),
165 &(data.is_deactivated as i32),
166 &serde_json::to_string(&data.checkouts).unwrap(),
167 &(data.last_policy_consent as i64),
168 &serde_json::to_string(&data.linked_accounts).unwrap(),
169 &serde_json::to_string(&data.badges).unwrap(),
170 &if data.principal_org != Id::default() {
171 Some(data.principal_org.printable())
172 } else {
173 None
174 },
175 &data.org_creation_credits,
176 &data.org_user_register_credits,
177 &(data.is_super_verified as i32),
178 ]
179 );
180
181 if let Err(e) = res {
182 return Err(Error::DatabaseError(e.to_string()));
183 }
184
185 Ok(data)
186 }
187
188 pub async fn delete_user(&self, id: &Id, password: &str, force: bool) -> Result<User> {
195 let user = self.get_user_by_id(&id).await?;
196
197 if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
198 return Err(Error::IncorrectPassword);
199 }
200
201 let conn = match self.0.connect().await {
202 Ok(c) => c,
203 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
204 };
205
206 let res = execute!(
207 &conn,
208 "DELETE FROM a_users WHERE id = $1::bigint",
209 &[&id.printable()]
210 );
211
212 if let Err(e) = res {
213 return Err(Error::DatabaseError(e.to_string()));
214 }
215
216 self.cache_clear_user(&user).await;
217
218 for upload in match self.1.get_uploads_by_owner_all(user.id.as_usize()).await {
220 Ok(x) => x,
221 Err(e) => return Err(Error::MiscError(e.to_string())),
222 } {
223 if let Err(e) = self.1.delete_upload(upload.id).await {
224 return Err(Error::MiscError(e.to_string()));
225 }
226 }
227
228 Ok(user)
230 }
231
232 pub async fn update_user_verified_status(&self, id: &Id, x: bool, user: User) -> Result<()> {
233 if !user.permissions.contains(&UserPermission::ManageVerified) {
234 return Err(Error::NotAllowed);
235 }
236
237 let other_user = self.get_user_by_id(&id).await?;
238
239 let conn = match self.0.connect().await {
240 Ok(c) => c,
241 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
242 };
243
244 let res = execute!(
245 &conn,
246 "UPDATE a_users SET is_verified = $1 WHERE id = $2::bigint",
247 params![&{ if x { 1 } else { 0 } }, &id.printable()]
248 );
249
250 if let Err(e) = res {
251 return Err(Error::DatabaseError(e.to_string()));
252 }
253
254 self.cache_clear_user(&other_user).await;
255
256 self.create_audit_log_entry(AuditLogEntry::new(
258 user.id,
259 format!(
260 "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
261 other_user.id, x
262 ),
263 ))
264 .await?;
265
266 Ok(())
268 }
269
270 pub async fn update_user_super_verified_status(&self, id: &Id, x: bool) -> Result<()> {
273 let other_user = self.get_user_by_id(&id).await?;
274
275 let conn = match self.0.connect().await {
276 Ok(c) => c,
277 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
278 };
279
280 let res = execute!(
281 &conn,
282 "UPDATE a_users SET is_super_verified = $1 WHERE id = $2::bigint",
283 params![&{ if x { 1 } else { 0 } }, &id.printable()]
284 );
285
286 if let Err(e) = res {
287 return Err(Error::DatabaseError(e.to_string()));
288 }
289
290 self.cache_clear_user(&other_user).await;
291
292 self.create_audit_log_entry(AuditLogEntry::new(
294 other_user.id.clone(),
295 format!(
296 "invoked `update_user_super_verified_status` with x value `{}` and y value `{}`",
297 other_user.id.printable(),
298 x
299 ),
300 ))
301 .await?;
302
303 Ok(())
305 }
306
307 pub async fn update_user_is_deactivated(&self, id: &Id, x: bool, user: User) -> Result<()> {
308 if id != &user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
309 return Err(Error::NotAllowed);
310 }
311
312 let other_user = self.get_user_by_id(&id).await?;
313
314 let conn = match self.0.connect().await {
315 Ok(c) => c,
316 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
317 };
318
319 let res = execute!(
320 &conn,
321 "UPDATE a_users SET is_deactivated = $1 WHERE id = $2::bigint",
322 params![&{ if x { 1 } else { 0 } }, &id.printable()]
323 );
324
325 if let Err(e) = res {
326 return Err(Error::DatabaseError(e.to_string()));
327 }
328
329 self.cache_clear_user(&other_user).await;
330
331 if user.id != other_user.id {
333 self.create_audit_log_entry(AuditLogEntry::new(
334 user.id,
335 format!(
336 "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
337 other_user.id, x
338 ),
339 ))
340 .await?;
341 }
342
343 Ok(())
345 }
346
347 pub async fn update_user_password(
348 &self,
349 id: &Id,
350 from: String,
351 to: String,
352 user: User,
353 force: bool,
354 ) -> Result<()> {
355 if !user.check_password(from.clone()) && !force {
357 return Err(Error::MiscError("Password does not match".to_string()));
358 }
359
360 let conn = match self.0.connect().await {
362 Ok(c) => c,
363 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
364 };
365
366 let new_salt = salt();
367 let new_password = hash_salted(to, new_salt.clone());
368 let res = execute!(
369 &conn,
370 "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
371 params![&new_password.as_str(), &new_salt.as_str(), &id.printable()]
372 );
373
374 if let Err(e) = res {
375 return Err(Error::DatabaseError(e.to_string()));
376 }
377
378 self.cache_clear_user(&user).await;
379 Ok(())
380 }
381
382 pub async fn update_user_username(&self, id: &Id, to: String, user: User) -> Result<()> {
383 if to.len() < 2 {
385 return Err(Error::DataTooShort("username".to_string()));
386 } else if to.len() > 32 {
387 return Err(Error::DataTooLong("username".to_string()));
388 }
389
390 if self.0.0.banned_usernames.contains(&to) {
391 return Err(Error::MiscError("This username cannot be used".to_string()));
392 }
393
394 let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
395 .multi_line(true)
396 .build()
397 .unwrap();
398
399 if regex.captures(&to).is_some() {
400 return Err(Error::MiscError(
401 "This username contains invalid characters".to_string(),
402 ));
403 }
404
405 let conn = match self.0.connect().await {
407 Ok(c) => c,
408 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
409 };
410
411 let res = execute!(
412 &conn,
413 "UPDATE a_users SET username = $1 WHERE id = $2::bigint",
414 params![&to.to_lowercase(), &id.printable()]
415 );
416
417 if let Err(e) = res {
418 return Err(Error::DatabaseError(e.to_string()));
419 }
420
421 self.cache_clear_user(&user).await;
422 Ok(())
423 }
424
425 pub fn check_totp(&self, ua: &User, code: &str) -> bool {
427 let totp = ua.totp(Some(
428 self.0
429 .0
430 .host
431 .replace("http://", "")
432 .replace("https://", "")
433 .replace(":", "_"),
434 ));
435
436 if let Some(totp) = totp {
437 return !code.is_empty()
438 && (totp.check_current(code).unwrap()
439 | ua.recovery_codes.contains(&code.to_string()));
440 }
441
442 true
443 }
444
445 pub fn generate_totp_recovery_codes() -> Vec<String> {
447 let mut out: Vec<String> = Vec::new();
448
449 for _ in 0..9 {
450 out.push(salt())
451 }
452
453 out
454 }
455
456 pub async fn update_user_totp(
463 &self,
464 id: &Id,
465 secret: &str,
466 recovery: &Vec<String>,
467 ) -> Result<()> {
468 let user = self.get_user_by_id(&id).await?;
469
470 let conn = match self.0.connect().await {
472 Ok(c) => c,
473 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
474 };
475
476 let res = execute!(
477 &conn,
478 "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
479 params![
480 &secret,
481 &serde_json::to_string(recovery).unwrap(),
482 &id.printable()
483 ]
484 );
485
486 if let Err(e) = res {
487 return Err(Error::DatabaseError(e.to_string()));
488 }
489
490 self.cache_clear_user(&user).await;
491 Ok(())
492 }
493
494 pub async fn enable_totp(&self, id: &Id, user: User) -> Result<(String, String, Vec<String>)> {
503 let other_user = self.get_user_by_id(&id).await?;
504
505 if other_user.id != user.id {
506 if other_user
507 .permissions
508 .contains(&UserPermission::ManageUsers)
509 {
510 self.create_audit_log_entry(AuditLogEntry::new(
512 user.id,
513 format!("invoked `enable_totp` with x value `{}`", other_user.id,),
514 ))
515 .await?;
516 } else {
517 return Err(Error::NotAllowed);
518 }
519 }
520
521 let secret = totp_rs::Secret::default().to_string();
522 let recovery = Self::generate_totp_recovery_codes();
523 self.update_user_totp(id, &secret, &recovery).await?;
524
525 let other_user = self.get_user_by_id(&id).await?;
527
528 let totp = other_user.totp(Some(
530 self.0
531 .0
532 .host
533 .replace("http://", "")
534 .replace("https://", "")
535 .replace(":", "_"),
536 ));
537
538 if totp.is_none() {
539 return Err(Error::MiscError("Failed to get TOTP code".to_string()));
540 }
541
542 let totp = totp.unwrap();
543
544 let qr = match totp.get_qr_base64() {
546 Ok(q) => q,
547 Err(e) => return Err(Error::MiscError(e.to_string())),
548 };
549
550 Ok((totp.get_secret_base32(), qr, recovery))
552 }
553
554 pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
556 if user.principal_org == Id::default() {
557 return None;
558 }
559
560 if let Ok(x) = self.get_organization_by_id(&user.principal_org).await {
561 Some(x)
562 } else {
563 self.update_user_principal_org(&user.id, None)
564 .await
565 .expect("failed to clear user principal org");
566
567 None
568 }
569 }
570
571 pub async fn cache_clear_user(&self, user: &User) {
572 self.0.1.remove(format!("srmp.user:{}", user.id)).await;
573 self.0
574 .1
575 .remove(format!("srmp.user:{}", user.username))
576 .await;
577 }
578
579 auto_method!(update_user_permissions(Id :: usize, Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
580 auto_method!(update_user_tokens(Id :: usize, Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
581 auto_method!(update_user_settings(Id :: usize, UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
582 auto_method!(update_user_ban_reason(Id :: usize, &str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
583 auto_method!(update_user_ban_expire(Id :: usize, i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
584 auto_method!(update_user_checkouts(Id :: usize, Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
585 auto_method!(update_user_last_policy_consent(Id :: usize, i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
586 auto_method!(update_user_linked_accounts(Id :: usize, UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
587 auto_method!(update_user_badges(Id :: usize, Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2::bigint" --serde --cache-key-tmpl=cache_clear_user);
588 auto_method!(update_user_principal_org(Id :: usize, Option<String>)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
589 auto_method!(update_user_org_as_tenant(Id :: usize, i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
590
591 auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1::bigint" --name="user" --returns=User);
592 auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
593
594 auto_method!(update_user_notification_count(Id :: usize, i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2::bigint" --cache-key-tmpl=cache_clear_user);
595 auto_method!(incr_user_notifications(Id :: usize)@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --incr);
596 auto_method!(decr_user_notifications(Id :: usize)@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --decr=notification_count);
597
598 auto_method!(incr_user_org_creation_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --incr);
599 auto_method!(decr_user_org_creation_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
600
601 auto_method!(incr_user_org_user_register_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --incr);
602 auto_method!(decr_user_org_user_register_credits(Id :: usize)@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1::bigint" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
603}