1use crate::{
2 DataManager,
3 model::{
4 AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5 UserSettings, organizations::Organization,
6 },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14 pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16 User {
17 id: get!(x->0(i64)) as usize,
18 created: get!(x->1(i64)) as usize,
19 username: get!(x->2(String)),
20 password: get!(x->3(String)),
21 salt: get!(x->4(String)),
22 settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23 tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24 permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25 is_verified: get!(x->8(i32)) as i8 == 1,
26 notification_count: {
27 let x = get!(x->9(i32)) as usize;
28 if x > usize::MAX - 1000 { 0 } else { x }
30 },
31 totp: get!(x->10(String)),
32 recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33 stripe_id: get!(x->12(String)),
34 ban_reason: get!(x->13(String)),
35 ban_expire: get!(x->14(i64)) as usize,
36 is_deactivated: get!(x->15(i32)) as i8 == 1,
37 checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38 last_policy_consent: get!(x->17(i64)) as usize,
39 linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40 badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41 principal_org: get!(x->20(i64)) as usize,
42 org_as_tenant: get!(x->21(i32)) as i8 == 1,
43 org_creation_credits: get!(x->22(i32)),
44 org_user_register_credits: get!(x->23(i32)),
45 }
46 }
47
48 auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
49 auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
50 auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
51
52 pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
57 let conn = match self.0.connect().await {
58 Ok(c) => c,
59 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
60 };
61
62 let res = query_row!(
63 &conn,
64 "SELECT * FROM a_users WHERE id = $1",
65 &[&(id as i64)],
66 |x| Ok(Self::get_user_from_row(x))
67 );
68
69 if res.is_err() {
70 return Ok(User::deleted());
71 }
73
74 Ok(res.unwrap())
75 }
76
77 pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
82 let conn = match self.0.connect().await {
83 Ok(c) => c,
84 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
85 };
86
87 let res = query_row!(
88 &conn,
89 "SELECT * FROM a_users WHERE tokens LIKE $1",
90 &[&format!("%\"{token}\"%")],
91 |x| Ok(Self::get_user_from_row(x))
92 );
93
94 if res.is_err() {
95 return Err(Error::UserNotFound);
96 }
97
98 Ok(res.unwrap())
99 }
100
101 pub async fn create_user(&self, mut data: User) -> Result<User> {
106 if !self.0.0.security.registration_enabled {
107 return Err(Error::RegistrationDisabled);
108 }
109
110 data.username = data.username.to_lowercase();
111
112 if data.username.len() < 2 {
114 return Err(Error::DataTooShort("username".to_string()));
115 } else if data.username.len() > 32 {
116 return Err(Error::DataTooLong("username".to_string()));
117 }
118
119 if data.password.len() < 6 {
120 return Err(Error::DataTooShort("password".to_string()));
121 }
122
123 if self.0.0.banned_usernames.contains(&data.username) {
124 return Err(Error::MiscError("This username cannot be used".to_string()));
125 }
126
127 if self.get_user_by_username(&data.username).await.is_ok() {
129 return Err(Error::UsernameInUse);
130 }
131
132 let conn = match self.0.connect().await {
134 Ok(c) => c,
135 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
136 };
137
138 let res = execute!(
139 &conn,
140 "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)",
141 params![
142 &(data.id as i64),
143 &(data.created as i64),
144 &data.username.to_lowercase(),
145 &data.password,
146 &data.salt,
147 &serde_json::to_string(&data.settings).unwrap(),
148 &serde_json::to_string(&data.tokens).unwrap(),
149 &serde_json::to_string(&data.permissions).unwrap(),
150 &(data.is_verified as i32),
151 &0_i32,
152 &String::new(),
153 "[]",
154 &data.stripe_id,
155 &data.ban_reason,
156 &(data.ban_expire as i64),
157 &(data.is_deactivated as i32),
158 &serde_json::to_string(&data.checkouts).unwrap(),
159 &(data.last_policy_consent as i64),
160 &serde_json::to_string(&data.linked_accounts).unwrap(),
161 &serde_json::to_string(&data.badges).unwrap(),
162 &(data.principal_org as i64),
163 &((data.principal_org > 0) as i32),
164 &data.org_creation_credits,
165 &data.org_user_register_credits,
166 ]
167 );
168
169 if let Err(e) = res {
170 return Err(Error::DatabaseError(e.to_string()));
171 }
172
173 Ok(data)
174 }
175
176 pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
183 let user = self.get_user_by_id(id).await?;
184
185 if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
186 return Err(Error::IncorrectPassword);
187 }
188
189 let conn = match self.0.connect().await {
190 Ok(c) => c,
191 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
192 };
193
194 let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
195
196 if let Err(e) = res {
197 return Err(Error::DatabaseError(e.to_string()));
198 }
199
200 self.cache_clear_user(&user).await;
201
202 let res = execute!(
204 &conn,
205 "DELETE FROM a_notifications WHERE owner = $1",
206 &[&(id as i64)]
207 );
208
209 if let Err(e) = res {
210 return Err(Error::DatabaseError(e.to_string()));
211 }
212
213 let res = execute!(
215 &conn,
216 "DELETE FROM a_user_warnings WHERE receiver = $1",
217 &[&(id as i64)]
218 );
219
220 if let Err(e) = res {
221 return Err(Error::DatabaseError(e.to_string()));
222 }
223
224 for upload in match self.1.get_uploads_by_owner_all(user.id).await {
226 Ok(x) => x,
227 Err(e) => return Err(Error::MiscError(e.to_string())),
228 } {
229 if let Err(e) = self.1.delete_upload(upload.id).await {
230 return Err(Error::MiscError(e.to_string()));
231 }
232 }
233
234 Ok(user)
236 }
237
238 pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
239 if !user.permissions.contains(&UserPermission::ManageVerified) {
240 return Err(Error::NotAllowed);
241 }
242
243 let other_user = self.get_user_by_id(id).await?;
244
245 let conn = match self.0.connect().await {
246 Ok(c) => c,
247 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
248 };
249
250 let res = execute!(
251 &conn,
252 "UPDATE a_users SET verified = $1 WHERE id = $2",
253 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
254 );
255
256 if let Err(e) = res {
257 return Err(Error::DatabaseError(e.to_string()));
258 }
259
260 self.cache_clear_user(&other_user).await;
261
262 self.create_audit_log_entry(AuditLogEntry::new(
264 user.id,
265 format!(
266 "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
267 other_user.id, x
268 ),
269 ))
270 .await?;
271
272 Ok(())
274 }
275
276 pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
277 if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
278 return Err(Error::NotAllowed);
279 }
280
281 let other_user = self.get_user_by_id(id).await?;
282
283 let conn = match self.0.connect().await {
284 Ok(c) => c,
285 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
286 };
287
288 let res = execute!(
289 &conn,
290 "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
291 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
292 );
293
294 if let Err(e) = res {
295 return Err(Error::DatabaseError(e.to_string()));
296 }
297
298 self.cache_clear_user(&other_user).await;
299
300 if user.id != other_user.id {
302 self.create_audit_log_entry(AuditLogEntry::new(
303 user.id,
304 format!(
305 "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
306 other_user.id, x
307 ),
308 ))
309 .await?;
310 }
311
312 Ok(())
314 }
315
316 pub async fn update_user_password(
317 &self,
318 id: usize,
319 from: String,
320 to: String,
321 user: User,
322 force: bool,
323 ) -> Result<()> {
324 if !user.check_password(from.clone()) && !force {
326 return Err(Error::MiscError("Password does not match".to_string()));
327 }
328
329 let conn = match self.0.connect().await {
331 Ok(c) => c,
332 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
333 };
334
335 let new_salt = salt();
336 let new_password = hash_salted(to, new_salt.clone());
337 let res = execute!(
338 &conn,
339 "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
340 params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
341 );
342
343 if let Err(e) = res {
344 return Err(Error::DatabaseError(e.to_string()));
345 }
346
347 self.cache_clear_user(&user).await;
348 Ok(())
349 }
350
351 pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
352 if to.len() < 2 {
354 return Err(Error::DataTooShort("username".to_string()));
355 } else if to.len() > 32 {
356 return Err(Error::DataTooLong("username".to_string()));
357 }
358
359 if self.0.0.banned_usernames.contains(&to) {
360 return Err(Error::MiscError("This username cannot be used".to_string()));
361 }
362
363 let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
364 .multi_line(true)
365 .build()
366 .unwrap();
367
368 if regex.captures(&to).is_some() {
369 return Err(Error::MiscError(
370 "This username contains invalid characters".to_string(),
371 ));
372 }
373
374 let conn = match self.0.connect().await {
376 Ok(c) => c,
377 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
378 };
379
380 let res = execute!(
381 &conn,
382 "UPDATE a_users SET username = $1 WHERE id = $2",
383 params![&to.to_lowercase(), &(id as i64)]
384 );
385
386 if let Err(e) = res {
387 return Err(Error::DatabaseError(e.to_string()));
388 }
389
390 self.cache_clear_user(&user).await;
391 Ok(())
392 }
393
394 pub fn check_totp(&self, ua: &User, code: &str) -> bool {
396 let totp = ua.totp(Some(
397 self.0
398 .0
399 .host
400 .replace("http://", "")
401 .replace("https://", "")
402 .replace(":", "_"),
403 ));
404
405 if let Some(totp) = totp {
406 return !code.is_empty()
407 && (totp.check_current(code).unwrap()
408 | ua.recovery_codes.contains(&code.to_string()));
409 }
410
411 true
412 }
413
414 pub fn generate_totp_recovery_codes() -> Vec<String> {
416 let mut out: Vec<String> = Vec::new();
417
418 for _ in 0..9 {
419 out.push(salt())
420 }
421
422 out
423 }
424
425 pub async fn update_user_totp(
432 &self,
433 id: usize,
434 secret: &str,
435 recovery: &Vec<String>,
436 ) -> Result<()> {
437 let user = self.get_user_by_id(id).await?;
438
439 let conn = match self.0.connect().await {
441 Ok(c) => c,
442 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
443 };
444
445 let res = execute!(
446 &conn,
447 "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
448 params![
449 &secret,
450 &serde_json::to_string(recovery).unwrap(),
451 &(id as i64)
452 ]
453 );
454
455 if let Err(e) = res {
456 return Err(Error::DatabaseError(e.to_string()));
457 }
458
459 self.cache_clear_user(&user).await;
460 Ok(())
461 }
462
463 pub async fn enable_totp(
472 &self,
473 id: usize,
474 user: User,
475 ) -> Result<(String, String, Vec<String>)> {
476 let other_user = self.get_user_by_id(id).await?;
477
478 if other_user.id != user.id {
479 if other_user
480 .permissions
481 .contains(&UserPermission::ManageUsers)
482 {
483 self.create_audit_log_entry(AuditLogEntry::new(
485 user.id,
486 format!("invoked `enable_totp` with x value `{}`", other_user.id,),
487 ))
488 .await?;
489 } else {
490 return Err(Error::NotAllowed);
491 }
492 }
493
494 let secret = totp_rs::Secret::default().to_string();
495 let recovery = Self::generate_totp_recovery_codes();
496 self.update_user_totp(id, &secret, &recovery).await?;
497
498 let other_user = self.get_user_by_id(id).await?;
500
501 let totp = other_user.totp(Some(
503 self.0
504 .0
505 .host
506 .replace("http://", "")
507 .replace("https://", "")
508 .replace(":", "_"),
509 ));
510
511 if totp.is_none() {
512 return Err(Error::MiscError("Failed to get TOTP code".to_string()));
513 }
514
515 let totp = totp.unwrap();
516
517 let qr = match totp.get_qr_base64() {
519 Ok(q) => q,
520 Err(e) => return Err(Error::MiscError(e.to_string())),
521 };
522
523 Ok((totp.get_secret_base32(), qr, recovery))
525 }
526
527 pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
529 if user.principal_org == 0 {
530 return None;
531 }
532
533 if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
534 Some(x)
535 } else {
536 self.update_user_principal_org(user.id, 0)
537 .await
538 .expect("failed to clear user principal org");
539
540 None
541 }
542 }
543
544 pub async fn cache_clear_user(&self, user: &User) {
545 self.0.1.remove(format!("srmp.user:{}", user.id)).await;
546 self.0
547 .1
548 .remove(format!("srmp.user:{}", user.username))
549 .await;
550 }
551
552 auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
553 auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
554 auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
555 auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
556 auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
557 auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
558 auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
559 auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
560 auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
561 auto_method!(update_user_principal_org(i64)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
562 auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
563
564 auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
565 auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
566
567 auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
568 auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
569 auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
570
571 auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
572 auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
573
574 auto_method!(incr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
575 auto_method!(decr_user_org_user_register_credits()@get_user_by_id -> "UPDATE a_users SET org_user_register_credits = org_user_register_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_user_register_credits);
576}