1use crate::{
2 DataManager,
3 model::{
4 AuditLogEntry, Error, Result, Token, User, UserBadge, UserLinkedAccounts, UserPermission,
5 UserSettings, organizations::Organization,
6 },
7};
8use oiseau::cache::Cache;
9use oiseau::{PostgresRow, execute, get, params, query_row};
10use tetratto_core::auto_method;
11use tetratto_shared::hash::{hash_salted, salt};
12
13impl DataManager {
14 pub(crate) fn get_user_from_row(x: &PostgresRow) -> User {
16 User {
17 id: get!(x->0(i64)) as usize,
18 created: get!(x->1(i64)) as usize,
19 username: get!(x->2(String)),
20 password: get!(x->3(String)),
21 salt: get!(x->4(String)),
22 settings: serde_json::from_str(&get!(x->5(String)).to_string()).unwrap(),
23 tokens: serde_json::from_str(&get!(x->6(String)).to_string()).unwrap(),
24 permissions: serde_json::from_str(&get!(x->7(String)).to_string()).unwrap(),
25 is_verified: get!(x->8(i32)) as i8 == 1,
26 notification_count: {
27 let x = get!(x->9(i32)) as usize;
28 if x > usize::MAX - 1000 { 0 } else { x }
30 },
31 totp: get!(x->10(String)),
32 recovery_codes: serde_json::from_str(&get!(x->11(String)).to_string()).unwrap(),
33 stripe_id: get!(x->12(String)),
34 ban_reason: get!(x->13(String)),
35 ban_expire: get!(x->14(i64)) as usize,
36 is_deactivated: get!(x->15(i32)) as i8 == 1,
37 checkouts: serde_json::from_str(&get!(x->16(String)).to_string()).unwrap(),
38 last_policy_consent: get!(x->17(i64)) as usize,
39 linked_accounts: serde_json::from_str(&get!(x->18(String)).to_string()).unwrap(),
40 badges: serde_json::from_str(&get!(x->19(String)).to_string()).unwrap(),
41 principal_org: get!(x->20(i64)) as usize,
42 org_as_tenant: get!(x->21(i32)) as i8 == 1,
43 org_creation_credits: get!(x->22(i32)),
44 }
45 }
46
47 auto_method!(get_user_by_id(usize as i64)@get_user_from_row -> "SELECT * FROM a_users WHERE id = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
48 auto_method!(get_user_by_username(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User --cache-key-tmpl="srmp.user:{}");
49 auto_method!(get_user_by_username_no_cache(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE username = $1" --name="user" --returns=User);
50
51 pub async fn get_user_by_id_with_void(&self, id: usize) -> Result<User> {
56 let conn = match self.0.connect().await {
57 Ok(c) => c,
58 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
59 };
60
61 let res = query_row!(
62 &conn,
63 "SELECT * FROM a_users WHERE id = $1",
64 &[&(id as i64)],
65 |x| Ok(Self::get_user_from_row(x))
66 );
67
68 if res.is_err() {
69 return Ok(User::deleted());
70 }
72
73 Ok(res.unwrap())
74 }
75
76 pub async fn get_user_by_token(&self, token: &str) -> Result<User> {
81 let conn = match self.0.connect().await {
82 Ok(c) => c,
83 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
84 };
85
86 let res = query_row!(
87 &conn,
88 "SELECT * FROM a_users WHERE tokens LIKE $1",
89 &[&format!("%\"{token}\"%")],
90 |x| Ok(Self::get_user_from_row(x))
91 );
92
93 if res.is_err() {
94 return Err(Error::UserNotFound);
95 }
96
97 Ok(res.unwrap())
98 }
99
100 pub async fn create_user(&self, mut data: User) -> Result<User> {
105 if !self.0.0.security.registration_enabled {
106 return Err(Error::RegistrationDisabled);
107 }
108
109 data.username = data.username.to_lowercase();
110
111 if data.username.len() < 2 {
113 return Err(Error::DataTooShort("username".to_string()));
114 } else if data.username.len() > 32 {
115 return Err(Error::DataTooLong("username".to_string()));
116 }
117
118 if data.password.len() < 6 {
119 return Err(Error::DataTooShort("password".to_string()));
120 }
121
122 if self.0.0.banned_usernames.contains(&data.username) {
123 return Err(Error::MiscError("This username cannot be used".to_string()));
124 }
125
126 if self.get_user_by_username(&data.username).await.is_ok() {
128 return Err(Error::UsernameInUse);
129 }
130
131 let conn = match self.0.connect().await {
133 Ok(c) => c,
134 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
135 };
136
137 let res = execute!(
138 &conn,
139 "INSERT INTO a_users VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23)",
140 params![
141 &(data.id as i64),
142 &(data.created as i64),
143 &data.username.to_lowercase(),
144 &data.password,
145 &data.salt,
146 &serde_json::to_string(&data.settings).unwrap(),
147 &serde_json::to_string(&data.tokens).unwrap(),
148 &serde_json::to_string(&data.permissions).unwrap(),
149 &(data.is_verified as i32),
150 &0_i32,
151 &String::new(),
152 "[]",
153 &data.stripe_id,
154 &data.ban_reason,
155 &(data.ban_expire as i64),
156 &(data.is_deactivated as i32),
157 &serde_json::to_string(&data.checkouts).unwrap(),
158 &(data.last_policy_consent as i64),
159 &serde_json::to_string(&data.linked_accounts).unwrap(),
160 &serde_json::to_string(&data.badges).unwrap(),
161 &(data.principal_org as i64),
162 &((data.principal_org > 0) as i32),
163 &data.org_creation_credits
164 ]
165 );
166
167 if let Err(e) = res {
168 return Err(Error::DatabaseError(e.to_string()));
169 }
170
171 Ok(data)
172 }
173
174 pub async fn delete_user(&self, id: usize, password: &str, force: bool) -> Result<User> {
181 let user = self.get_user_by_id(id).await?;
182
183 if (hash_salted(password.to_string(), user.salt.clone()) != user.password) && !force {
184 return Err(Error::IncorrectPassword);
185 }
186
187 let conn = match self.0.connect().await {
188 Ok(c) => c,
189 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
190 };
191
192 let res = execute!(&conn, "DELETE FROM a_users WHERE id = $1", &[&(id as i64)]);
193
194 if let Err(e) = res {
195 return Err(Error::DatabaseError(e.to_string()));
196 }
197
198 self.cache_clear_user(&user).await;
199
200 let res = execute!(
202 &conn,
203 "DELETE FROM a_notifications WHERE owner = $1",
204 &[&(id as i64)]
205 );
206
207 if let Err(e) = res {
208 return Err(Error::DatabaseError(e.to_string()));
209 }
210
211 let res = execute!(
213 &conn,
214 "DELETE FROM a_user_warnings WHERE receiver = $1",
215 &[&(id as i64)]
216 );
217
218 if let Err(e) = res {
219 return Err(Error::DatabaseError(e.to_string()));
220 }
221
222 for upload in match self.1.get_uploads_by_owner_all(user.id).await {
224 Ok(x) => x,
225 Err(e) => return Err(Error::MiscError(e.to_string())),
226 } {
227 if let Err(e) = self.1.delete_upload(upload.id).await {
228 return Err(Error::MiscError(e.to_string()));
229 }
230 }
231
232 Ok(user)
234 }
235
236 pub async fn update_user_verified_status(&self, id: usize, x: bool, user: User) -> Result<()> {
237 if !user.permissions.contains(&UserPermission::ManageVerified) {
238 return Err(Error::NotAllowed);
239 }
240
241 let other_user = self.get_user_by_id(id).await?;
242
243 let conn = match self.0.connect().await {
244 Ok(c) => c,
245 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
246 };
247
248 let res = execute!(
249 &conn,
250 "UPDATE a_users SET verified = $1 WHERE id = $2",
251 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
252 );
253
254 if let Err(e) = res {
255 return Err(Error::DatabaseError(e.to_string()));
256 }
257
258 self.cache_clear_user(&other_user).await;
259
260 self.create_audit_log_entry(AuditLogEntry::new(
262 user.id,
263 format!(
264 "invoked `update_user_verified_status` with x value `{}` and y value `{}`",
265 other_user.id, x
266 ),
267 ))
268 .await?;
269
270 Ok(())
272 }
273
274 pub async fn update_user_is_deactivated(&self, id: usize, x: bool, user: User) -> Result<()> {
275 if id != user.id && !user.permissions.contains(&UserPermission::ManageUsers) {
276 return Err(Error::NotAllowed);
277 }
278
279 let other_user = self.get_user_by_id(id).await?;
280
281 let conn = match self.0.connect().await {
282 Ok(c) => c,
283 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
284 };
285
286 let res = execute!(
287 &conn,
288 "UPDATE a_users SET is_deactivated = $1 WHERE id = $2",
289 params![&{ if x { 1 } else { 0 } }, &(id as i64)]
290 );
291
292 if let Err(e) = res {
293 return Err(Error::DatabaseError(e.to_string()));
294 }
295
296 self.cache_clear_user(&other_user).await;
297
298 if user.id != other_user.id {
300 self.create_audit_log_entry(AuditLogEntry::new(
301 user.id,
302 format!(
303 "invoked `update_user_is_deactivated` with x value `{}` and y value `{}`",
304 other_user.id, x
305 ),
306 ))
307 .await?;
308 }
309
310 Ok(())
312 }
313
314 pub async fn update_user_password(
315 &self,
316 id: usize,
317 from: String,
318 to: String,
319 user: User,
320 force: bool,
321 ) -> Result<()> {
322 if !user.check_password(from.clone()) && !force {
324 return Err(Error::MiscError("Password does not match".to_string()));
325 }
326
327 let conn = match self.0.connect().await {
329 Ok(c) => c,
330 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
331 };
332
333 let new_salt = salt();
334 let new_password = hash_salted(to, new_salt.clone());
335 let res = execute!(
336 &conn,
337 "UPDATE a_users SET password = $1, salt = $2 WHERE id = $3",
338 params![&new_password.as_str(), &new_salt.as_str(), &(id as i64)]
339 );
340
341 if let Err(e) = res {
342 return Err(Error::DatabaseError(e.to_string()));
343 }
344
345 self.cache_clear_user(&user).await;
346 Ok(())
347 }
348
349 pub async fn update_user_username(&self, id: usize, to: String, user: User) -> Result<()> {
350 if to.len() < 2 {
352 return Err(Error::DataTooShort("username".to_string()));
353 } else if to.len() > 32 {
354 return Err(Error::DataTooLong("username".to_string()));
355 }
356
357 if self.0.0.banned_usernames.contains(&to) {
358 return Err(Error::MiscError("This username cannot be used".to_string()));
359 }
360
361 let regex = regex::RegexBuilder::new(r"[^\w_\-\.!]+")
362 .multi_line(true)
363 .build()
364 .unwrap();
365
366 if regex.captures(&to).is_some() {
367 return Err(Error::MiscError(
368 "This username contains invalid characters".to_string(),
369 ));
370 }
371
372 let conn = match self.0.connect().await {
374 Ok(c) => c,
375 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
376 };
377
378 let res = execute!(
379 &conn,
380 "UPDATE a_users SET username = $1 WHERE id = $2",
381 params![&to.to_lowercase(), &(id as i64)]
382 );
383
384 if let Err(e) = res {
385 return Err(Error::DatabaseError(e.to_string()));
386 }
387
388 self.cache_clear_user(&user).await;
389 Ok(())
390 }
391
392 pub fn check_totp(&self, ua: &User, code: &str) -> bool {
394 let totp = ua.totp(Some(
395 self.0
396 .0
397 .host
398 .replace("http://", "")
399 .replace("https://", "")
400 .replace(":", "_"),
401 ));
402
403 if let Some(totp) = totp {
404 return !code.is_empty()
405 && (totp.check_current(code).unwrap()
406 | ua.recovery_codes.contains(&code.to_string()));
407 }
408
409 true
410 }
411
412 pub fn generate_totp_recovery_codes() -> Vec<String> {
414 let mut out: Vec<String> = Vec::new();
415
416 for _ in 0..9 {
417 out.push(salt())
418 }
419
420 out
421 }
422
423 pub async fn update_user_totp(
430 &self,
431 id: usize,
432 secret: &str,
433 recovery: &Vec<String>,
434 ) -> Result<()> {
435 let user = self.get_user_by_id(id).await?;
436
437 let conn = match self.0.connect().await {
439 Ok(c) => c,
440 Err(e) => return Err(Error::DatabaseConnection(e.to_string())),
441 };
442
443 let res = execute!(
444 &conn,
445 "UPDATE a_users SET totp = $1, recovery_codes = $2 WHERE id = $3",
446 params![
447 &secret,
448 &serde_json::to_string(recovery).unwrap(),
449 &(id as i64)
450 ]
451 );
452
453 if let Err(e) = res {
454 return Err(Error::DatabaseError(e.to_string()));
455 }
456
457 self.cache_clear_user(&user).await;
458 Ok(())
459 }
460
461 pub async fn enable_totp(
470 &self,
471 id: usize,
472 user: User,
473 ) -> Result<(String, String, Vec<String>)> {
474 let other_user = self.get_user_by_id(id).await?;
475
476 if other_user.id != user.id {
477 if other_user
478 .permissions
479 .contains(&UserPermission::ManageUsers)
480 {
481 self.create_audit_log_entry(AuditLogEntry::new(
483 user.id,
484 format!("invoked `enable_totp` with x value `{}`", other_user.id,),
485 ))
486 .await?;
487 } else {
488 return Err(Error::NotAllowed);
489 }
490 }
491
492 let secret = totp_rs::Secret::default().to_string();
493 let recovery = Self::generate_totp_recovery_codes();
494 self.update_user_totp(id, &secret, &recovery).await?;
495
496 let other_user = self.get_user_by_id(id).await?;
498
499 let totp = other_user.totp(Some(
501 self.0
502 .0
503 .host
504 .replace("http://", "")
505 .replace("https://", "")
506 .replace(":", "_"),
507 ));
508
509 if totp.is_none() {
510 return Err(Error::MiscError("Failed to get TOTP code".to_string()));
511 }
512
513 let totp = totp.unwrap();
514
515 let qr = match totp.get_qr_base64() {
517 Ok(q) => q,
518 Err(e) => return Err(Error::MiscError(e.to_string())),
519 };
520
521 Ok((totp.get_secret_base32(), qr, recovery))
523 }
524
525 pub async fn get_principal_org(&self, user: &User) -> Option<Organization> {
527 if user.principal_org == 0 {
528 return None;
529 }
530
531 if let Ok(x) = self.get_organization_by_id(user.principal_org).await {
532 Some(x)
533 } else {
534 self.update_user_principal_org(user.id, 0)
535 .await
536 .expect("failed to clear user principal org");
537
538 None
539 }
540 }
541
542 pub async fn cache_clear_user(&self, user: &User) {
543 self.0.1.remove(format!("srmp.user:{}", user.id)).await;
544 self.0
545 .1
546 .remove(format!("srmp.user:{}", user.username))
547 .await;
548 }
549
550 auto_method!(update_user_permissions(Vec<UserPermission>)@get_user_by_id -> "UPDATE a_users SET permissions = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
551 auto_method!(update_user_tokens(Vec<Token>)@get_user_by_id -> "UPDATE a_users SET tokens = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
552 auto_method!(update_user_settings(UserSettings)@get_user_by_id -> "UPDATE a_users SET settings = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
553 auto_method!(update_user_ban_reason(&str)@get_user_by_id -> "UPDATE a_users SET ban_reason = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
554 auto_method!(update_user_ban_expire(i64)@get_user_by_id -> "UPDATE a_users SET ban_expire = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
555 auto_method!(update_user_checkouts(Vec<String>)@get_user_by_id -> "UPDATE a_users SET checkouts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
556 auto_method!(update_user_last_policy_consent(i64)@get_user_by_id -> "UPDATE a_users SET last_policy_consent = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
557 auto_method!(update_user_linked_accounts(UserLinkedAccounts)@get_user_by_id -> "UPDATE a_users SET linked_accounts = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
558 auto_method!(update_user_badges(Vec<UserBadge>)@get_user_by_id -> "UPDATE a_users SET badges = $1 WHERE id = $2" --serde --cache-key-tmpl=cache_clear_user);
559 auto_method!(update_user_principal_org(i64)@get_user_by_id -> "UPDATE a_users SET principal_org = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
560 auto_method!(update_user_org_as_tenant(i32)@get_user_by_id -> "UPDATE a_users SET org_as_tenant = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
561
562 auto_method!(get_user_by_stripe_id(&str)@get_user_from_row -> "SELECT * FROM a_users WHERE stripe_id = $1" --name="user" --returns=User);
563 auto_method!(update_user_stripe_id(&str)@get_user_by_id -> "UPDATE a_users SET stripe_id = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
564
565 auto_method!(update_user_notification_count(i32)@get_user_by_id -> "UPDATE a_users SET notification_count = $1 WHERE id = $2" --cache-key-tmpl=cache_clear_user);
566 auto_method!(incr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
567 auto_method!(decr_user_notifications()@get_user_by_id -> "UPDATE a_users SET notification_count = notification_count - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=notification_count);
568
569 auto_method!(incr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits + 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --incr);
570 auto_method!(decr_user_org_creation_credits()@get_user_by_id -> "UPDATE a_users SET org_creation_credits = org_creation_credits - 1 WHERE id = $1" --cache-key-tmpl=cache_clear_user --decr=org_creation_credits);
571}