1use crate::{
2 DataManager,
3 model::{
4 AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5 UserSettings, organizations::Organization,
6 },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14 pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16 User {
17 id: get!(x->0(i64)) as usize,
18 created: get!(x->1(i64)) as usize,
19 username: get!(x->2(String)),
20 password: get!(x->3(String)),
21 salt: get!(x->4(String)),
22 settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23 tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24 permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25 is_verified: get!(x->8(i32)) as i8 == 1,
26 notification_count: {
27 let x = get!(x->9(i32)) as usize;
28 if x > usize::MAX - 1000 { 0 } else { x }
30 },
31 totp: get!(x->10(String)),
32 recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33 stripe_id: get!(x->12(String)),
34 ban_reason: get!(x->13(String)),
35 ban_expire: get!(x->14(i64)) as usize,
36 is_deactivated: get!(x->15(i32)) as i8 == 1,
37 checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38 last_policy_consent: get!(x->17(i64)) as usize,
39 linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40 badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41 principal_org: get!(x->20(i64) default (0)) as usize,
42 org_as_tenant: get!(x->21(i32)) as i8 == 1,
43 org_creation_credits: get!(x->22(i32)),
44 org_user_register_credits: get!(x->23(i32)),
45 }
46 }
47
48 auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
49 auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
50 auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
51
52 pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
57 let conn = match self.0.connect().await {
58 Ok(c) => c,
59 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
60 };
61
62 let res = query_row!(
63 &conn,
64 "SELECT * FROM a_users WHERE id = $1",
65 &[&(id as i64)],
66 |x| Ok(Self::get_user_from_row(x))
67 );
68
69 if res.is_err() {
70 return Ok(User::deleted());
71 }
73
74 Ok(res.unwrap())
75 }
76
77 pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
82 let conn = match self.0.connect().await {
83 Ok(c) => c,
84 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
85 };
86
87 let res = query_row!(
88 &conn,
89 "SELECT * FROM a_users WHERE tokens LIKE $1",
90 &[&format!("%\"{token}\"%")],
91 |x| Ok(Self::get_user_from_row(x))
92 );
93
94 if res.is_err() {
95 return Err(Error::UserNotFound);
96 }
97
98 Ok(res.unwrap())
99 }
100
101 pub async fn create_user(&self, mut data: User) -> Result<User> {
106 if !self.0.0.security.registration_enabled {
107 return Err(Error::RegistrationDisabled);
108 }
109
110 data.username = data.username.to_lowercase();
111
112 if data.username.len() < 2 {
114 return Err(Error::DataTooShort("username".to_string()));
115 } else if data.username.len() > 32 {
116 return Err(Error::DataTooLong("username".to_string()));
117 }
118
119 if data.password.len() < 6 {
120 return Err(Error::DataTooShort("password".to_string()));
121 }
122
123 if self.0.0.banned_usernames.contains(&data.username) {
124 return Err(Error::MiscError("This username cannot be used".to_string()));
125 }
126
127 if self.get_user_by_username(&data.username).await.is_ok() {
129 return Err(Error::UsernameInUse);
130 }
131
132 let conn = match self.0.connect().await {
134 Ok(c) => c,
135 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
136 };
137
138 let res = execute!(
139 &conn,
140 "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)",
141 params![
142 &(data.id as i64),
143 &(data.created as i64),
144 &data.username.to_lowercase(),
145 &data.password,
146 &data.salt,
147 &serde_json::to_string(&data.settings).unwrap(),
148 &serde_json::to_string(&data.tokens).unwrap(),
149 &serde_json::to_string(&data.permissions).unwrap(),
150 &(data.is_verified as i32),
151 &0_i32,
152 &String::new(),
153 "[]",
154 &data.stripe_id,
155 &data.ban_reason,
156 &(data.ban_expire as i64),
157 &(data.is_deactivated as i32),
158 &serde_json::to_string(&data.checkouts).unwrap(),
159 &(data.last_policy_consent as i64),
160 &serde_json::to_string(&data.linked_accounts).unwrap(),
161 &serde_json::to_string(&data.badges).unwrap(),
162 &if data.principal_org != 0 {
163 Some(data.principal_org as i64)
164 } else {
165 None
166 },
167 &((data.principal_org > 0) as i32),
168 &data.org_creation_credits,
169 &data.org_user_register_credits,
170 ]
171 );
172
173 if let Err(e) = res {
174 return Err(Error::DatabaseError(e.to_string()));
175 }
176
177 Ok(data)
178 }
179
180 pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
187 let user = self.get_user_by_id(id).await?;
188
189 if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
190 return Err(Error::IncorrectPassword);
191 }
192
193 let conn = match self.0.connect().await {
194 Ok(c) => c,
195 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
196 };
197
198 let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
199
200 if let Err(e) = res {
201 return Err(Error::DatabaseError(e.to_string()));
202 }
203
204 self.cache_clear_user(&user).await;
205
206 let res = execute!(
208 &conn,
209 "DELETE FROM a_notifications WHERE owner = $1",
210 &[&(id as i64)]
211 );
212
213 if let Err(e) = res {
214 return Err(Error::DatabaseError(e.to_string()));
215 }
216
217 let res = execute!(
219 &conn,
220 "DELETE FROM a_user_warnings WHERE receiver = $1",
221 &[&(id as i64)]
222 );
223
224 if let Err(e) = res {
225 return Err(Error::DatabaseError(e.to_string()));
226 }
227
228 for upload in match self.1.get_uploads_by_owner_all(user.id).await {
230 Ok(x) => x,
231 Err(e) => return Err(Error::MiscError(e.to_string())),
232 } {
233 if let Err(e) = self.1.delete_upload(upload.id).await {
234 return Err(Error::MiscError(e.to_string()));
235 }
236 }
237
238 Ok(user)
240 }
241
242 pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
243 if !user.permissions.contains(&UserPermission::ManageVerified) {
244 return Err(Error::NotAllowed);
245 }
246
247 let other_user = self.get_user_by_id(id).await?;
248
249 let conn = match self.0.connect().await {
250 Ok(c) => c,
251 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
252 };
253
254 let res = execute!(
255 &conn,
256 "UPDATE a_users SET verified = $1 WHERE id = $2",
257 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
258 );
259
260 if let Err(e) = res {
261 return Err(Error::DatabaseError(e.to_string()));
262 }
263
264 self.cache_clear_user(&other_user).await;
265
266 self.create_audit_log_entry(AuditLogEntry::new(
268 user.id,
269 format!(
270 "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
271 other_user.id, x
272 ),
273 ))
274 .await?;
275
276 Ok(())
278 }
279
280 pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
281 if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
282 return Err(Error::NotAllowed);
283 }
284
285 let other_user = self.get_user_by_id(id).await?;
286
287 let conn = match self.0.connect().await {
288 Ok(c) => c,
289 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
290 };
291
292 let res = execute!(
293 &conn,
294 "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
295 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
296 );
297
298 if let Err(e) = res {
299 return Err(Error::DatabaseError(e.to_string()));
300 }
301
302 self.cache_clear_user(&other_user).await;
303
304 if user.id != other_user.id {
306 self.create_audit_log_entry(AuditLogEntry::new(
307 user.id,
308 format!(
309 "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
310 other_user.id, x
311 ),
312 ))
313 .await?;
314 }
315
316 Ok(())
318 }
319
320 pub async fn update_user_password(
321 &self,
322 id: usize,
323 from: String,
324 to: String,
325 user: User,
326 force: bool,
327 ) -> Result<()> {
328 if !user.check_password(from.clone()) && !force {
330 return Err(Error::MiscError("Password does not match".to_string()));
331 }
332
333 let conn = match self.0.connect().await {
335 Ok(c) => c,
336 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
337 };
338
339 let new_salt = salt();
340 let new_password = hash_salted(to, new_salt.clone());
341 let res = execute!(
342 &conn,
343 "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
344 params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
345 );
346
347 if let Err(e) = res {
348 return Err(Error::DatabaseError(e.to_string()));
349 }
350
351 self.cache_clear_user(&user).await;
352 Ok(())
353 }
354
355 pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
356 if to.len() < 2 {
358 return Err(Error::DataTooShort("username".to_string()));
359 } else if to.len() > 32 {
360 return Err(Error::DataTooLong("username".to_string()));
361 }
362
363 if self.0.0.banned_usernames.contains(&to) {
364 return Err(Error::MiscError("This username cannot be used".to_string()));
365 }
366
367 let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
368 .multi_line(true)
369 .build()
370 .unwrap();
371
372 if regex.captures(&to).is_some() {
373 return Err(Error::MiscError(
374 "This username contains invalid characters".to_string(),
375 ));
376 }
377
378 let conn = match self.0.connect().await {
380 Ok(c) => c,
381 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
382 };
383
384 let res = execute!(
385 &conn,
386 "UPDATE a_users SET username = $1 WHERE id = $2",
387 params![&to.to_lowercase(), &(id as i64)]
388 );
389
390 if let Err(e) = res {
391 return Err(Error::DatabaseError(e.to_string()));
392 }
393
394 self.cache_clear_user(&user).await;
395 Ok(())
396 }
397
398 pub fn check_totp(&self, ua: &User, code: &str) -> bool {
400 let totp = ua.totp(Some(
401 self.0
402 .0
403 .host
404 .replace("http://", "")
405 .replace("https://", "")
406 .replace(":", "_"),
407 ));
408
409 if let Some(totp) = totp {
410 return !code.is_empty()
411 && (totp.check_current(code).unwrap()
412 | ua.recovery_codes.contains(&code.to_string()));
413 }
414
415 true
416 }
417
418 pub fn generate_totp_recovery_codes() -> Vec<String> {
420 let mut out: Vec<String> = Vec::new();
421
422 for _ in 0..9 {
423 out.push(salt())
424 }
425
426 out
427 }
428
429 pub async fn update_user_totp(
436 &self,
437 id: usize,
438 secret: &str,
439 recovery: &Vec<String>,
440 ) -> Result<()> {
441 let user = self.get_user_by_id(id).await?;
442
443 let conn = match self.0.connect().await {
445 Ok(c) => c,
446 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
447 };
448
449 let res = execute!(
450 &conn,
451 "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
452 params![
453 &secret,
454 &serde_json::to_string(recovery).unwrap(),
455 &(id as i64)
456 ]
457 );
458
459 if let Err(e) = res {
460 return Err(Error::DatabaseError(e.to_string()));
461 }
462
463 self.cache_clear_user(&user).await;
464 Ok(())
465 }
466
467 pub async fn enable_totp(
476 &self,
477 id: usize,
478 user: User,
479 ) -> Result<(String, String, Vec<String>)> {
480 let other_user = self.get_user_by_id(id).await?;
481
482 if other_user.id != user.id {
483 if other_user
484 .permissions
485 .contains(&UserPermission::ManageUsers)
486 {
487 self.create_audit_log_entry(AuditLogEntry::new(
489 user.id,
490 format!("invoked `enable_totp` with x value `{}`", other_user.id,),
491 ))
492 .await?;
493 } else {
494 return Err(Error::NotAllowed);
495 }
496 }
497
498 let secret = totp_rs::Secret::default().to_string();
499 let recovery = Self::generate_totp_recovery_codes();
500 self.update_user_totp(id, &secret, &recovery).await?;
501
502 let other_user = self.get_user_by_id(id).await?;
504
505 let totp = other_user.totp(Some(
507 self.0
508 .0
509 .host
510 .replace("http://", "")
511 .replace("https://", "")
512 .replace(":", "_"),
513 ));
514
515 if totp.is_none() {
516 return Err(Error::MiscError("Failed to get TOTP code".to_string()));
517 }
518
519 let totp = totp.unwrap();
520
521 let qr = match totp.get_qr_base64() {
523 Ok(q) => q,
524 Err(e) => return Err(Error::MiscError(e.to_string())),
525 };
526
527 Ok((totp.get_secret_base32(), qr, recovery))
529 }
530
531 pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
533 if user.principal_org == 0 {
534 return None;
535 }
536
537 if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
538 Some(x)
539 } else {
540 self.update_user_principal_org(user.id, None)
541 .await
542 .expect("failed to clear user principal org");
543
544 None
545 }
546 }
547
548 pub async fn cache_clear_user(&self, user: &User) {
549 self.0.1.remove(format!("srmp.user:{}", user.id)).await;
550 self.0
551 .1
552 .remove(format!("srmp.user:{}", user.username))
553 .await;
554 }
555
556 auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
557 auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
558 auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
559 auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
560 auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
561 auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
562 auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
563 auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
564 auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
565 auto_method!(update_user_principal_org(Option<i64>)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
566 auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
567
568 auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
569 auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
570
571 auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
572 auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
573 auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
574
575 auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
576 auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
577
578 auto_method!(incr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
579 auto_method!(decr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
580}