1use std::collections::HashMap;
2use std::ops::Range;
3
4use rand::Rng;
5
6use async_trait::async_trait;
7use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
8use atuin_common::utils::crypto_random_string;
9use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
10use atuin_server_database::{Database, DbError, DbResult, DbSettings};
11use futures_util::TryStreamExt;
12use sqlx::Row;
13use sqlx::postgres::PgPoolOptions;
14
15use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
16use tracing::{instrument, trace};
17use uuid::Uuid;
18use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
19
20mod wrappers;
21
22const MIN_PG_VERSION: u32 = 14;
23
24#[derive(Clone)]
25pub struct Postgres {
26 pool: sqlx::Pool<sqlx::postgres::Postgres>,
27 read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
29}
30
31impl Postgres {
32 fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
35 self.read_pool.as_ref().unwrap_or(&self.pool)
36 }
37}
38
39fn fix_error(error: sqlx::Error) -> DbError {
40 match error {
41 sqlx::Error::RowNotFound => DbError::NotFound,
42 error => DbError::Other(error.into()),
43 }
44}
45
46#[async_trait]
47impl Database for Postgres {
48 async fn new(settings: &DbSettings) -> DbResult<Self> {
49 let pool = PgPoolOptions::new()
50 .max_connections(100)
51 .connect(settings.db_uri.as_str())
52 .await
53 .map_err(fix_error)?;
54
55 let pg_major_version: u32 = pool
58 .acquire()
59 .await
60 .map_err(fix_error)?
61 .server_version_num()
62 .ok_or(DbError::Other(eyre::Report::msg(
63 "could not get PostgreSQL version",
64 )))?
65 / 10000;
66
67 if pg_major_version < MIN_PG_VERSION {
68 return Err(DbError::Other(eyre::Report::msg(format!(
69 "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
70 ))));
71 }
72
73 sqlx::migrate!("./migrations")
74 .run(&pool)
75 .await
76 .map_err(|error| DbError::Other(error.into()))?;
77
78 let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
80 tracing::info!("Connecting to read replica database");
81 let read_pool = PgPoolOptions::new()
82 .max_connections(100)
83 .connect(read_db_uri.as_str())
84 .await
85 .map_err(fix_error)?;
86
87 let read_pg_major_version: u32 = read_pool
89 .acquire()
90 .await
91 .map_err(fix_error)?
92 .server_version_num()
93 .ok_or(DbError::Other(eyre::Report::msg(
94 "could not get PostgreSQL version from read replica",
95 )))?
96 / 10000;
97
98 if read_pg_major_version < MIN_PG_VERSION {
99 return Err(DbError::Other(eyre::Report::msg(format!(
100 "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
101 ))));
102 }
103
104 Some(read_pool)
105 } else {
106 None
107 };
108
109 Ok(Self { pool, read_pool })
110 }
111
112 #[instrument(skip_all)]
113 async fn get_session(&self, token: &str) -> DbResult<Session> {
114 sqlx::query_as("select id, user_id, token from sessions where token = $1")
115 .bind(token)
116 .fetch_one(self.read_pool())
117 .await
118 .map_err(fix_error)
119 .map(|DbSession(session)| session)
120 }
121
122 #[instrument(skip_all)]
123 async fn get_user(&self, username: &str) -> DbResult<User> {
124 sqlx::query_as(
125 "select id, username, email, password, verified_at from users where username = $1",
126 )
127 .bind(username)
128 .fetch_one(self.read_pool())
129 .await
130 .map_err(fix_error)
131 .map(|DbUser(user)| user)
132 }
133
134 #[instrument(skip_all)]
135 async fn user_verified(&self, id: i64) -> DbResult<bool> {
136 let res: (bool,) =
137 sqlx::query_as("select verified_at is not null from users where id = $1")
138 .bind(id)
139 .fetch_one(self.read_pool())
140 .await
141 .map_err(fix_error)?;
142
143 Ok(res.0)
144 }
145
146 #[instrument(skip_all)]
147 async fn verify_user(&self, id: i64) -> DbResult<()> {
148 sqlx::query(
149 "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
150 )
151 .bind(id)
152 .execute(&self.pool)
153 .await
154 .map_err(fix_error)?;
155
156 Ok(())
157 }
158
159 #[instrument(skip_all)]
164 async fn user_verification_token(&self, id: i64) -> DbResult<String> {
165 const TOKEN_VALID_MINUTES: i64 = 15;
166
167 let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
169 "select token, valid_until from user_verification_token where user_id = $1",
170 )
171 .bind(id)
172 .fetch_optional(&self.pool)
173 .await
174 .map_err(fix_error)?;
175
176 let token = if let Some((token, valid_until)) = token {
177 trace!("Token for user {id} valid until {valid_until}");
178
179 if valid_until > time::OffsetDateTime::now_utc() {
181 token
182 } else {
183 let token = crypto_random_string::<24>();
185
186 sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
187 .bind(id)
188 .bind(&token)
189 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
190 .execute(&self.pool)
191 .await
192 .map_err(fix_error)?;
193
194 token
195 }
196 } else {
197 let token = crypto_random_string::<24>();
199
200 sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
201 .bind(id)
202 .bind(&token)
203 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
204 .execute(&self.pool)
205 .await
206 .map_err(fix_error)?;
207
208 token
209 };
210
211 Ok(token)
212 }
213
214 #[instrument(skip_all)]
215 async fn get_session_user(&self, token: &str) -> DbResult<User> {
216 sqlx::query_as(
217 "select users.id, users.username, users.email, users.password, users.verified_at from users
218 inner join sessions
219 on users.id = sessions.user_id
220 and sessions.token = $1",
221 )
222 .bind(token)
223 .fetch_one(self.read_pool())
224 .await
225 .map_err(fix_error)
226 .map(|DbUser(user)| user)
227 }
228
229 #[instrument(skip_all)]
230 async fn count_history(&self, user: &User) -> DbResult<i64> {
231 let res: (i64,) = sqlx::query_as(
236 "select count(1) from history
237 where user_id = $1",
238 )
239 .bind(user.id)
240 .fetch_one(self.read_pool())
241 .await
242 .map_err(fix_error)?;
243
244 Ok(res.0)
245 }
246
247 #[instrument(skip_all)]
248 async fn total_history(&self) -> DbResult<i64> {
249 let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
254 .fetch_optional(self.read_pool())
255 .await
256 .map_err(fix_error)?
257 .unwrap_or((0,));
258
259 Ok(res.0)
260 }
261
262 #[instrument(skip_all)]
263 async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
264 let res: (i32,) = sqlx::query_as(
265 "select total from total_history_count_user
266 where user_id = $1",
267 )
268 .bind(user.id)
269 .fetch_one(self.read_pool())
270 .await
271 .map_err(fix_error)?;
272
273 Ok(res.0 as i64)
274 }
275
276 async fn delete_store(&self, user: &User) -> DbResult<()> {
277 let mut tx = self.pool.begin().await.map_err(fix_error)?;
278
279 sqlx::query(
280 "delete from store
281 where user_id = $1",
282 )
283 .bind(user.id)
284 .execute(&mut *tx)
285 .await
286 .map_err(fix_error)?;
287
288 sqlx::query(
289 "delete from store_idx_cache
290 where user_id = $1",
291 )
292 .bind(user.id)
293 .execute(&mut *tx)
294 .await
295 .map_err(fix_error)?;
296
297 tx.commit().await.map_err(fix_error)?;
298
299 Ok(())
300 }
301
302 async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
303 sqlx::query(
304 "update history
305 set deleted_at = $3
306 where user_id = $1
307 and client_id = $2
308 and deleted_at is null", )
310 .bind(user.id)
311 .bind(id)
312 .bind(OffsetDateTime::now_utc())
313 .fetch_all(&self.pool)
314 .await
315 .map_err(fix_error)?;
316
317 Ok(())
318 }
319
320 #[instrument(skip_all)]
321 async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
322 let res = sqlx::query(
327 "select client_id from history
328 where user_id = $1
329 and deleted_at is not null",
330 )
331 .bind(user.id)
332 .fetch_all(self.read_pool())
333 .await
334 .map_err(fix_error)?;
335
336 let res = res
337 .iter()
338 .map(|row| row.get::<String, _>("client_id"))
339 .collect();
340
341 Ok(res)
342 }
343
344 #[instrument(skip_all)]
345 async fn count_history_range(
346 &self,
347 user: &User,
348 range: Range<OffsetDateTime>,
349 ) -> DbResult<i64> {
350 let res: (i64,) = sqlx::query_as(
351 "select count(1) from history
352 where user_id = $1
353 and timestamp >= $2::date
354 and timestamp < $3::date",
355 )
356 .bind(user.id)
357 .bind(into_utc(range.start))
358 .bind(into_utc(range.end))
359 .fetch_one(self.read_pool())
360 .await
361 .map_err(fix_error)?;
362
363 Ok(res.0)
364 }
365
366 #[instrument(skip_all)]
367 async fn list_history(
368 &self,
369 user: &User,
370 created_after: OffsetDateTime,
371 since: OffsetDateTime,
372 host: &str,
373 page_size: i64,
374 ) -> DbResult<Vec<History>> {
375 let res = sqlx::query_as(
376 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
377 where user_id = $1
378 and hostname != $2
379 and created_at >= $3
380 and timestamp >= $4
381 order by timestamp asc
382 limit $5",
383 )
384 .bind(user.id)
385 .bind(host)
386 .bind(into_utc(created_after))
387 .bind(into_utc(since))
388 .bind(page_size)
389 .fetch(self.read_pool())
390 .map_ok(|DbHistory(h)| h)
391 .try_collect()
392 .await
393 .map_err(fix_error)?;
394
395 Ok(res)
396 }
397
398 #[instrument(skip_all)]
399 async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
400 let mut tx = self.pool.begin().await.map_err(fix_error)?;
401
402 for i in history {
403 let client_id: &str = &i.client_id;
404 let hostname: &str = &i.hostname;
405 let data: &str = &i.data;
406
407 sqlx::query(
408 "insert into history
409 (client_id, user_id, hostname, timestamp, data)
410 values ($1, $2, $3, $4, $5)
411 on conflict do nothing
412 ",
413 )
414 .bind(client_id)
415 .bind(i.user_id)
416 .bind(hostname)
417 .bind(i.timestamp)
418 .bind(data)
419 .execute(&mut *tx)
420 .await
421 .map_err(fix_error)?;
422 }
423
424 tx.commit().await.map_err(fix_error)?;
425
426 Ok(())
427 }
428
429 #[instrument(skip_all)]
430 async fn delete_user(&self, u: &User) -> DbResult<()> {
431 sqlx::query("delete from sessions where user_id = $1")
432 .bind(u.id)
433 .execute(&self.pool)
434 .await
435 .map_err(fix_error)?;
436
437 sqlx::query("delete from history where user_id = $1")
438 .bind(u.id)
439 .execute(&self.pool)
440 .await
441 .map_err(fix_error)?;
442
443 sqlx::query("delete from store where user_id = $1")
444 .bind(u.id)
445 .execute(&self.pool)
446 .await
447 .map_err(fix_error)?;
448
449 sqlx::query("delete from user_verification_token where user_id = $1")
450 .bind(u.id)
451 .execute(&self.pool)
452 .await
453 .map_err(fix_error)?;
454
455 sqlx::query("delete from total_history_count_user where user_id = $1")
456 .bind(u.id)
457 .execute(&self.pool)
458 .await
459 .map_err(fix_error)?;
460
461 sqlx::query("delete from users where id = $1")
462 .bind(u.id)
463 .execute(&self.pool)
464 .await
465 .map_err(fix_error)?;
466
467 Ok(())
468 }
469
470 #[instrument(skip_all)]
471 async fn update_user_password(&self, user: &User) -> DbResult<()> {
472 sqlx::query(
473 "update users
474 set password = $1
475 where id = $2",
476 )
477 .bind(&user.password)
478 .bind(user.id)
479 .execute(&self.pool)
480 .await
481 .map_err(fix_error)?;
482
483 Ok(())
484 }
485
486 #[instrument(skip_all)]
487 async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
488 let email: &str = &user.email;
489 let username: &str = &user.username;
490 let password: &str = &user.password;
491
492 let res: (i64,) = sqlx::query_as(
493 "insert into users
494 (username, email, password)
495 values($1, $2, $3)
496 returning id",
497 )
498 .bind(username)
499 .bind(email)
500 .bind(password)
501 .fetch_one(&self.pool)
502 .await
503 .map_err(fix_error)?;
504
505 Ok(res.0)
506 }
507
508 #[instrument(skip_all)]
509 async fn add_session(&self, session: &NewSession) -> DbResult<()> {
510 let token: &str = &session.token;
511
512 sqlx::query(
513 "insert into sessions
514 (user_id, token)
515 values($1, $2)",
516 )
517 .bind(session.user_id)
518 .bind(token)
519 .execute(&self.pool)
520 .await
521 .map_err(fix_error)?;
522
523 Ok(())
524 }
525
526 #[instrument(skip_all)]
527 async fn get_user_session(&self, u: &User) -> DbResult<Session> {
528 sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
529 .bind(u.id)
530 .fetch_one(self.read_pool())
531 .await
532 .map_err(fix_error)
533 .map(|DbSession(session)| session)
534 }
535
536 #[instrument(skip_all)]
537 async fn oldest_history(&self, user: &User) -> DbResult<History> {
538 sqlx::query_as(
539 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
540 where user_id = $1
541 order by timestamp asc
542 limit 1",
543 )
544 .bind(user.id)
545 .fetch_one(self.read_pool())
546 .await
547 .map_err(fix_error)
548 .map(|DbHistory(h)| h)
549 }
550
551 #[instrument(skip_all)]
552 async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
553 let mut tx = self.pool.begin().await.map_err(fix_error)?;
554
555 let mut heads = HashMap::<(HostId, &str), u64>::new();
563
564 for i in records {
565 let id = atuin_common::utils::uuid_v7();
566
567 let result = sqlx::query(
568 "insert into store
569 (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
570 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
571 on conflict do nothing
572 ",
573 )
574 .bind(id)
575 .bind(i.id)
576 .bind(i.host.id)
577 .bind(i.idx as i64)
578 .bind(i.timestamp as i64) .bind(&i.version)
580 .bind(&i.tag)
581 .bind(&i.data.data)
582 .bind(&i.data.content_encryption_key)
583 .bind(user.id)
584 .execute(&mut *tx)
585 .await
586 .map_err(fix_error)?;
587
588 if result.rows_affected() > 0 {
590 heads
591 .entry((i.host.id, &i.tag))
592 .and_modify(|e| {
593 if i.idx > *e {
594 *e = i.idx
595 }
596 })
597 .or_insert(i.idx);
598 }
599 }
600
601 for ((host, tag), idx) in heads {
603 sqlx::query(
604 "insert into store_idx_cache
605 (user_id, host, tag, idx)
606 values ($1, $2, $3, $4)
607 on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
608 ",
609 )
610 .bind(user.id)
611 .bind(host)
612 .bind(tag)
613 .bind(idx as i64)
614 .execute(&mut *tx)
615 .await
616 .map_err(fix_error)?;
617 }
618
619 tx.commit().await.map_err(fix_error)?;
620
621 Ok(())
622 }
623
624 #[instrument(skip_all)]
625 async fn next_records(
626 &self,
627 user: &User,
628 host: HostId,
629 tag: String,
630 start: Option<RecordIdx>,
631 count: u64,
632 ) -> DbResult<Vec<Record<EncryptedData>>> {
633 tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
634 let start = start.unwrap_or(0);
635
636 let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
637 "select client_id, host, idx, timestamp, version, tag, data, cek from store
638 where user_id = $1
639 and tag = $2
640 and host = $3
641 and idx >= $4
642 order by idx asc
643 limit $5",
644 )
645 .bind(user.id)
646 .bind(tag.clone())
647 .bind(host)
648 .bind(start as i64)
649 .bind(count as i64)
650 .fetch_all(self.read_pool())
651 .await
652 .map_err(fix_error);
653
654 let ret = match records {
655 Ok(records) => {
656 let records: Vec<Record<EncryptedData>> = records
657 .into_iter()
658 .map(|f| {
659 let record: Record<EncryptedData> = f.into();
660 record
661 })
662 .collect();
663
664 records
665 }
666 Err(DbError::NotFound) => {
667 tracing::debug!("no records found in store: {:?}/{}", host, tag);
668 return Ok(vec![]);
669 }
670 Err(e) => return Err(e),
671 };
672
673 Ok(ret)
674 }
675
676 async fn status(&self, user: &User) -> DbResult<RecordStatus> {
677 const STATUS_SQL: &str =
678 "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
679
680 let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
687 let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
688 let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
689
690 let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
691 tracing::debug!("using idx cache for user {}", user.id);
692 sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
693 .bind(user.id)
694 .fetch_all(self.read_pool())
695 .await
696 .map_err(fix_error)?
697 } else {
698 tracing::debug!("using aggregate query for user {}", user.id);
699 sqlx::query_as(STATUS_SQL)
700 .bind(user.id)
701 .fetch_all(self.read_pool())
702 .await
703 .map_err(fix_error)?
704 };
705
706 res.sort();
707
708 let mut status = RecordStatus::new();
709
710 for i in res.iter() {
711 status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
712 }
713
714 Ok(status)
715 }
716}
717
718fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
719 let x = x.to_offset(UtcOffset::UTC);
720 PrimitiveDateTime::new(x.date(), x.time())
721}
722
723#[cfg(test)]
724mod tests {
725 use time::macros::datetime;
726
727 use crate::into_utc;
728
729 #[test]
730 fn utc() {
731 let dt = datetime!(2023-09-26 15:11:02 +05:30);
732 assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
733 assert_eq!(into_utc(dt).assume_utc(), dt);
734
735 let dt = datetime!(2023-09-26 15:11:02 -07:00);
736 assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
737 assert_eq!(into_utc(dt).assume_utc(), dt);
738
739 let dt = datetime!(2023-09-26 15:11:02 +00:00);
740 assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
741 assert_eq!(into_utc(dt).assume_utc(), dt);
742 }
743}