1use std::collections::HashMap;
2use std::ops::Range;
3
4use async_trait::async_trait;
5use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
6use atuin_common::utils::crypto_random_string;
7use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
8use atuin_server_database::{Database, DbError, DbResult, DbSettings};
9use futures_util::TryStreamExt;
10use metrics::counter;
11use sqlx::Row;
12use sqlx::postgres::PgPoolOptions;
13
14use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
15use tracing::{instrument, trace};
16use uuid::Uuid;
17use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
18
19mod wrappers;
20
21const MIN_PG_VERSION: u32 = 14;
22
23#[derive(Clone)]
24pub struct Postgres {
25 pool: sqlx::Pool<sqlx::postgres::Postgres>,
26}
27
28fn fix_error(error: sqlx::Error) -> DbError {
29 match error {
30 sqlx::Error::RowNotFound => DbError::NotFound,
31 error => DbError::Other(error.into()),
32 }
33}
34
35#[async_trait]
36impl Database for Postgres {
37 async fn new(settings: &DbSettings) -> DbResult<Self> {
38 let pool = PgPoolOptions::new()
39 .max_connections(100)
40 .connect(settings.db_uri.as_str())
41 .await
42 .map_err(fix_error)?;
43
44 let pg_major_version: u32 = pool
47 .acquire()
48 .await
49 .map_err(fix_error)?
50 .server_version_num()
51 .ok_or(DbError::Other(eyre::Report::msg(
52 "could not get PostgreSQL version",
53 )))?
54 / 10000;
55
56 if pg_major_version < MIN_PG_VERSION {
57 return Err(DbError::Other(eyre::Report::msg(format!(
58 "unsupported PostgreSQL version {}, minimum required is {}",
59 pg_major_version, MIN_PG_VERSION
60 ))));
61 }
62
63 sqlx::migrate!("./migrations")
64 .run(&pool)
65 .await
66 .map_err(|error| DbError::Other(error.into()))?;
67
68 Ok(Self { pool })
69 }
70
71 #[instrument(skip_all)]
72 async fn get_session(&self, token: &str) -> DbResult<Session> {
73 sqlx::query_as("select id, user_id, token from sessions where token = $1")
74 .bind(token)
75 .fetch_one(&self.pool)
76 .await
77 .map_err(fix_error)
78 .map(|DbSession(session)| session)
79 }
80
81 #[instrument(skip_all)]
82 async fn get_user(&self, username: &str) -> DbResult<User> {
83 sqlx::query_as(
84 "select id, username, email, password, verified_at from users where username = $1",
85 )
86 .bind(username)
87 .fetch_one(&self.pool)
88 .await
89 .map_err(fix_error)
90 .map(|DbUser(user)| user)
91 }
92
93 #[instrument(skip_all)]
94 async fn user_verified(&self, id: i64) -> DbResult<bool> {
95 let res: (bool,) =
96 sqlx::query_as("select verified_at is not null from users where id = $1")
97 .bind(id)
98 .fetch_one(&self.pool)
99 .await
100 .map_err(fix_error)?;
101
102 Ok(res.0)
103 }
104
105 #[instrument(skip_all)]
106 async fn verify_user(&self, id: i64) -> DbResult<()> {
107 sqlx::query(
108 "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
109 )
110 .bind(id)
111 .execute(&self.pool)
112 .await
113 .map_err(fix_error)?;
114
115 Ok(())
116 }
117
118 #[instrument(skip_all)]
123 async fn user_verification_token(&self, id: i64) -> DbResult<String> {
124 const TOKEN_VALID_MINUTES: i64 = 15;
125
126 let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
128 "select token, valid_until from user_verification_token where user_id = $1",
129 )
130 .bind(id)
131 .fetch_optional(&self.pool)
132 .await
133 .map_err(fix_error)?;
134
135 let token = if let Some((token, valid_until)) = token {
136 trace!("Token for user {id} valid until {valid_until}");
137
138 if valid_until > time::OffsetDateTime::now_utc() {
140 token
141 } else {
142 let token = crypto_random_string::<24>();
144
145 sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
146 .bind(id)
147 .bind(&token)
148 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
149 .execute(&self.pool)
150 .await
151 .map_err(fix_error)?;
152
153 token
154 }
155 } else {
156 let token = crypto_random_string::<24>();
158
159 sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
160 .bind(id)
161 .bind(&token)
162 .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
163 .execute(&self.pool)
164 .await
165 .map_err(fix_error)?;
166
167 token
168 };
169
170 Ok(token)
171 }
172
173 #[instrument(skip_all)]
174 async fn get_session_user(&self, token: &str) -> DbResult<User> {
175 sqlx::query_as(
176 "select users.id, users.username, users.email, users.password, users.verified_at from users
177 inner join sessions
178 on users.id = sessions.user_id
179 and sessions.token = $1",
180 )
181 .bind(token)
182 .fetch_one(&self.pool)
183 .await
184 .map_err(fix_error)
185 .map(|DbUser(user)| user)
186 }
187
188 #[instrument(skip_all)]
189 async fn count_history(&self, user: &User) -> DbResult<i64> {
190 let res: (i64,) = sqlx::query_as(
195 "select count(1) from history
196 where user_id = $1",
197 )
198 .bind(user.id)
199 .fetch_one(&self.pool)
200 .await
201 .map_err(fix_error)?;
202
203 Ok(res.0)
204 }
205
206 #[instrument(skip_all)]
207 async fn total_history(&self) -> DbResult<i64> {
208 let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
213 .fetch_optional(&self.pool)
214 .await
215 .map_err(fix_error)?
216 .unwrap_or((0,));
217
218 Ok(res.0)
219 }
220
221 #[instrument(skip_all)]
222 async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
223 let res: (i32,) = sqlx::query_as(
224 "select total from total_history_count_user
225 where user_id = $1",
226 )
227 .bind(user.id)
228 .fetch_one(&self.pool)
229 .await
230 .map_err(fix_error)?;
231
232 Ok(res.0 as i64)
233 }
234
235 async fn delete_store(&self, user: &User) -> DbResult<()> {
236 sqlx::query(
237 "delete from store
238 where user_id = $1",
239 )
240 .bind(user.id)
241 .execute(&self.pool)
242 .await
243 .map_err(fix_error)?;
244
245 Ok(())
246 }
247
248 async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
249 sqlx::query(
250 "update history
251 set deleted_at = $3
252 where user_id = $1
253 and client_id = $2
254 and deleted_at is null", )
256 .bind(user.id)
257 .bind(id)
258 .bind(OffsetDateTime::now_utc())
259 .fetch_all(&self.pool)
260 .await
261 .map_err(fix_error)?;
262
263 Ok(())
264 }
265
266 #[instrument(skip_all)]
267 async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
268 let res = sqlx::query(
273 "select client_id from history
274 where user_id = $1
275 and deleted_at is not null",
276 )
277 .bind(user.id)
278 .fetch_all(&self.pool)
279 .await
280 .map_err(fix_error)?;
281
282 let res = res
283 .iter()
284 .map(|row| row.get::<String, _>("client_id"))
285 .collect();
286
287 Ok(res)
288 }
289
290 #[instrument(skip_all)]
291 async fn count_history_range(
292 &self,
293 user: &User,
294 range: Range<OffsetDateTime>,
295 ) -> DbResult<i64> {
296 let res: (i64,) = sqlx::query_as(
297 "select count(1) from history
298 where user_id = $1
299 and timestamp >= $2::date
300 and timestamp < $3::date",
301 )
302 .bind(user.id)
303 .bind(into_utc(range.start))
304 .bind(into_utc(range.end))
305 .fetch_one(&self.pool)
306 .await
307 .map_err(fix_error)?;
308
309 Ok(res.0)
310 }
311
312 #[instrument(skip_all)]
313 async fn list_history(
314 &self,
315 user: &User,
316 created_after: OffsetDateTime,
317 since: OffsetDateTime,
318 host: &str,
319 page_size: i64,
320 ) -> DbResult<Vec<History>> {
321 let res = sqlx::query_as(
322 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
323 where user_id = $1
324 and hostname != $2
325 and created_at >= $3
326 and timestamp >= $4
327 order by timestamp asc
328 limit $5",
329 )
330 .bind(user.id)
331 .bind(host)
332 .bind(into_utc(created_after))
333 .bind(into_utc(since))
334 .bind(page_size)
335 .fetch(&self.pool)
336 .map_ok(|DbHistory(h)| h)
337 .try_collect()
338 .await
339 .map_err(fix_error)?;
340
341 Ok(res)
342 }
343
344 #[instrument(skip_all)]
345 async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
346 let mut tx = self.pool.begin().await.map_err(fix_error)?;
347
348 for i in history {
349 let client_id: &str = &i.client_id;
350 let hostname: &str = &i.hostname;
351 let data: &str = &i.data;
352
353 sqlx::query(
354 "insert into history
355 (client_id, user_id, hostname, timestamp, data)
356 values ($1, $2, $3, $4, $5)
357 on conflict do nothing
358 ",
359 )
360 .bind(client_id)
361 .bind(i.user_id)
362 .bind(hostname)
363 .bind(i.timestamp)
364 .bind(data)
365 .execute(&mut *tx)
366 .await
367 .map_err(fix_error)?;
368 }
369
370 tx.commit().await.map_err(fix_error)?;
371
372 Ok(())
373 }
374
375 #[instrument(skip_all)]
376 async fn delete_user(&self, u: &User) -> DbResult<()> {
377 sqlx::query("delete from sessions where user_id = $1")
378 .bind(u.id)
379 .execute(&self.pool)
380 .await
381 .map_err(fix_error)?;
382
383 sqlx::query("delete from history where user_id = $1")
384 .bind(u.id)
385 .execute(&self.pool)
386 .await
387 .map_err(fix_error)?;
388
389 sqlx::query("delete from store where user_id = $1")
390 .bind(u.id)
391 .execute(&self.pool)
392 .await
393 .map_err(fix_error)?;
394
395 sqlx::query("delete from user_verification_token where user_id = $1")
396 .bind(u.id)
397 .execute(&self.pool)
398 .await
399 .map_err(fix_error)?;
400
401 sqlx::query("delete from total_history_count_user where user_id = $1")
402 .bind(u.id)
403 .execute(&self.pool)
404 .await
405 .map_err(fix_error)?;
406
407 sqlx::query("delete from users where id = $1")
408 .bind(u.id)
409 .execute(&self.pool)
410 .await
411 .map_err(fix_error)?;
412
413 Ok(())
414 }
415
416 #[instrument(skip_all)]
417 async fn update_user_password(&self, user: &User) -> DbResult<()> {
418 sqlx::query(
419 "update users
420 set password = $1
421 where id = $2",
422 )
423 .bind(&user.password)
424 .bind(user.id)
425 .execute(&self.pool)
426 .await
427 .map_err(fix_error)?;
428
429 Ok(())
430 }
431
432 #[instrument(skip_all)]
433 async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
434 let email: &str = &user.email;
435 let username: &str = &user.username;
436 let password: &str = &user.password;
437
438 let res: (i64,) = sqlx::query_as(
439 "insert into users
440 (username, email, password)
441 values($1, $2, $3)
442 returning id",
443 )
444 .bind(username)
445 .bind(email)
446 .bind(password)
447 .fetch_one(&self.pool)
448 .await
449 .map_err(fix_error)?;
450
451 Ok(res.0)
452 }
453
454 #[instrument(skip_all)]
455 async fn add_session(&self, session: &NewSession) -> DbResult<()> {
456 let token: &str = &session.token;
457
458 sqlx::query(
459 "insert into sessions
460 (user_id, token)
461 values($1, $2)",
462 )
463 .bind(session.user_id)
464 .bind(token)
465 .execute(&self.pool)
466 .await
467 .map_err(fix_error)?;
468
469 Ok(())
470 }
471
472 #[instrument(skip_all)]
473 async fn get_user_session(&self, u: &User) -> DbResult<Session> {
474 sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
475 .bind(u.id)
476 .fetch_one(&self.pool)
477 .await
478 .map_err(fix_error)
479 .map(|DbSession(session)| session)
480 }
481
482 #[instrument(skip_all)]
483 async fn oldest_history(&self, user: &User) -> DbResult<History> {
484 sqlx::query_as(
485 "select id, client_id, user_id, hostname, timestamp, data, created_at from history
486 where user_id = $1
487 order by timestamp asc
488 limit 1",
489 )
490 .bind(user.id)
491 .fetch_one(&self.pool)
492 .await
493 .map_err(fix_error)
494 .map(|DbHistory(h)| h)
495 }
496
497 #[instrument(skip_all)]
498 async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
499 let mut tx = self.pool.begin().await.map_err(fix_error)?;
500
501 let mut heads = HashMap::<(HostId, &str), u64>::new();
509
510 for i in records {
511 let id = atuin_common::utils::uuid_v7();
512
513 sqlx::query(
514 "insert into store
515 (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
516 values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
517 on conflict do nothing
518 ",
519 )
520 .bind(id)
521 .bind(i.id)
522 .bind(i.host.id)
523 .bind(i.idx as i64)
524 .bind(i.timestamp as i64) .bind(&i.version)
526 .bind(&i.tag)
527 .bind(&i.data.data)
528 .bind(&i.data.content_encryption_key)
529 .bind(user.id)
530 .execute(&mut *tx)
531 .await
532 .map_err(fix_error)?;
533
534 heads
536 .entry((i.host.id, &i.tag))
537 .and_modify(|e| {
538 if i.idx > *e {
539 *e = i.idx
540 }
541 })
542 .or_insert(i.idx);
543 }
544
545 for ((host, tag), idx) in heads {
547 sqlx::query(
548 "insert into store_idx_cache
549 (user_id, host, tag, idx)
550 values ($1, $2, $3, $4)
551 on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
552 ",
553 )
554 .bind(user.id)
555 .bind(host)
556 .bind(tag)
557 .bind(idx as i64)
558 .execute(&mut *tx)
559 .await
560 .map_err(fix_error)?;
561 }
562
563 tx.commit().await.map_err(fix_error)?;
564
565 Ok(())
566 }
567
568 #[instrument(skip_all)]
569 async fn next_records(
570 &self,
571 user: &User,
572 host: HostId,
573 tag: String,
574 start: Option<RecordIdx>,
575 count: u64,
576 ) -> DbResult<Vec<Record<EncryptedData>>> {
577 tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
578 let start = start.unwrap_or(0);
579
580 let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
581 "select client_id, host, idx, timestamp, version, tag, data, cek from store
582 where user_id = $1
583 and tag = $2
584 and host = $3
585 and idx >= $4
586 order by idx asc
587 limit $5",
588 )
589 .bind(user.id)
590 .bind(tag.clone())
591 .bind(host)
592 .bind(start as i64)
593 .bind(count as i64)
594 .fetch_all(&self.pool)
595 .await
596 .map_err(fix_error);
597
598 let ret = match records {
599 Ok(records) => {
600 let records: Vec<Record<EncryptedData>> = records
601 .into_iter()
602 .map(|f| {
603 let record: Record<EncryptedData> = f.into();
604 record
605 })
606 .collect();
607
608 records
609 }
610 Err(DbError::NotFound) => {
611 tracing::debug!("no records found in store: {:?}/{}", host, tag);
612 return Ok(vec![]);
613 }
614 Err(e) => return Err(e),
615 };
616
617 Ok(ret)
618 }
619
620 async fn status(&self, user: &User) -> DbResult<RecordStatus> {
621 const STATUS_SQL: &str =
622 "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
623
624 let mut res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
625 .bind(user.id)
626 .fetch_all(&self.pool)
627 .await
628 .map_err(fix_error)?;
629 res.sort();
630
631 let mut cached_res: Vec<(Uuid, String, i64)> =
638 sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
639 .bind(user.id)
640 .fetch_all(&self.pool)
641 .await
642 .map_err(fix_error)?;
643 cached_res.sort();
644
645 let mut status = RecordStatus::new();
646
647 let equal = res == cached_res;
648
649 if equal {
650 counter!("atuin_store_idx_cache_consistent", 1);
651 } else {
652 tracing::debug!(user = user.username, cache_match = equal, res = ?res, cached = ?cached_res, "record store index request");
654 counter!("atuin_store_idx_cache_inconsistent", 1);
655 };
656
657 for i in res.iter() {
658 status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
659 }
660
661 Ok(status)
662 }
663}
664
665fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
666 let x = x.to_offset(UtcOffset::UTC);
667 PrimitiveDateTime::new(x.date(), x.time())
668}
669
670#[cfg(test)]
671mod tests {
672 use time::macros::datetime;
673
674 use crate::into_utc;
675
676 #[test]
677 fn utc() {
678 let dt = datetime!(2023-09-26 15:11:02 +05:30);
679 assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
680 assert_eq!(into_utc(dt).assume_utc(), dt);
681
682 let dt = datetime!(2023-09-26 15:11:02 -07:00);
683 assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
684 assert_eq!(into_utc(dt).assume_utc(), dt);
685
686 let dt = datetime!(2023-09-26 15:11:02 +00:00);
687 assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
688 assert_eq!(into_utc(dt).assume_utc(), dt);
689 }
690}