1use std::collections::HashMap;
2use std::ops::Range;
3
4use rand::Rng;
5
6use async_trait::async_trait;
7use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
8use atuin_common::utils::crypto_random_string;
9use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
10use atuin_server_database::{Database, DbError, DbResult, DbSettings};
11use futures_util::TryStreamExt;
12use sqlx::Row;
13use sqlx::postgres::PgPoolOptions;
14
15use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
16use tracing::{instrument, trace};
17use uuid::Uuid;
18use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
19
20mod wrappers;
21
22const MIN_PG_VERSION: u32 = 14;
23
24#[derive(Clone)]
25pub struct Postgres {
26 pool: sqlx::Pool<sqlx::postgres::Postgres>,
27}
28
29fn fix_error(error: sqlx::Error) -> DbError {
30 match error {
31 sqlx::Error::RowNotFound => DbError::NotFound,
32 error => DbError::Other(error.into()),
33 }
34}
35
36#[async_trait]
37impl Database for Postgres {
38 async fn new(settings: &DbSettings) -> DbResult<Self> {
39 let pool = PgPoolOptions::new()
40 .max_connections(100)
41 .connect(settings.db_uri.as_str())
42 .await
43 .map_err(fix_error)?;
44
45 let pg_major_version: u32 = pool
48 .acquire()
49 .await
50 .map_err(fix_error)?
51 .server_version_num()
52 .ok_or(DbError::Other(eyre::Report::msg(
53 "could not get PostgreSQL version",
54 )))?
55 / 10000;
56
57 if pg_major_version < MIN_PG_VERSION {
58 return Err(DbError::Other(eyre::Report::msg(format!(
59 "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
60 ))));
61 }
62
63 sqlx::migrate!("./migrations")
64 .run(&pool)
65 .await
66 .map_err(|error| DbError::Other(error.into()))?;
67
68 Ok(Self { pool })
69 }
70
71 #[instrument(skip_all)]
72 async fn get_session(&self, token: &str) -> DbResult<Session> {
73 sqlx::query_as("select id, user_id, token from sessions where token = $1")
74 .bind(token)
75 .fetch_one(&self.pool)
76 .await
77 .map_err(fix_error)
78 .map(|DbSession(session)| session)
79 }
80
81 #[instrument(skip_all)]
82 async fn get_user(&self, username: &str) -> DbResult<User> {
83 sqlx::query_as(
84 "select id, username, email, password, verified_at from users where username = $1",
85 )
86 .bind(username)
87 .fetch_one(&self.pool)
88 .await
89 .map_err(fix_error)
90 .map(|DbUser(user)| user)
91 }
92
93 #[instrument(skip_all)]
94 async fn user_verified(&self, id: i64) -> DbResult<bool> {
95 let res: (bool,) =
96 sqlx::query_as("select verified_at is not null from users where id = $1")
97 .bind(id)
98 .fetch_one(&self.pool)
99 .await
100 .map_err(fix_error)?;
101
102 Ok(res.0)
103 }
104
105 #[instrument(skip_all)]
106 async fn verify_user(&self, id: i64) -> DbResult<()> {
107 sqlx::query(
108 "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
109 )
110 .bind(id)
111 .execute(&self.pool)
112 .await
113 .map_err(fix_error)?;
114
115 Ok(())
116 }
117
118 #[instrument(skip_all)]
123 async fn user_verification_token(&self, id: i64) -> DbResult<String> {
124 const TOKEN_VALID_MINUTES: i64 = 15;
125
126 let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
128 "select token, valid_until from user_verification_token where user_id = $1",
129 )
130 .bind(id)
131 .fetch_optional(&self.pool)
132 .await
133 .map_err(fix_error)?;
134
135 let token = if let Some((token, valid_until)) = token {
136 trace!("Token for user {id} valid until {valid_until}");
137
138 if valid_until > time::OffsetDateTime::now_utc() {
140 token
141 } else {
142 let token = crypto_random_string::<24>();
144
145 sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
146 .bind(id)
147 .bind(&token)
148 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
149 .execute(&self.pool)
150 .await
151 .map_err(fix_error)?;
152
153 token
154 }
155 } else {
156 let token = crypto_random_string::<24>();
158
159 sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
160 .bind(id)
161 .bind(&token)
162 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
163 .execute(&self.pool)
164 .await
165 .map_err(fix_error)?;
166
167 token
168 };
169
170 Ok(token)
171 }
172
173 #[instrument(skip_all)]
174 async fn get_session_user(&self, token: &str) -> DbResult<User> {
175 sqlx::query_as(
176 "select users.id, users.username, users.email, users.password, users.verified_at from users
177 inner join sessions
178 on users.id = sessions.user_id
179 and sessions.token = $1",
180 )
181 .bind(token)
182 .fetch_one(&self.pool)
183 .await
184 .map_err(fix_error)
185 .map(|DbUser(user)| user)
186 }
187
188 #[instrument(skip_all)]
189 async fn count_history(&self, user: &User) -> DbResult<i64> {
190 let res: (i64,) = sqlx::query_as(
195 "select count(1) from history
196 where user_id = $1",
197 )
198 .bind(user.id)
199 .fetch_one(&self.pool)
200 .await
201 .map_err(fix_error)?;
202
203 Ok(res.0)
204 }
205
206 #[instrument(skip_all)]
207 async fn total_history(&self) -> DbResult<i64> {
208 let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
213 .fetch_optional(&self.pool)
214 .await
215 .map_err(fix_error)?
216 .unwrap_or((0,));
217
218 Ok(res.0)
219 }
220
221 #[instrument(skip_all)]
222 async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
223 let res: (i32,) = sqlx::query_as(
224 "select total from total_history_count_user
225 where user_id = $1",
226 )
227 .bind(user.id)
228 .fetch_one(&self.pool)
229 .await
230 .map_err(fix_error)?;
231
232 Ok(res.0 as i64)
233 }
234
235 async fn delete_store(&self, user: &User) -> DbResult<()> {
236 let mut tx = self.pool.begin().await.map_err(fix_error)?;
237
238 sqlx::query(
239 "delete from store
240 where user_id = $1",
241 )
242 .bind(user.id)
243 .execute(&mut *tx)
244 .await
245 .map_err(fix_error)?;
246
247 sqlx::query(
248 "delete from store_idx_cache
249 where user_id = $1",
250 )
251 .bind(user.id)
252 .execute(&mut *tx)
253 .await
254 .map_err(fix_error)?;
255
256 tx.commit().await.map_err(fix_error)?;
257
258 Ok(())
259 }
260
261 async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
262 sqlx::query(
263 "update history
264 set deleted_at = $3
265 where user_id = $1
266 and client_id = $2
267 and deleted_at is null", )
269 .bind(user.id)
270 .bind(id)
271 .bind(OffsetDateTime::now_utc())
272 .fetch_all(&self.pool)
273 .await
274 .map_err(fix_error)?;
275
276 Ok(())
277 }
278
279 #[instrument(skip_all)]
280 async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
281 let res = sqlx::query(
286 "select client_id from history
287 where user_id = $1
288 and deleted_at is not null",
289 )
290 .bind(user.id)
291 .fetch_all(&self.pool)
292 .await
293 .map_err(fix_error)?;
294
295 let res = res
296 .iter()
297 .map(|row| row.get::<String, _>("client_id"))
298 .collect();
299
300 Ok(res)
301 }
302
303 #[instrument(skip_all)]
304 async fn count_history_range(
305 &self,
306 user: &User,
307 range: Range<OffsetDateTime>,
308 ) -> DbResult<i64> {
309 let res: (i64,) = sqlx::query_as(
310 "select count(1) from history
311 where user_id = $1
312 and timestamp >= $2::date
313 and timestamp < $3::date",
314 )
315 .bind(user.id)
316 .bind(into_utc(range.start))
317 .bind(into_utc(range.end))
318 .fetch_one(&self.pool)
319 .await
320 .map_err(fix_error)?;
321
322 Ok(res.0)
323 }
324
325 #[instrument(skip_all)]
326 async fn list_history(
327 &self,
328 user: &User,
329 created_after: OffsetDateTime,
330 since: OffsetDateTime,
331 host: &str,
332 page_size: i64,
333 ) -> DbResult<Vec<History>> {
334 let res = sqlx::query_as(
335 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
336 where user_id = $1
337 and hostname != $2
338 and created_at >= $3
339 and timestamp >= $4
340 order by timestamp asc
341 limit $5",
342 )
343 .bind(user.id)
344 .bind(host)
345 .bind(into_utc(created_after))
346 .bind(into_utc(since))
347 .bind(page_size)
348 .fetch(&self.pool)
349 .map_ok(|DbHistory(h)| h)
350 .try_collect()
351 .await
352 .map_err(fix_error)?;
353
354 Ok(res)
355 }
356
357 #[instrument(skip_all)]
358 async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
359 let mut tx = self.pool.begin().await.map_err(fix_error)?;
360
361 for i in history {
362 let client_id: &str = &i.client_id;
363 let hostname: &str = &i.hostname;
364 let data: &str = &i.data;
365
366 sqlx::query(
367 "insert into history
368 (client_id, user_id, hostname, timestamp, data)
369 values ($1, $2, $3, $4, $5)
370 on conflict do nothing
371 ",
372 )
373 .bind(client_id)
374 .bind(i.user_id)
375 .bind(hostname)
376 .bind(i.timestamp)
377 .bind(data)
378 .execute(&mut *tx)
379 .await
380 .map_err(fix_error)?;
381 }
382
383 tx.commit().await.map_err(fix_error)?;
384
385 Ok(())
386 }
387
388 #[instrument(skip_all)]
389 async fn delete_user(&self, u: &User) -> DbResult<()> {
390 sqlx::query("delete from sessions where user_id = $1")
391 .bind(u.id)
392 .execute(&self.pool)
393 .await
394 .map_err(fix_error)?;
395
396 sqlx::query("delete from history where user_id = $1")
397 .bind(u.id)
398 .execute(&self.pool)
399 .await
400 .map_err(fix_error)?;
401
402 sqlx::query("delete from store where user_id = $1")
403 .bind(u.id)
404 .execute(&self.pool)
405 .await
406 .map_err(fix_error)?;
407
408 sqlx::query("delete from user_verification_token where user_id = $1")
409 .bind(u.id)
410 .execute(&self.pool)
411 .await
412 .map_err(fix_error)?;
413
414 sqlx::query("delete from total_history_count_user where user_id = $1")
415 .bind(u.id)
416 .execute(&self.pool)
417 .await
418 .map_err(fix_error)?;
419
420 sqlx::query("delete from users where id = $1")
421 .bind(u.id)
422 .execute(&self.pool)
423 .await
424 .map_err(fix_error)?;
425
426 Ok(())
427 }
428
429 #[instrument(skip_all)]
430 async fn update_user_password(&self, user: &User) -> DbResult<()> {
431 sqlx::query(
432 "update users
433 set password = $1
434 where id = $2",
435 )
436 .bind(&user.password)
437 .bind(user.id)
438 .execute(&self.pool)
439 .await
440 .map_err(fix_error)?;
441
442 Ok(())
443 }
444
445 #[instrument(skip_all)]
446 async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
447 let email: &str = &user.email;
448 let username: &str = &user.username;
449 let password: &str = &user.password;
450
451 let res: (i64,) = sqlx::query_as(
452 "insert into users
453 (username, email, password)
454 values($1, $2, $3)
455 returning id",
456 )
457 .bind(username)
458 .bind(email)
459 .bind(password)
460 .fetch_one(&self.pool)
461 .await
462 .map_err(fix_error)?;
463
464 Ok(res.0)
465 }
466
467 #[instrument(skip_all)]
468 async fn add_session(&self, session: &NewSession) -> DbResult<()> {
469 let token: &str = &session.token;
470
471 sqlx::query(
472 "insert into sessions
473 (user_id, token)
474 values($1, $2)",
475 )
476 .bind(session.user_id)
477 .bind(token)
478 .execute(&self.pool)
479 .await
480 .map_err(fix_error)?;
481
482 Ok(())
483 }
484
485 #[instrument(skip_all)]
486 async fn get_user_session(&self, u: &User) -> DbResult<Session> {
487 sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
488 .bind(u.id)
489 .fetch_one(&self.pool)
490 .await
491 .map_err(fix_error)
492 .map(|DbSession(session)| session)
493 }
494
495 #[instrument(skip_all)]
496 async fn oldest_history(&self, user: &User) -> DbResult<History> {
497 sqlx::query_as(
498 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
499 where user_id = $1
500 order by timestamp asc
501 limit 1",
502 )
503 .bind(user.id)
504 .fetch_one(&self.pool)
505 .await
506 .map_err(fix_error)
507 .map(|DbHistory(h)| h)
508 }
509
510 #[instrument(skip_all)]
511 async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
512 let mut tx = self.pool.begin().await.map_err(fix_error)?;
513
514 let mut heads = HashMap::<(HostId, &str), u64>::new();
522
523 for i in records {
524 let id = atuin_common::utils::uuid_v7();
525
526 let result = sqlx::query(
527 "insert into store
528 (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
529 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
530 on conflict do nothing
531 ",
532 )
533 .bind(id)
534 .bind(i.id)
535 .bind(i.host.id)
536 .bind(i.idx as i64)
537 .bind(i.timestamp as i64) .bind(&i.version)
539 .bind(&i.tag)
540 .bind(&i.data.data)
541 .bind(&i.data.content_encryption_key)
542 .bind(user.id)
543 .execute(&mut *tx)
544 .await
545 .map_err(fix_error)?;
546
547 if result.rows_affected() > 0 {
549 heads
550 .entry((i.host.id, &i.tag))
551 .and_modify(|e| {
552 if i.idx > *e {
553 *e = i.idx
554 }
555 })
556 .or_insert(i.idx);
557 }
558 }
559
560 for ((host, tag), idx) in heads {
562 sqlx::query(
563 "insert into store_idx_cache
564 (user_id, host, tag, idx)
565 values ($1, $2, $3, $4)
566 on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
567 ",
568 )
569 .bind(user.id)
570 .bind(host)
571 .bind(tag)
572 .bind(idx as i64)
573 .execute(&mut *tx)
574 .await
575 .map_err(fix_error)?;
576 }
577
578 tx.commit().await.map_err(fix_error)?;
579
580 Ok(())
581 }
582
583 #[instrument(skip_all)]
584 async fn next_records(
585 &self,
586 user: &User,
587 host: HostId,
588 tag: String,
589 start: Option<RecordIdx>,
590 count: u64,
591 ) -> DbResult<Vec<Record<EncryptedData>>> {
592 tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
593 let start = start.unwrap_or(0);
594
595 let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
596 "select client_id, host, idx, timestamp, version, tag, data, cek from store
597 where user_id = $1
598 and tag = $2
599 and host = $3
600 and idx >= $4
601 order by idx asc
602 limit $5",
603 )
604 .bind(user.id)
605 .bind(tag.clone())
606 .bind(host)
607 .bind(start as i64)
608 .bind(count as i64)
609 .fetch_all(&self.pool)
610 .await
611 .map_err(fix_error);
612
613 let ret = match records {
614 Ok(records) => {
615 let records: Vec<Record<EncryptedData>> = records
616 .into_iter()
617 .map(|f| {
618 let record: Record<EncryptedData> = f.into();
619 record
620 })
621 .collect();
622
623 records
624 }
625 Err(DbError::NotFound) => {
626 tracing::debug!("no records found in store: {:?}/{}", host, tag);
627 return Ok(vec![]);
628 }
629 Err(e) => return Err(e),
630 };
631
632 Ok(ret)
633 }
634
635 async fn status(&self, user: &User) -> DbResult<RecordStatus> {
636 const STATUS_SQL: &str =
637 "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
638
639 let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
646 let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
647 let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
648
649 let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
650 tracing::debug!("using idx cache for user {}", user.id);
651 sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
652 .bind(user.id)
653 .fetch_all(&self.pool)
654 .await
655 .map_err(fix_error)?
656 } else {
657 tracing::debug!("using aggregate query for user {}", user.id);
658 sqlx::query_as(STATUS_SQL)
659 .bind(user.id)
660 .fetch_all(&self.pool)
661 .await
662 .map_err(fix_error)?
663 };
664
665 res.sort();
666
667 let mut status = RecordStatus::new();
668
669 for i in res.iter() {
670 status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
671 }
672
673 Ok(status)
674 }
675}
676
677fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
678 let x = x.to_offset(UtcOffset::UTC);
679 PrimitiveDateTime::new(x.date(), x.time())
680}
681
682#[cfg(test)]
683mod tests {
684 use time::macros::datetime;
685
686 use crate::into_utc;
687
688 #[test]
689 fn utc() {
690 let dt = datetime!(2023-09-26 15:11:02 +05:30);
691 assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
692 assert_eq!(into_utc(dt).assume_utc(), dt);
693
694 let dt = datetime!(2023-09-26 15:11:02 -07:00);
695 assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
696 assert_eq!(into_utc(dt).assume_utc(), dt);
697
698 let dt = datetime!(2023-09-26 15:11:02 +00:00);
699 assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
700 assert_eq!(into_utc(dt).assume_utc(), dt);
701 }
702}