atuin_server_postgres/
lib.rs

1use std::collections::HashMap;
2use std::ops::Range;
3
4use rand::Rng;
5
6use async_trait::async_trait;
7use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
8use atuin_common::utils::crypto_random_string;
9use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
10use atuin_server_database::{Database, DbError, DbResult, DbSettings};
11use futures_util::TryStreamExt;
12use sqlx::Row;
13use sqlx::postgres::PgPoolOptions;
14
15use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
16use tracing::{instrument, trace};
17use uuid::Uuid;
18use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
19
20mod wrappers;
21
22const MIN_PG_VERSION: u32 = 14;
23
24#[derive(Clone)]
25pub struct Postgres {
26    pool: sqlx::Pool<sqlx::postgres::Postgres>,
27    /// Optional read replica pool for read-only queries
28    read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
29}
30
31impl Postgres {
32    /// Returns the appropriate pool for read operations.
33    /// Uses read_pool if available, otherwise falls back to the primary pool.
34    fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
35        self.read_pool.as_ref().unwrap_or(&self.pool)
36    }
37}
38
39fn fix_error(error: sqlx::Error) -> DbError {
40    match error {
41        sqlx::Error::RowNotFound => DbError::NotFound,
42        error => DbError::Other(error.into()),
43    }
44}
45
46#[async_trait]
47impl Database for Postgres {
48    async fn new(settings: &DbSettings) -> DbResult<Self> {
49        let pool = PgPoolOptions::new()
50            .max_connections(100)
51            .connect(settings.db_uri.as_str())
52            .await
53            .map_err(fix_error)?;
54
55        // Call server_version_num to get the DB server's major version number
56        // The call returns None for servers older than 8.x.
57        let pg_major_version: u32 = pool
58            .acquire()
59            .await
60            .map_err(fix_error)?
61            .server_version_num()
62            .ok_or(DbError::Other(eyre::Report::msg(
63                "could not get PostgreSQL version",
64            )))?
65            / 10000;
66
67        if pg_major_version < MIN_PG_VERSION {
68            return Err(DbError::Other(eyre::Report::msg(format!(
69                "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
70            ))));
71        }
72
73        sqlx::migrate!("./migrations")
74            .run(&pool)
75            .await
76            .map_err(|error| DbError::Other(error.into()))?;
77
78        // Create read replica pool if configured
79        let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
80            tracing::info!("Connecting to read replica database");
81            let read_pool = PgPoolOptions::new()
82                .max_connections(100)
83                .connect(read_db_uri.as_str())
84                .await
85                .map_err(fix_error)?;
86
87            // Verify the read replica is also a supported PostgreSQL version
88            let read_pg_major_version: u32 = read_pool
89                .acquire()
90                .await
91                .map_err(fix_error)?
92                .server_version_num()
93                .ok_or(DbError::Other(eyre::Report::msg(
94                    "could not get PostgreSQL version from read replica",
95                )))?
96                / 10000;
97
98            if read_pg_major_version < MIN_PG_VERSION {
99                return Err(DbError::Other(eyre::Report::msg(format!(
100                    "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
101                ))));
102            }
103
104            Some(read_pool)
105        } else {
106            None
107        };
108
109        Ok(Self { pool, read_pool })
110    }
111
112    #[instrument(skip_all)]
113    async fn get_session(&self, token: &str) -> DbResult<Session> {
114        sqlx::query_as("select id, user_id, token from sessions where token = $1")
115            .bind(token)
116            .fetch_one(self.read_pool())
117            .await
118            .map_err(fix_error)
119            .map(|DbSession(session)| session)
120    }
121
122    #[instrument(skip_all)]
123    async fn get_user(&self, username: &str) -> DbResult<User> {
124        sqlx::query_as(
125            "select id, username, email, password, verified_at from users where username = $1",
126        )
127        .bind(username)
128        .fetch_one(self.read_pool())
129        .await
130        .map_err(fix_error)
131        .map(|DbUser(user)| user)
132    }
133
134    #[instrument(skip_all)]
135    async fn user_verified(&self, id: i64) -> DbResult<bool> {
136        let res: (bool,) =
137            sqlx::query_as("select verified_at is not null from users where id = $1")
138                .bind(id)
139                .fetch_one(self.read_pool())
140                .await
141                .map_err(fix_error)?;
142
143        Ok(res.0)
144    }
145
146    #[instrument(skip_all)]
147    async fn verify_user(&self, id: i64) -> DbResult<()> {
148        sqlx::query(
149            "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
150        )
151        .bind(id)
152        .execute(&self.pool)
153        .await
154        .map_err(fix_error)?;
155
156        Ok(())
157    }
158
159    /// Return a valid verification token for the user
160    /// If the user does not have any token, create one, insert it, and return
161    /// If the user has a token, but it's invalid, delete it, create a new one, return
162    /// If the user already has a valid token, return it
163    #[instrument(skip_all)]
164    async fn user_verification_token(&self, id: i64) -> DbResult<String> {
165        const TOKEN_VALID_MINUTES: i64 = 15;
166
167        // First we check if there is a verification token
168        let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
169            "select token, valid_until from user_verification_token where user_id = $1",
170        )
171        .bind(id)
172        .fetch_optional(&self.pool)
173        .await
174        .map_err(fix_error)?;
175
176        let token = if let Some((token, valid_until)) = token {
177            trace!("Token for user {id} valid until {valid_until}");
178
179            // We have a token, AND it's still valid
180            if valid_until > time::OffsetDateTime::now_utc() {
181                token
182            } else {
183                // token has expired. generate a new one, return it
184                let token = crypto_random_string::<24>();
185
186                sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
187                    .bind(id)
188                    .bind(&token)
189                    .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
190                    .execute(&self.pool)
191                    .await
192                    .map_err(fix_error)?;
193
194                token
195            }
196        } else {
197            // No token in the database! Generate one, insert it
198            let token = crypto_random_string::<24>();
199
200            sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
201                .bind(id)
202                .bind(&token)
203                .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
204                .execute(&self.pool)
205                .await
206                .map_err(fix_error)?;
207
208            token
209        };
210
211        Ok(token)
212    }
213
214    #[instrument(skip_all)]
215    async fn get_session_user(&self, token: &str) -> DbResult<User> {
216        sqlx::query_as(
217            "select users.id, users.username, users.email, users.password, users.verified_at from users
218            inner join sessions
219            on users.id = sessions.user_id
220            and sessions.token = $1",
221        )
222        .bind(token)
223        .fetch_one(self.read_pool())
224        .await
225        .map_err(fix_error)
226        .map(|DbUser(user)| user)
227    }
228
229    #[instrument(skip_all)]
230    async fn count_history(&self, user: &User) -> DbResult<i64> {
231        // The cache is new, and the user might not yet have a cache value.
232        // They will have one as soon as they post up some new history, but handle that
233        // edge case.
234
235        let res: (i64,) = sqlx::query_as(
236            "select count(1) from history
237            where user_id = $1",
238        )
239        .bind(user.id)
240        .fetch_one(self.read_pool())
241        .await
242        .map_err(fix_error)?;
243
244        Ok(res.0)
245    }
246
247    #[instrument(skip_all)]
248    async fn total_history(&self) -> DbResult<i64> {
249        // The cache is new, and the user might not yet have a cache value.
250        // They will have one as soon as they post up some new history, but handle that
251        // edge case.
252
253        let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
254            .fetch_optional(self.read_pool())
255            .await
256            .map_err(fix_error)?
257            .unwrap_or((0,));
258
259        Ok(res.0)
260    }
261
262    #[instrument(skip_all)]
263    async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
264        let res: (i32,) = sqlx::query_as(
265            "select total from total_history_count_user
266            where user_id = $1",
267        )
268        .bind(user.id)
269        .fetch_one(self.read_pool())
270        .await
271        .map_err(fix_error)?;
272
273        Ok(res.0 as i64)
274    }
275
276    async fn delete_store(&self, user: &User) -> DbResult<()> {
277        let mut tx = self.pool.begin().await.map_err(fix_error)?;
278
279        sqlx::query(
280            "delete from store
281            where user_id = $1",
282        )
283        .bind(user.id)
284        .execute(&mut *tx)
285        .await
286        .map_err(fix_error)?;
287
288        sqlx::query(
289            "delete from store_idx_cache
290            where user_id = $1",
291        )
292        .bind(user.id)
293        .execute(&mut *tx)
294        .await
295        .map_err(fix_error)?;
296
297        tx.commit().await.map_err(fix_error)?;
298
299        Ok(())
300    }
301
302    async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
303        sqlx::query(
304            "update history
305            set deleted_at = $3
306            where user_id = $1
307            and client_id = $2
308            and deleted_at is null", // don't just keep setting it
309        )
310        .bind(user.id)
311        .bind(id)
312        .bind(OffsetDateTime::now_utc())
313        .fetch_all(&self.pool)
314        .await
315        .map_err(fix_error)?;
316
317        Ok(())
318    }
319
320    #[instrument(skip_all)]
321    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
322        // The cache is new, and the user might not yet have a cache value.
323        // They will have one as soon as they post up some new history, but handle that
324        // edge case.
325
326        let res = sqlx::query(
327            "select client_id from history
328            where user_id = $1
329            and deleted_at is not null",
330        )
331        .bind(user.id)
332        .fetch_all(self.read_pool())
333        .await
334        .map_err(fix_error)?;
335
336        let res = res
337            .iter()
338            .map(|row| row.get::<String, _>("client_id"))
339            .collect();
340
341        Ok(res)
342    }
343
344    #[instrument(skip_all)]
345    async fn count_history_range(
346        &self,
347        user: &User,
348        range: Range<OffsetDateTime>,
349    ) -> DbResult<i64> {
350        let res: (i64,) = sqlx::query_as(
351            "select count(1) from history
352            where user_id = $1
353            and timestamp >= $2::date
354            and timestamp < $3::date",
355        )
356        .bind(user.id)
357        .bind(into_utc(range.start))
358        .bind(into_utc(range.end))
359        .fetch_one(self.read_pool())
360        .await
361        .map_err(fix_error)?;
362
363        Ok(res.0)
364    }
365
366    #[instrument(skip_all)]
367    async fn list_history(
368        &self,
369        user: &User,
370        created_after: OffsetDateTime,
371        since: OffsetDateTime,
372        host: &str,
373        page_size: i64,
374    ) -> DbResult<Vec<History>> {
375        let res = sqlx::query_as(
376            "select id, client_id, user_id, hostname, timestamp, data, created_at from history
377            where user_id = $1
378            and hostname != $2
379            and created_at >= $3
380            and timestamp >= $4
381            order by timestamp asc
382            limit $5",
383        )
384        .bind(user.id)
385        .bind(host)
386        .bind(into_utc(created_after))
387        .bind(into_utc(since))
388        .bind(page_size)
389        .fetch(self.read_pool())
390        .map_ok(|DbHistory(h)| h)
391        .try_collect()
392        .await
393        .map_err(fix_error)?;
394
395        Ok(res)
396    }
397
398    #[instrument(skip_all)]
399    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
400        let mut tx = self.pool.begin().await.map_err(fix_error)?;
401
402        for i in history {
403            let client_id: &str = &i.client_id;
404            let hostname: &str = &i.hostname;
405            let data: &str = &i.data;
406
407            sqlx::query(
408                "insert into history
409                    (client_id, user_id, hostname, timestamp, data) 
410                values ($1, $2, $3, $4, $5)
411                on conflict do nothing
412                ",
413            )
414            .bind(client_id)
415            .bind(i.user_id)
416            .bind(hostname)
417            .bind(i.timestamp)
418            .bind(data)
419            .execute(&mut *tx)
420            .await
421            .map_err(fix_error)?;
422        }
423
424        tx.commit().await.map_err(fix_error)?;
425
426        Ok(())
427    }
428
429    #[instrument(skip_all)]
430    async fn delete_user(&self, u: &User) -> DbResult<()> {
431        sqlx::query("delete from sessions where user_id = $1")
432            .bind(u.id)
433            .execute(&self.pool)
434            .await
435            .map_err(fix_error)?;
436
437        sqlx::query("delete from history where user_id = $1")
438            .bind(u.id)
439            .execute(&self.pool)
440            .await
441            .map_err(fix_error)?;
442
443        sqlx::query("delete from store where user_id = $1")
444            .bind(u.id)
445            .execute(&self.pool)
446            .await
447            .map_err(fix_error)?;
448
449        sqlx::query("delete from user_verification_token where user_id = $1")
450            .bind(u.id)
451            .execute(&self.pool)
452            .await
453            .map_err(fix_error)?;
454
455        sqlx::query("delete from total_history_count_user where user_id = $1")
456            .bind(u.id)
457            .execute(&self.pool)
458            .await
459            .map_err(fix_error)?;
460
461        sqlx::query("delete from users where id = $1")
462            .bind(u.id)
463            .execute(&self.pool)
464            .await
465            .map_err(fix_error)?;
466
467        Ok(())
468    }
469
470    #[instrument(skip_all)]
471    async fn update_user_password(&self, user: &User) -> DbResult<()> {
472        sqlx::query(
473            "update users
474            set password = $1
475            where id = $2",
476        )
477        .bind(&user.password)
478        .bind(user.id)
479        .execute(&self.pool)
480        .await
481        .map_err(fix_error)?;
482
483        Ok(())
484    }
485
486    #[instrument(skip_all)]
487    async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
488        let email: &str = &user.email;
489        let username: &str = &user.username;
490        let password: &str = &user.password;
491
492        let res: (i64,) = sqlx::query_as(
493            "insert into users
494                (username, email, password)
495            values($1, $2, $3)
496            returning id",
497        )
498        .bind(username)
499        .bind(email)
500        .bind(password)
501        .fetch_one(&self.pool)
502        .await
503        .map_err(fix_error)?;
504
505        Ok(res.0)
506    }
507
508    #[instrument(skip_all)]
509    async fn add_session(&self, session: &NewSession) -> DbResult<()> {
510        let token: &str = &session.token;
511
512        sqlx::query(
513            "insert into sessions
514                (user_id, token)
515            values($1, $2)",
516        )
517        .bind(session.user_id)
518        .bind(token)
519        .execute(&self.pool)
520        .await
521        .map_err(fix_error)?;
522
523        Ok(())
524    }
525
526    #[instrument(skip_all)]
527    async fn get_user_session(&self, u: &User) -> DbResult<Session> {
528        sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
529            .bind(u.id)
530            .fetch_one(self.read_pool())
531            .await
532            .map_err(fix_error)
533            .map(|DbSession(session)| session)
534    }
535
536    #[instrument(skip_all)]
537    async fn oldest_history(&self, user: &User) -> DbResult<History> {
538        sqlx::query_as(
539            "select id, client_id, user_id, hostname, timestamp, data, created_at from history
540            where user_id = $1
541            order by timestamp asc
542            limit 1",
543        )
544        .bind(user.id)
545        .fetch_one(self.read_pool())
546        .await
547        .map_err(fix_error)
548        .map(|DbHistory(h)| h)
549    }
550
551    #[instrument(skip_all)]
552    async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
553        let mut tx = self.pool.begin().await.map_err(fix_error)?;
554
555        // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
556        // idx without having to make further database queries. Doing the query on this small
557        // amount of data should be much, much faster.
558        //
559        // Worst case, say we get this wrong. We end up caching data that isn't actually the max
560        // idx, so clients upload again. The cache logic can be verified with a sql query anyway :)
561
562        let mut heads = HashMap::<(HostId, &str), u64>::new();
563
564        for i in records {
565            let id = atuin_common::utils::uuid_v7();
566
567            let result = sqlx::query(
568                "insert into store
569                    (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) 
570                values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
571                on conflict do nothing
572                ",
573            )
574            .bind(id)
575            .bind(i.id)
576            .bind(i.host.id)
577            .bind(i.idx as i64)
578            .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
579            .bind(&i.version)
580            .bind(&i.tag)
581            .bind(&i.data.data)
582            .bind(&i.data.content_encryption_key)
583            .bind(user.id)
584            .execute(&mut *tx)
585            .await
586            .map_err(fix_error)?;
587
588            // Only update heads if we actually inserted the record
589            if result.rows_affected() > 0 {
590                heads
591                    .entry((i.host.id, &i.tag))
592                    .and_modify(|e| {
593                        if i.idx > *e {
594                            *e = i.idx
595                        }
596                    })
597                    .or_insert(i.idx);
598            }
599        }
600
601        // we've built the map of heads for this push, so commit it to the database
602        for ((host, tag), idx) in heads {
603            sqlx::query(
604                "insert into store_idx_cache
605                    (user_id, host, tag, idx) 
606                values ($1, $2, $3, $4)
607                on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
608                ",
609            )
610            .bind(user.id)
611            .bind(host)
612            .bind(tag)
613            .bind(idx as i64)
614            .execute(&mut *tx)
615            .await
616            .map_err(fix_error)?;
617        }
618
619        tx.commit().await.map_err(fix_error)?;
620
621        Ok(())
622    }
623
624    #[instrument(skip_all)]
625    async fn next_records(
626        &self,
627        user: &User,
628        host: HostId,
629        tag: String,
630        start: Option<RecordIdx>,
631        count: u64,
632    ) -> DbResult<Vec<Record<EncryptedData>>> {
633        tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
634        let start = start.unwrap_or(0);
635
636        let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
637            "select client_id, host, idx, timestamp, version, tag, data, cek from store
638                    where user_id = $1
639                    and tag = $2
640                    and host = $3
641                    and idx >= $4
642                    order by idx asc
643                    limit $5",
644        )
645        .bind(user.id)
646        .bind(tag.clone())
647        .bind(host)
648        .bind(start as i64)
649        .bind(count as i64)
650        .fetch_all(self.read_pool())
651        .await
652        .map_err(fix_error);
653
654        let ret = match records {
655            Ok(records) => {
656                let records: Vec<Record<EncryptedData>> = records
657                    .into_iter()
658                    .map(|f| {
659                        let record: Record<EncryptedData> = f.into();
660                        record
661                    })
662                    .collect();
663
664                records
665            }
666            Err(DbError::NotFound) => {
667                tracing::debug!("no records found in store: {:?}/{}", host, tag);
668                return Ok(vec![]);
669            }
670            Err(e) => return Err(e),
671        };
672
673        Ok(ret)
674    }
675
676    async fn status(&self, user: &User) -> DbResult<RecordStatus> {
677        const STATUS_SQL: &str =
678            "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
679
680        // If IDX_CACHE_ROLLOUT is set, then we
681        // 1. Read the value of the var, use it as a % chance of using the cache
682        // 2. If we use the cache, just read from the cache table
683        // 3. If we don't use the cache, read from the store table
684        // IDX_CACHE_ROLLOUT should be between 0 and 100.
685
686        let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
687        let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
688        let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
689
690        let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
691            tracing::debug!("using idx cache for user {}", user.id);
692            sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
693                .bind(user.id)
694                .fetch_all(self.read_pool())
695                .await
696                .map_err(fix_error)?
697        } else {
698            tracing::debug!("using aggregate query for user {}", user.id);
699            sqlx::query_as(STATUS_SQL)
700                .bind(user.id)
701                .fetch_all(self.read_pool())
702                .await
703                .map_err(fix_error)?
704        };
705
706        res.sort();
707
708        let mut status = RecordStatus::new();
709
710        for i in res.iter() {
711            status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
712        }
713
714        Ok(status)
715    }
716}
717
718fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
719    let x = x.to_offset(UtcOffset::UTC);
720    PrimitiveDateTime::new(x.date(), x.time())
721}
722
723#[cfg(test)]
724mod tests {
725    use time::macros::datetime;
726
727    use crate::into_utc;
728
729    #[test]
730    fn utc() {
731        let dt = datetime!(2023-09-26 15:11:02 +05:30);
732        assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
733        assert_eq!(into_utc(dt).assume_utc(), dt);
734
735        let dt = datetime!(2023-09-26 15:11:02 -07:00);
736        assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
737        assert_eq!(into_utc(dt).assume_utc(), dt);
738
739        let dt = datetime!(2023-09-26 15:11:02 +00:00);
740        assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
741        assert_eq!(into_utc(dt).assume_utc(), dt);
742    }
743}