Skip to main content

atuin_server_postgres/
lib.rs

1use std::collections::HashMap;
2use std::ops::Range;
3
4use rand::Rng;
5
6use async_trait::async_trait;
7use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
8use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
9use atuin_server_database::{Database, DbError, DbResult, DbSettings};
10use futures_util::TryStreamExt;
11use sqlx::Row;
12use sqlx::postgres::PgPoolOptions;
13
14use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
15use tracing::instrument;
16use uuid::Uuid;
17use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
18
19mod wrappers;
20
21const MIN_PG_VERSION: u32 = 14;
22
23#[derive(Clone)]
24pub struct Postgres {
25    pool: sqlx::Pool<sqlx::postgres::Postgres>,
26    /// Optional read replica pool for read-only queries
27    read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>,
28}
29
30impl Postgres {
31    /// Returns the appropriate pool for read operations.
32    /// Uses read_pool if available, otherwise falls back to the primary pool.
33    fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> {
34        self.read_pool.as_ref().unwrap_or(&self.pool)
35    }
36}
37
38fn fix_error(error: sqlx::Error) -> DbError {
39    match error {
40        sqlx::Error::RowNotFound => DbError::NotFound,
41        error => DbError::Other(error.into()),
42    }
43}
44
45#[async_trait]
46impl Database for Postgres {
47    async fn new(settings: &DbSettings) -> DbResult<Self> {
48        let pool = PgPoolOptions::new()
49            .max_connections(100)
50            .connect(settings.db_uri.as_str())
51            .await
52            .map_err(fix_error)?;
53
54        // Call server_version_num to get the DB server's major version number
55        // The call returns None for servers older than 8.x.
56        let pg_major_version: u32 = pool
57            .acquire()
58            .await
59            .map_err(fix_error)?
60            .server_version_num()
61            .ok_or(DbError::Other(eyre::Report::msg(
62                "could not get PostgreSQL version",
63            )))?
64            / 10000;
65
66        if pg_major_version < MIN_PG_VERSION {
67            return Err(DbError::Other(eyre::Report::msg(format!(
68                "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}"
69            ))));
70        }
71
72        sqlx::migrate!("./migrations")
73            .run(&pool)
74            .await
75            .map_err(|error| DbError::Other(error.into()))?;
76
77        // Create read replica pool if configured
78        let read_pool = if let Some(read_db_uri) = &settings.read_db_uri {
79            tracing::info!("Connecting to read replica database");
80            let read_pool = PgPoolOptions::new()
81                .max_connections(100)
82                .connect(read_db_uri.as_str())
83                .await
84                .map_err(fix_error)?;
85
86            // Verify the read replica is also a supported PostgreSQL version
87            let read_pg_major_version: u32 = read_pool
88                .acquire()
89                .await
90                .map_err(fix_error)?
91                .server_version_num()
92                .ok_or(DbError::Other(eyre::Report::msg(
93                    "could not get PostgreSQL version from read replica",
94                )))?
95                / 10000;
96
97            if read_pg_major_version < MIN_PG_VERSION {
98                return Err(DbError::Other(eyre::Report::msg(format!(
99                    "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}"
100                ))));
101            }
102
103            Some(read_pool)
104        } else {
105            None
106        };
107
108        Ok(Self { pool, read_pool })
109    }
110
111    #[instrument(skip_all)]
112    async fn get_session(&self, token: &str) -> DbResult<Session> {
113        sqlx::query_as("select id, user_id, token from sessions where token = $1")
114            .bind(token)
115            .fetch_one(self.read_pool())
116            .await
117            .map_err(fix_error)
118            .map(|DbSession(session)| session)
119    }
120
121    #[instrument(skip_all)]
122    async fn get_user(&self, username: &str) -> DbResult<User> {
123        sqlx::query_as("select id, username, email, password from users where username = $1")
124            .bind(username)
125            .fetch_one(self.read_pool())
126            .await
127            .map_err(fix_error)
128            .map(|DbUser(user)| user)
129    }
130
131    #[instrument(skip_all)]
132    async fn get_session_user(&self, token: &str) -> DbResult<User> {
133        sqlx::query_as(
134            "select users.id, users.username, users.email, users.password from users
135            inner join sessions
136            on users.id = sessions.user_id
137            and sessions.token = $1",
138        )
139        .bind(token)
140        .fetch_one(self.read_pool())
141        .await
142        .map_err(fix_error)
143        .map(|DbUser(user)| user)
144    }
145
146    #[instrument(skip_all)]
147    async fn count_history(&self, user: &User) -> DbResult<i64> {
148        // The cache is new, and the user might not yet have a cache value.
149        // They will have one as soon as they post up some new history, but handle that
150        // edge case.
151
152        let res: (i64,) = sqlx::query_as(
153            "select count(1) from history
154            where user_id = $1",
155        )
156        .bind(user.id)
157        .fetch_one(self.read_pool())
158        .await
159        .map_err(fix_error)?;
160
161        Ok(res.0)
162    }
163
164    #[instrument(skip_all)]
165    async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
166        let res: (i32,) = sqlx::query_as(
167            "select total from total_history_count_user
168            where user_id = $1",
169        )
170        .bind(user.id)
171        .fetch_one(self.read_pool())
172        .await
173        .map_err(fix_error)?;
174
175        Ok(res.0 as i64)
176    }
177
178    async fn delete_store(&self, user: &User) -> DbResult<()> {
179        let mut tx = self.pool.begin().await.map_err(fix_error)?;
180
181        sqlx::query(
182            "delete from store
183            where user_id = $1",
184        )
185        .bind(user.id)
186        .execute(&mut *tx)
187        .await
188        .map_err(fix_error)?;
189
190        sqlx::query(
191            "delete from store_idx_cache
192            where user_id = $1",
193        )
194        .bind(user.id)
195        .execute(&mut *tx)
196        .await
197        .map_err(fix_error)?;
198
199        tx.commit().await.map_err(fix_error)?;
200
201        Ok(())
202    }
203
204    async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
205        sqlx::query(
206            "update history
207            set deleted_at = $3
208            where user_id = $1
209            and client_id = $2
210            and deleted_at is null", // don't just keep setting it
211        )
212        .bind(user.id)
213        .bind(id)
214        .bind(OffsetDateTime::now_utc())
215        .fetch_all(&self.pool)
216        .await
217        .map_err(fix_error)?;
218
219        Ok(())
220    }
221
222    #[instrument(skip_all)]
223    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
224        // The cache is new, and the user might not yet have a cache value.
225        // They will have one as soon as they post up some new history, but handle that
226        // edge case.
227
228        let res = sqlx::query(
229            "select client_id from history
230            where user_id = $1
231            and deleted_at is not null",
232        )
233        .bind(user.id)
234        .fetch_all(self.read_pool())
235        .await
236        .map_err(fix_error)?;
237
238        let res = res
239            .iter()
240            .map(|row| row.get::<String, _>("client_id"))
241            .collect();
242
243        Ok(res)
244    }
245
246    #[instrument(skip_all)]
247    async fn count_history_range(
248        &self,
249        user: &User,
250        range: Range<OffsetDateTime>,
251    ) -> DbResult<i64> {
252        let res: (i64,) = sqlx::query_as(
253            "select count(1) from history
254            where user_id = $1
255            and timestamp >= $2::date
256            and timestamp < $3::date",
257        )
258        .bind(user.id)
259        .bind(into_utc(range.start))
260        .bind(into_utc(range.end))
261        .fetch_one(self.read_pool())
262        .await
263        .map_err(fix_error)?;
264
265        Ok(res.0)
266    }
267
268    #[instrument(skip_all)]
269    async fn list_history(
270        &self,
271        user: &User,
272        created_after: OffsetDateTime,
273        since: OffsetDateTime,
274        host: &str,
275        page_size: i64,
276    ) -> DbResult<Vec<History>> {
277        let res = sqlx::query_as(
278            "select id, client_id, user_id, hostname, timestamp, data, created_at from history
279            where user_id = $1
280            and hostname != $2
281            and created_at >= $3
282            and timestamp >= $4
283            order by timestamp asc
284            limit $5",
285        )
286        .bind(user.id)
287        .bind(host)
288        .bind(into_utc(created_after))
289        .bind(into_utc(since))
290        .bind(page_size)
291        .fetch(self.read_pool())
292        .map_ok(|DbHistory(h)| h)
293        .try_collect()
294        .await
295        .map_err(fix_error)?;
296
297        Ok(res)
298    }
299
300    #[instrument(skip_all)]
301    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
302        let mut tx = self.pool.begin().await.map_err(fix_error)?;
303
304        for i in history {
305            let client_id: &str = &i.client_id;
306            let hostname: &str = &i.hostname;
307            let data: &str = &i.data;
308
309            sqlx::query(
310                "insert into history
311                    (client_id, user_id, hostname, timestamp, data) 
312                values ($1, $2, $3, $4, $5)
313                on conflict do nothing
314                ",
315            )
316            .bind(client_id)
317            .bind(i.user_id)
318            .bind(hostname)
319            .bind(i.timestamp)
320            .bind(data)
321            .execute(&mut *tx)
322            .await
323            .map_err(fix_error)?;
324        }
325
326        tx.commit().await.map_err(fix_error)?;
327
328        Ok(())
329    }
330
331    #[instrument(skip_all)]
332    async fn delete_user(&self, u: &User) -> DbResult<()> {
333        sqlx::query("delete from sessions where user_id = $1")
334            .bind(u.id)
335            .execute(&self.pool)
336            .await
337            .map_err(fix_error)?;
338
339        sqlx::query("delete from history where user_id = $1")
340            .bind(u.id)
341            .execute(&self.pool)
342            .await
343            .map_err(fix_error)?;
344
345        sqlx::query("delete from store where user_id = $1")
346            .bind(u.id)
347            .execute(&self.pool)
348            .await
349            .map_err(fix_error)?;
350
351        sqlx::query("delete from total_history_count_user where user_id = $1")
352            .bind(u.id)
353            .execute(&self.pool)
354            .await
355            .map_err(fix_error)?;
356
357        sqlx::query("delete from users where id = $1")
358            .bind(u.id)
359            .execute(&self.pool)
360            .await
361            .map_err(fix_error)?;
362
363        Ok(())
364    }
365
366    #[instrument(skip_all)]
367    async fn update_user_password(&self, user: &User) -> DbResult<()> {
368        sqlx::query(
369            "update users
370            set password = $1
371            where id = $2",
372        )
373        .bind(&user.password)
374        .bind(user.id)
375        .execute(&self.pool)
376        .await
377        .map_err(fix_error)?;
378
379        Ok(())
380    }
381
382    #[instrument(skip_all)]
383    async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
384        let email: &str = &user.email;
385        let username: &str = &user.username;
386        let password: &str = &user.password;
387
388        let res: (i64,) = sqlx::query_as(
389            "insert into users
390                (username, email, password)
391            values($1, $2, $3)
392            returning id",
393        )
394        .bind(username)
395        .bind(email)
396        .bind(password)
397        .fetch_one(&self.pool)
398        .await
399        .map_err(fix_error)?;
400
401        Ok(res.0)
402    }
403
404    #[instrument(skip_all)]
405    async fn add_session(&self, session: &NewSession) -> DbResult<()> {
406        let token: &str = &session.token;
407
408        sqlx::query(
409            "insert into sessions
410                (user_id, token)
411            values($1, $2)",
412        )
413        .bind(session.user_id)
414        .bind(token)
415        .execute(&self.pool)
416        .await
417        .map_err(fix_error)?;
418
419        Ok(())
420    }
421
422    #[instrument(skip_all)]
423    async fn get_user_session(&self, u: &User) -> DbResult<Session> {
424        sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
425            .bind(u.id)
426            .fetch_one(self.read_pool())
427            .await
428            .map_err(fix_error)
429            .map(|DbSession(session)| session)
430    }
431
432    #[instrument(skip_all)]
433    async fn oldest_history(&self, user: &User) -> DbResult<History> {
434        sqlx::query_as(
435            "select id, client_id, user_id, hostname, timestamp, data, created_at from history
436            where user_id = $1
437            order by timestamp asc
438            limit 1",
439        )
440        .bind(user.id)
441        .fetch_one(self.read_pool())
442        .await
443        .map_err(fix_error)
444        .map(|DbHistory(h)| h)
445    }
446
447    #[instrument(skip_all)]
448    async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
449        let mut tx = self.pool.begin().await.map_err(fix_error)?;
450
451        // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
452        // idx without having to make further database queries. Doing the query on this small
453        // amount of data should be much, much faster.
454        //
455        // Worst case, say we get this wrong. We end up caching data that isn't actually the max
456        // idx, so clients upload again. The cache logic can be verified with a sql query anyway :)
457
458        let mut heads = HashMap::<(HostId, &str), u64>::new();
459
460        for i in records {
461            let id = atuin_common::utils::uuid_v7();
462
463            let result = sqlx::query(
464                "insert into store
465                    (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) 
466                values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
467                on conflict do nothing
468                ",
469            )
470            .bind(id)
471            .bind(i.id)
472            .bind(i.host.id)
473            .bind(i.idx as i64)
474            .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
475            .bind(&i.version)
476            .bind(&i.tag)
477            .bind(&i.data.data)
478            .bind(&i.data.content_encryption_key)
479            .bind(user.id)
480            .execute(&mut *tx)
481            .await
482            .map_err(fix_error)?;
483
484            // Only update heads if we actually inserted the record
485            if result.rows_affected() > 0 {
486                heads
487                    .entry((i.host.id, &i.tag))
488                    .and_modify(|e| {
489                        if i.idx > *e {
490                            *e = i.idx
491                        }
492                    })
493                    .or_insert(i.idx);
494            }
495        }
496
497        // we've built the map of heads for this push, so commit it to the database
498        for ((host, tag), idx) in heads {
499            sqlx::query(
500                "insert into store_idx_cache
501                    (user_id, host, tag, idx) 
502                values ($1, $2, $3, $4)
503                on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
504                ",
505            )
506            .bind(user.id)
507            .bind(host)
508            .bind(tag)
509            .bind(idx as i64)
510            .execute(&mut *tx)
511            .await
512            .map_err(fix_error)?;
513        }
514
515        tx.commit().await.map_err(fix_error)?;
516
517        Ok(())
518    }
519
520    #[instrument(skip_all)]
521    async fn next_records(
522        &self,
523        user: &User,
524        host: HostId,
525        tag: String,
526        start: Option<RecordIdx>,
527        count: u64,
528    ) -> DbResult<Vec<Record<EncryptedData>>> {
529        tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
530        let start = start.unwrap_or(0);
531
532        let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
533            "select client_id, host, idx, timestamp, version, tag, data, cek from store
534                    where user_id = $1
535                    and tag = $2
536                    and host = $3
537                    and idx >= $4
538                    order by idx asc
539                    limit $5",
540        )
541        .bind(user.id)
542        .bind(tag.clone())
543        .bind(host)
544        .bind(start as i64)
545        .bind(count as i64)
546        .fetch_all(self.read_pool())
547        .await
548        .map_err(fix_error);
549
550        let ret = match records {
551            Ok(records) => {
552                let records: Vec<Record<EncryptedData>> = records
553                    .into_iter()
554                    .map(|f| {
555                        let record: Record<EncryptedData> = f.into();
556                        record
557                    })
558                    .collect();
559
560                records
561            }
562            Err(DbError::NotFound) => {
563                tracing::debug!("no records found in store: {:?}/{}", host, tag);
564                return Ok(vec![]);
565            }
566            Err(e) => return Err(e),
567        };
568
569        Ok(ret)
570    }
571
572    async fn status(&self, user: &User) -> DbResult<RecordStatus> {
573        const STATUS_SQL: &str =
574            "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
575
576        // If IDX_CACHE_ROLLOUT is set, then we
577        // 1. Read the value of the var, use it as a % chance of using the cache
578        // 2. If we use the cache, just read from the cache table
579        // 3. If we don't use the cache, read from the store table
580        // IDX_CACHE_ROLLOUT should be between 0 and 100.
581
582        let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string());
583        let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0);
584        let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0);
585
586        let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache {
587            tracing::debug!("using idx cache for user {}", user.id);
588            sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
589                .bind(user.id)
590                .fetch_all(self.read_pool())
591                .await
592                .map_err(fix_error)?
593        } else {
594            tracing::debug!("using aggregate query for user {}", user.id);
595            sqlx::query_as(STATUS_SQL)
596                .bind(user.id)
597                .fetch_all(self.read_pool())
598                .await
599                .map_err(fix_error)?
600        };
601
602        res.sort();
603
604        let mut status = RecordStatus::new();
605
606        for i in res.iter() {
607            status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
608        }
609
610        Ok(status)
611    }
612}
613
614fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
615    let x = x.to_offset(UtcOffset::UTC);
616    PrimitiveDateTime::new(x.date(), x.time())
617}
618
619#[cfg(test)]
620mod tests {
621    use time::macros::datetime;
622
623    use crate::into_utc;
624
625    #[test]
626    fn utc() {
627        let dt = datetime!(2023-09-26 15:11:02 +05:30);
628        assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
629        assert_eq!(into_utc(dt).assume_utc(), dt);
630
631        let dt = datetime!(2023-09-26 15:11:02 -07:00);
632        assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
633        assert_eq!(into_utc(dt).assume_utc(), dt);
634
635        let dt = datetime!(2023-09-26 15:11:02 +00:00);
636        assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
637        assert_eq!(into_utc(dt).assume_utc(), dt);
638    }
639}