atuin_server_postgres/
lib.rs

1use std::collections::HashMap;
2use std::ops::Range;
3
4use rand::Rng;
5
6use async_trait::async_trait;
7use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
8use atuin_common::utils::crypto_random_string;
9use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
10use atuin_server_database::{Database, DbError, DbResult, DbSettings};
11use futures_util::TryStreamExt;
12use sqlx::Row;
13use sqlx::postgres::PgPoolOptions;
14
15use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
16use tracing::{instrument, trace};
17use uuid::Uuid;
18use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
19
20mod wrappers;
21
22const MIN_PG_VERSION: u32 = 14;
23
24#[derive(Clone)]
25pub struct Postgres {
26    pool: sqlx::Pool<sqlx::postgres::Postgres>,
27}
28
29fn fix_error(error: sqlx::Error) -> DbError {
30    match error {
31        sqlx::Error::RowNotFound => DbError::NotFound,
32        error => DbError::Other(error.into()),
33    }
34}
35
36#[async_trait]
37impl Database for Postgres {
38    async fn new(settings: &DbSettings) -> DbResult<Self> {
39        let pool = PgPoolOptions::new()
40            .max_connections(100)
41            .connect(settings.db_uri.as_str())
42            .await
43            .map_err(fix_error)?;
44
45        // Call server_version_num to get the DB server's major version number
46        // The call returns None for servers older than 8.x.
47        let pg_major_version: u32 = pool
48            .acquire()
49            .await
50            .map_err(fix_error)?
51            .server_version_num()
52            .ok_or(DbError::Other(eyre::Report::msg(
53                "could not get PostgreSQL version",
54            )))?
55            / 10000;
56
57        if pg_major_version < MIN_PG_VERSION {
58            return Err(DbError::Other(eyre::Report::msg(format!(
59                "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
60            ))));
61        }
62
63        sqlx::migrate!("./migrations")
64            .run(&pool)
65            .await
66            .map_err(|error| DbError::Other(error.into()))?;
67
68        Ok(Self { pool })
69    }
70
71    #[instrument(skip_all)]
72    async fn get_session(&self, token: &str) -> DbResult<Session> {
73        sqlx::query_as("select id, user_id, token from sessions where token = $1")
74            .bind(token)
75            .fetch_one(&self.pool)
76            .await
77            .map_err(fix_error)
78            .map(|DbSession(session)| session)
79    }
80
81    #[instrument(skip_all)]
82    async fn get_user(&self, username: &str) -> DbResult<User> {
83        sqlx::query_as(
84            "select id, username, email, password, verified_at from users where username = $1",
85        )
86        .bind(username)
87        .fetch_one(&self.pool)
88        .await
89        .map_err(fix_error)
90        .map(|DbUser(user)| user)
91    }
92
93    #[instrument(skip_all)]
94    async fn user_verified(&self, id: i64) -> DbResult<bool> {
95        let res: (bool,) =
96            sqlx::query_as("select verified_at is not null from users where id = $1")
97                .bind(id)
98                .fetch_one(&self.pool)
99                .await
100                .map_err(fix_error)?;
101
102        Ok(res.0)
103    }
104
105    #[instrument(skip_all)]
106    async fn verify_user(&self, id: i64) -> DbResult<()> {
107        sqlx::query(
108            "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
109        )
110        .bind(id)
111        .execute(&self.pool)
112        .await
113        .map_err(fix_error)?;
114
115        Ok(())
116    }
117
118    /// Return a valid verification token for the user
119    /// If the user does not have any token, create one, insert it, and return
120    /// If the user has a token, but it's invalid, delete it, create a new one, return
121    /// If the user already has a valid token, return it
122    #[instrument(skip_all)]
123    async fn user_verification_token(&self, id: i64) -> DbResult<String> {
124        const TOKEN_VALID_MINUTES: i64 = 15;
125
126        // First we check if there is a verification token
127        let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
128            "select token, valid_until from user_verification_token where user_id = $1",
129        )
130        .bind(id)
131        .fetch_optional(&self.pool)
132        .await
133        .map_err(fix_error)?;
134
135        let token = if let Some((token, valid_until)) = token {
136            trace!("Token for user {id} valid until {valid_until}");
137
138            // We have a token, AND it's still valid
139            if valid_until > time::OffsetDateTime::now_utc() {
140                token
141            } else {
142                // token has expired. generate a new one, return it
143                let token = crypto_random_string::<24>();
144
145                sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
146                    .bind(id)
147                    .bind(&token)
148                    .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
149                    .execute(&self.pool)
150                    .await
151                    .map_err(fix_error)?;
152
153                token
154            }
155        } else {
156            // No token in the database! Generate one, insert it
157            let token = crypto_random_string::<24>();
158
159            sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
160                .bind(id)
161                .bind(&token)
162                .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
163                .execute(&self.pool)
164                .await
165                .map_err(fix_error)?;
166
167            token
168        };
169
170        Ok(token)
171    }
172
173    #[instrument(skip_all)]
174    async fn get_session_user(&self, token: &str) -> DbResult<User> {
175        sqlx::query_as(
176            "select users.id, users.username, users.email, users.password, users.verified_at from users 
177            inner join sessions 
178            on users.id = sessions.user_id 
179            and sessions.token = $1",
180        )
181        .bind(token)
182        .fetch_one(&self.pool)
183        .await
184        .map_err(fix_error)
185        .map(|DbUser(user)| user)
186    }
187
188    #[instrument(skip_all)]
189    async fn count_history(&self, user: &User) -> DbResult<i64> {
190        // The cache is new, and the user might not yet have a cache value.
191        // They will have one as soon as they post up some new history, but handle that
192        // edge case.
193
194        let res: (i64,) = sqlx::query_as(
195            "select count(1) from history
196            where user_id = $1",
197        )
198        .bind(user.id)
199        .fetch_one(&self.pool)
200        .await
201        .map_err(fix_error)?;
202
203        Ok(res.0)
204    }
205
206    #[instrument(skip_all)]
207    async fn total_history(&self) -> DbResult<i64> {
208        // The cache is new, and the user might not yet have a cache value.
209        // They will have one as soon as they post up some new history, but handle that
210        // edge case.
211
212        let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
213            .fetch_optional(&self.pool)
214            .await
215            .map_err(fix_error)?
216            .unwrap_or((0,));
217
218        Ok(res.0)
219    }
220
221    #[instrument(skip_all)]
222    async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
223        let res: (i32,) = sqlx::query_as(
224            "select total from total_history_count_user
225            where user_id = $1",
226        )
227        .bind(user.id)
228        .fetch_one(&self.pool)
229        .await
230        .map_err(fix_error)?;
231
232        Ok(res.0 as i64)
233    }
234
235    async fn delete_store(&self, user: &User) -> DbResult<()> {
236        let mut tx = self.pool.begin().await.map_err(fix_error)?;
237
238        sqlx::query(
239            "delete from store
240            where user_id = $1",
241        )
242        .bind(user.id)
243        .execute(&mut *tx)
244        .await
245        .map_err(fix_error)?;
246
247        sqlx::query(
248            "delete from store_idx_cache
249            where user_id = $1",
250        )
251        .bind(user.id)
252        .execute(&mut *tx)
253        .await
254        .map_err(fix_error)?;
255
256        tx.commit().await.map_err(fix_error)?;
257
258        Ok(())
259    }
260
261    async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
262        sqlx::query(
263            "update history
264            set deleted_at = $3
265            where user_id = $1
266            and client_id = $2
267            and deleted_at is null", // don't just keep setting it
268        )
269        .bind(user.id)
270        .bind(id)
271        .bind(OffsetDateTime::now_utc())
272        .fetch_all(&self.pool)
273        .await
274        .map_err(fix_error)?;
275
276        Ok(())
277    }
278
279    #[instrument(skip_all)]
280    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
281        // The cache is new, and the user might not yet have a cache value.
282        // They will have one as soon as they post up some new history, but handle that
283        // edge case.
284
285        let res = sqlx::query(
286            "select client_id from history 
287            where user_id = $1
288            and deleted_at is not null",
289        )
290        .bind(user.id)
291        .fetch_all(&self.pool)
292        .await
293        .map_err(fix_error)?;
294
295        let res = res
296            .iter()
297            .map(|row| row.get::<String, _>("client_id"))
298            .collect();
299
300        Ok(res)
301    }
302
303    #[instrument(skip_all)]
304    async fn count_history_range(
305        &self,
306        user: &User,
307        range: Range<OffsetDateTime>,
308    ) -> DbResult<i64> {
309        let res: (i64,) = sqlx::query_as(
310            "select count(1) from history
311            where user_id = $1
312            and timestamp >= $2::date
313            and timestamp < $3::date",
314        )
315        .bind(user.id)
316        .bind(into_utc(range.start))
317        .bind(into_utc(range.end))
318        .fetch_one(&self.pool)
319        .await
320        .map_err(fix_error)?;
321
322        Ok(res.0)
323    }
324
325    #[instrument(skip_all)]
326    async fn list_history(
327        &self,
328        user: &User,
329        created_after: OffsetDateTime,
330        since: OffsetDateTime,
331        host: &str,
332        page_size: i64,
333    ) -> DbResult<Vec<History>> {
334        let res = sqlx::query_as(
335            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
336            where user_id = $1
337            and hostname != $2
338            and created_at >= $3
339            and timestamp >= $4
340            order by timestamp asc
341            limit $5",
342        )
343        .bind(user.id)
344        .bind(host)
345        .bind(into_utc(created_after))
346        .bind(into_utc(since))
347        .bind(page_size)
348        .fetch(&self.pool)
349        .map_ok(|DbHistory(h)| h)
350        .try_collect()
351        .await
352        .map_err(fix_error)?;
353
354        Ok(res)
355    }
356
357    #[instrument(skip_all)]
358    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
359        let mut tx = self.pool.begin().await.map_err(fix_error)?;
360
361        for i in history {
362            let client_id: &str = &i.client_id;
363            let hostname: &str = &i.hostname;
364            let data: &str = &i.data;
365
366            sqlx::query(
367                "insert into history
368                    (client_id, user_id, hostname, timestamp, data) 
369                values ($1, $2, $3, $4, $5)
370                on conflict do nothing
371                ",
372            )
373            .bind(client_id)
374            .bind(i.user_id)
375            .bind(hostname)
376            .bind(i.timestamp)
377            .bind(data)
378            .execute(&mut *tx)
379            .await
380            .map_err(fix_error)?;
381        }
382
383        tx.commit().await.map_err(fix_error)?;
384
385        Ok(())
386    }
387
388    #[instrument(skip_all)]
389    async fn delete_user(&self, u: &User) -> DbResult<()> {
390        sqlx::query("delete from sessions where user_id = $1")
391            .bind(u.id)
392            .execute(&self.pool)
393            .await
394            .map_err(fix_error)?;
395
396        sqlx::query("delete from history where user_id = $1")
397            .bind(u.id)
398            .execute(&self.pool)
399            .await
400            .map_err(fix_error)?;
401
402        sqlx::query("delete from store where user_id = $1")
403            .bind(u.id)
404            .execute(&self.pool)
405            .await
406            .map_err(fix_error)?;
407
408        sqlx::query("delete from user_verification_token where user_id = $1")
409            .bind(u.id)
410            .execute(&self.pool)
411            .await
412            .map_err(fix_error)?;
413
414        sqlx::query("delete from total_history_count_user where user_id = $1")
415            .bind(u.id)
416            .execute(&self.pool)
417            .await
418            .map_err(fix_error)?;
419
420        sqlx::query("delete from users where id = $1")
421            .bind(u.id)
422            .execute(&self.pool)
423            .await
424            .map_err(fix_error)?;
425
426        Ok(())
427    }
428
429    #[instrument(skip_all)]
430    async fn update_user_password(&self, user: &User) -> DbResult<()> {
431        sqlx::query(
432            "update users
433            set password = $1
434            where id = $2",
435        )
436        .bind(&user.password)
437        .bind(user.id)
438        .execute(&self.pool)
439        .await
440        .map_err(fix_error)?;
441
442        Ok(())
443    }
444
445    #[instrument(skip_all)]
446    async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
447        let email: &str = &user.email;
448        let username: &str = &user.username;
449        let password: &str = &user.password;
450
451        let res: (i64,) = sqlx::query_as(
452            "insert into users
453                (username, email, password)
454            values($1, $2, $3)
455            returning id",
456        )
457        .bind(username)
458        .bind(email)
459        .bind(password)
460        .fetch_one(&self.pool)
461        .await
462        .map_err(fix_error)?;
463
464        Ok(res.0)
465    }
466
467    #[instrument(skip_all)]
468    async fn add_session(&self, session: &NewSession) -> DbResult<()> {
469        let token: &str = &session.token;
470
471        sqlx::query(
472            "insert into sessions
473                (user_id, token)
474            values($1, $2)",
475        )
476        .bind(session.user_id)
477        .bind(token)
478        .execute(&self.pool)
479        .await
480        .map_err(fix_error)?;
481
482        Ok(())
483    }
484
485    #[instrument(skip_all)]
486    async fn get_user_session(&self, u: &User) -> DbResult<Session> {
487        sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
488            .bind(u.id)
489            .fetch_one(&self.pool)
490            .await
491            .map_err(fix_error)
492            .map(|DbSession(session)| session)
493    }
494
495    #[instrument(skip_all)]
496    async fn oldest_history(&self, user: &User) -> DbResult<History> {
497        sqlx::query_as(
498            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
499            where user_id = $1
500            order by timestamp asc
501            limit 1",
502        )
503        .bind(user.id)
504        .fetch_one(&self.pool)
505        .await
506        .map_err(fix_error)
507        .map(|DbHistory(h)| h)
508    }
509
510    #[instrument(skip_all)]
511    async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
512        let mut tx = self.pool.begin().await.map_err(fix_error)?;
513
514        // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
515        // idx without having to make further database queries. Doing the query on this small
516        // amount of data should be much, much faster.
517        //
518        // Worst case, say we get this wrong. We end up caching data that isn't actually the max
519        // idx, so clients upload again. The cache logic can be verified with a sql query anyway :)
520
521        let mut heads = HashMap::<(HostId, &str), u64>::new();
522
523        for i in records {
524            let id = atuin_common::utils::uuid_v7();
525
526            let result = sqlx::query(
527                "insert into store
528                    (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) 
529                values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
530                on conflict do nothing
531                ",
532            )
533            .bind(id)
534            .bind(i.id)
535            .bind(i.host.id)
536            .bind(i.idx as i64)
537            .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
538            .bind(&i.version)
539            .bind(&i.tag)
540            .bind(&i.data.data)
541            .bind(&i.data.content_encryption_key)
542            .bind(user.id)
543            .execute(&mut *tx)
544            .await
545            .map_err(fix_error)?;
546
547            // Only update heads if we actually inserted the record
548            if result.rows_affected() > 0 {
549                heads
550                    .entry((i.host.id, &i.tag))
551                    .and_modify(|e| {
552                        if i.idx > *e {
553                            *e = i.idx
554                        }
555                    })
556                    .or_insert(i.idx);
557            }
558        }
559
560        // we've built the map of heads for this push, so commit it to the database
561        for ((host, tag), idx) in heads {
562            sqlx::query(
563                "insert into store_idx_cache
564                    (user_id, host, tag, idx) 
565                values ($1, $2, $3, $4)
566                on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
567                ",
568            )
569            .bind(user.id)
570            .bind(host)
571            .bind(tag)
572            .bind(idx as i64)
573            .execute(&mut *tx)
574            .await
575            .map_err(fix_error)?;
576        }
577
578        tx.commit().await.map_err(fix_error)?;
579
580        Ok(())
581    }
582
583    #[instrument(skip_all)]
584    async fn next_records(
585        &self,
586        user: &User,
587        host: HostId,
588        tag: String,
589        start: Option<RecordIdx>,
590        count: u64,
591    ) -> DbResult<Vec<Record<EncryptedData>>> {
592        tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
593        let start = start.unwrap_or(0);
594
595        let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
596            "select client_id, host, idx, timestamp, version, tag, data, cek from store
597                    where user_id = $1
598                    and tag = $2
599                    and host = $3
600                    and idx >= $4
601                    order by idx asc
602                    limit $5",
603        )
604        .bind(user.id)
605        .bind(tag.clone())
606        .bind(host)
607        .bind(start as i64)
608        .bind(count as i64)
609        .fetch_all(&self.pool)
610        .await
611        .map_err(fix_error);
612
613        let ret = match records {
614            Ok(records) => {
615                let records: Vec<Record<EncryptedData>> = records
616                    .into_iter()
617                    .map(|f| {
618                        let record: Record<EncryptedData> = f.into();
619                        record
620                    })
621                    .collect();
622
623                records
624            }
625            Err(DbError::NotFound) => {
626                tracing::debug!("no records found in store: {:?}/{}", host, tag);
627                return Ok(vec![]);
628            }
629            Err(e) => return Err(e),
630        };
631
632        Ok(ret)
633    }
634
635    async fn status(&self, user: &User) -> DbResult<RecordStatus> {
636        const STATUS_SQL: &str =
637            "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
638
639        // If IDX_CACHE_ROLLOUT is set, then we
640        // 1. Read the value of the var, use it as a % chance of using the cache
641        // 2. If we use the cache, just read from the cache table
642        // 3. If we don't use the cache, read from the store table
643        // IDX_CACHE_ROLLOUT should be between 0 and 100.
644
645        let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
646        let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
647        let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
648
649        let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
650            tracing::debug!("using idx cache for user {}", user.id);
651            sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
652                .bind(user.id)
653                .fetch_all(&self.pool)
654                .await
655                .map_err(fix_error)?
656        } else {
657            tracing::debug!("using aggregate query for user {}", user.id);
658            sqlx::query_as(STATUS_SQL)
659                .bind(user.id)
660                .fetch_all(&self.pool)
661                .await
662                .map_err(fix_error)?
663        };
664
665        res.sort();
666
667        let mut status = RecordStatus::new();
668
669        for i in res.iter() {
670            status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
671        }
672
673        Ok(status)
674    }
675}
676
677fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
678    let x = x.to_offset(UtcOffset::UTC);
679    PrimitiveDateTime::new(x.date(), x.time())
680}
681
682#[cfg(test)]
683mod tests {
684    use time::macros::datetime;
685
686    use crate::into_utc;
687
688    #[test]
689    fn utc() {
690        let dt = datetime!(2023-09-26 15:11:02 +05:30);
691        assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
692        assert_eq!(into_utc(dt).assume_utc(), dt);
693
694        let dt = datetime!(2023-09-26 15:11:02 -07:00);
695        assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
696        assert_eq!(into_utc(dt).assume_utc(), dt);
697
698        let dt = datetime!(2023-09-26 15:11:02 +00:00);
699        assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
700        assert_eq!(into_utc(dt).assume_utc(), dt);
701    }
702}