atuin_server_postgres/
lib.rs

1use std::collections::HashMap;
2use std::ops::Range;
3
4use async_trait::async_trait;
5use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
6use atuin_common::utils::crypto_random_string;
7use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
8use atuin_server_database::{Database, DbError, DbResult, DbSettings};
9use futures_util::TryStreamExt;
10use metrics::counter;
11use sqlx::Row;
12use sqlx::postgres::PgPoolOptions;
13
14use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
15use tracing::{instrument, trace};
16use uuid::Uuid;
17use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
18
19mod wrappers;
20
21const MIN_PG_VERSION: u32 = 14;
22
23#[derive(Clone)]
24pub struct Postgres {
25    pool: sqlx::Pool<sqlx::postgres::Postgres>,
26}
27
28fn fix_error(error: sqlx::Error) -> DbError {
29    match error {
30        sqlx::Error::RowNotFound => DbError::NotFound,
31        error => DbError::Other(error.into()),
32    }
33}
34
35#[async_trait]
36impl Database for Postgres {
37    async fn new(settings: &DbSettings) -> DbResult<Self> {
38        let pool = PgPoolOptions::new()
39            .max_connections(100)
40            .connect(settings.db_uri.as_str())
41            .await
42            .map_err(fix_error)?;
43
44        // Call server_version_num to get the DB server's major version number
45        // The call returns None for servers older than 8.x.
46        let pg_major_version: u32 = pool
47            .acquire()
48            .await
49            .map_err(fix_error)?
50            .server_version_num()
51            .ok_or(DbError::Other(eyre::Report::msg(
52                "could not get PostgreSQL version",
53            )))?
54            / 10000;
55
56        if pg_major_version < MIN_PG_VERSION {
57            return Err(DbError::Other(eyre::Report::msg(format!(
58                "unsupported PostgreSQL version {}, minimum required is {}",
59                pg_major_version, MIN_PG_VERSION
60            ))));
61        }
62
63        sqlx::migrate!("./migrations")
64            .run(&pool)
65            .await
66            .map_err(|error| DbError::Other(error.into()))?;
67
68        Ok(Self { pool })
69    }
70
71    #[instrument(skip_all)]
72    async fn get_session(&self, token: &str) -> DbResult<Session> {
73        sqlx::query_as("select id, user_id, token from sessions where token = $1")
74            .bind(token)
75            .fetch_one(&self.pool)
76            .await
77            .map_err(fix_error)
78            .map(|DbSession(session)| session)
79    }
80
81    #[instrument(skip_all)]
82    async fn get_user(&self, username: &str) -> DbResult<User> {
83        sqlx::query_as(
84            "select id, username, email, password, verified_at from users where username = $1",
85        )
86        .bind(username)
87        .fetch_one(&self.pool)
88        .await
89        .map_err(fix_error)
90        .map(|DbUser(user)| user)
91    }
92
93    #[instrument(skip_all)]
94    async fn user_verified(&self, id: i64) -> DbResult<bool> {
95        let res: (bool,) =
96            sqlx::query_as("select verified_at is not null from users where id = $1")
97                .bind(id)
98                .fetch_one(&self.pool)
99                .await
100                .map_err(fix_error)?;
101
102        Ok(res.0)
103    }
104
105    #[instrument(skip_all)]
106    async fn verify_user(&self, id: i64) -> DbResult<()> {
107        sqlx::query(
108            "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
109        )
110        .bind(id)
111        .execute(&self.pool)
112        .await
113        .map_err(fix_error)?;
114
115        Ok(())
116    }
117
118    /// Return a valid verification token for the user
119    /// If the user does not have any token, create one, insert it, and return
120    /// If the user has a token, but it's invalid, delete it, create a new one, return
121    /// If the user already has a valid token, return it
122    #[instrument(skip_all)]
123    async fn user_verification_token(&self, id: i64) -> DbResult<String> {
124        const TOKEN_VALID_MINUTES: i64 = 15;
125
126        // First we check if there is a verification token
127        let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
128            "select token, valid_until from user_verification_token where user_id = $1",
129        )
130        .bind(id)
131        .fetch_optional(&self.pool)
132        .await
133        .map_err(fix_error)?;
134
135        let token = if let Some((token, valid_until)) = token {
136            trace!("Token for user {id} valid until {valid_until}");
137
138            // We have a token, AND it's still valid
139            if valid_until > time::OffsetDateTime::now_utc() {
140                token
141            } else {
142                // token has expired. generate a new one, return it
143                let token = crypto_random_string::<24>();
144
145                sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
146                    .bind(id)
147                    .bind(&token)
148                    .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
149                    .execute(&self.pool)
150                    .await
151                    .map_err(fix_error)?;
152
153                token
154            }
155        } else {
156            // No token in the database! Generate one, insert it
157            let token = crypto_random_string::<24>();
158
159            sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
160                .bind(id)
161                .bind(&token)
162                .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
163                .execute(&self.pool)
164                .await
165                .map_err(fix_error)?;
166
167            token
168        };
169
170        Ok(token)
171    }
172
173    #[instrument(skip_all)]
174    async fn get_session_user(&self, token: &str) -> DbResult<User> {
175        sqlx::query_as(
176            "select users.id, users.username, users.email, users.password, users.verified_at from users 
177            inner join sessions 
178            on users.id = sessions.user_id 
179            and sessions.token = $1",
180        )
181        .bind(token)
182        .fetch_one(&self.pool)
183        .await
184        .map_err(fix_error)
185        .map(|DbUser(user)| user)
186    }
187
188    #[instrument(skip_all)]
189    async fn count_history(&self, user: &User) -> DbResult<i64> {
190        // The cache is new, and the user might not yet have a cache value.
191        // They will have one as soon as they post up some new history, but handle that
192        // edge case.
193
194        let res: (i64,) = sqlx::query_as(
195            "select count(1) from history
196            where user_id = $1",
197        )
198        .bind(user.id)
199        .fetch_one(&self.pool)
200        .await
201        .map_err(fix_error)?;
202
203        Ok(res.0)
204    }
205
206    #[instrument(skip_all)]
207    async fn total_history(&self) -> DbResult<i64> {
208        // The cache is new, and the user might not yet have a cache value.
209        // They will have one as soon as they post up some new history, but handle that
210        // edge case.
211
212        let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user")
213            .fetch_optional(&self.pool)
214            .await
215            .map_err(fix_error)?
216            .unwrap_or((0,));
217
218        Ok(res.0)
219    }
220
221    #[instrument(skip_all)]
222    async fn count_history_cached(&self, user: &User) -> DbResult<i64> {
223        let res: (i32,) = sqlx::query_as(
224            "select total from total_history_count_user
225            where user_id = $1",
226        )
227        .bind(user.id)
228        .fetch_one(&self.pool)
229        .await
230        .map_err(fix_error)?;
231
232        Ok(res.0 as i64)
233    }
234
235    async fn delete_store(&self, user: &User) -> DbResult<()> {
236        sqlx::query(
237            "delete from store
238            where user_id = $1",
239        )
240        .bind(user.id)
241        .execute(&self.pool)
242        .await
243        .map_err(fix_error)?;
244
245        Ok(())
246    }
247
248    async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
249        sqlx::query(
250            "update history
251            set deleted_at = $3
252            where user_id = $1
253            and client_id = $2
254            and deleted_at is null", // don't just keep setting it
255        )
256        .bind(user.id)
257        .bind(id)
258        .bind(OffsetDateTime::now_utc())
259        .fetch_all(&self.pool)
260        .await
261        .map_err(fix_error)?;
262
263        Ok(())
264    }
265
266    #[instrument(skip_all)]
267    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
268        // The cache is new, and the user might not yet have a cache value.
269        // They will have one as soon as they post up some new history, but handle that
270        // edge case.
271
272        let res = sqlx::query(
273            "select client_id from history 
274            where user_id = $1
275            and deleted_at is not null",
276        )
277        .bind(user.id)
278        .fetch_all(&self.pool)
279        .await
280        .map_err(fix_error)?;
281
282        let res = res
283            .iter()
284            .map(|row| row.get::<String, _>("client_id"))
285            .collect();
286
287        Ok(res)
288    }
289
290    #[instrument(skip_all)]
291    async fn count_history_range(
292        &self,
293        user: &User,
294        range: Range<OffsetDateTime>,
295    ) -> DbResult<i64> {
296        let res: (i64,) = sqlx::query_as(
297            "select count(1) from history
298            where user_id = $1
299            and timestamp >= $2::date
300            and timestamp < $3::date",
301        )
302        .bind(user.id)
303        .bind(into_utc(range.start))
304        .bind(into_utc(range.end))
305        .fetch_one(&self.pool)
306        .await
307        .map_err(fix_error)?;
308
309        Ok(res.0)
310    }
311
312    #[instrument(skip_all)]
313    async fn list_history(
314        &self,
315        user: &User,
316        created_after: OffsetDateTime,
317        since: OffsetDateTime,
318        host: &str,
319        page_size: i64,
320    ) -> DbResult<Vec<History>> {
321        let res = sqlx::query_as(
322            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
323            where user_id = $1
324            and hostname != $2
325            and created_at >= $3
326            and timestamp >= $4
327            order by timestamp asc
328            limit $5",
329        )
330        .bind(user.id)
331        .bind(host)
332        .bind(into_utc(created_after))
333        .bind(into_utc(since))
334        .bind(page_size)
335        .fetch(&self.pool)
336        .map_ok(|DbHistory(h)| h)
337        .try_collect()
338        .await
339        .map_err(fix_error)?;
340
341        Ok(res)
342    }
343
344    #[instrument(skip_all)]
345    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
346        let mut tx = self.pool.begin().await.map_err(fix_error)?;
347
348        for i in history {
349            let client_id: &str = &i.client_id;
350            let hostname: &str = &i.hostname;
351            let data: &str = &i.data;
352
353            sqlx::query(
354                "insert into history
355                    (client_id, user_id, hostname, timestamp, data) 
356                values ($1, $2, $3, $4, $5)
357                on conflict do nothing
358                ",
359            )
360            .bind(client_id)
361            .bind(i.user_id)
362            .bind(hostname)
363            .bind(i.timestamp)
364            .bind(data)
365            .execute(&mut *tx)
366            .await
367            .map_err(fix_error)?;
368        }
369
370        tx.commit().await.map_err(fix_error)?;
371
372        Ok(())
373    }
374
375    #[instrument(skip_all)]
376    async fn delete_user(&self, u: &User) -> DbResult<()> {
377        sqlx::query("delete from sessions where user_id = $1")
378            .bind(u.id)
379            .execute(&self.pool)
380            .await
381            .map_err(fix_error)?;
382
383        sqlx::query("delete from history where user_id = $1")
384            .bind(u.id)
385            .execute(&self.pool)
386            .await
387            .map_err(fix_error)?;
388
389        sqlx::query("delete from store where user_id = $1")
390            .bind(u.id)
391            .execute(&self.pool)
392            .await
393            .map_err(fix_error)?;
394
395        sqlx::query("delete from user_verification_token where user_id = $1")
396            .bind(u.id)
397            .execute(&self.pool)
398            .await
399            .map_err(fix_error)?;
400
401        sqlx::query("delete from total_history_count_user where user_id = $1")
402            .bind(u.id)
403            .execute(&self.pool)
404            .await
405            .map_err(fix_error)?;
406
407        sqlx::query("delete from users where id = $1")
408            .bind(u.id)
409            .execute(&self.pool)
410            .await
411            .map_err(fix_error)?;
412
413        Ok(())
414    }
415
416    #[instrument(skip_all)]
417    async fn update_user_password(&self, user: &User) -> DbResult<()> {
418        sqlx::query(
419            "update users
420            set password = $1
421            where id = $2",
422        )
423        .bind(&user.password)
424        .bind(user.id)
425        .execute(&self.pool)
426        .await
427        .map_err(fix_error)?;
428
429        Ok(())
430    }
431
432    #[instrument(skip_all)]
433    async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
434        let email: &str = &user.email;
435        let username: &str = &user.username;
436        let password: &str = &user.password;
437
438        let res: (i64,) = sqlx::query_as(
439            "insert into users
440                (username, email, password)
441            values($1, $2, $3)
442            returning id",
443        )
444        .bind(username)
445        .bind(email)
446        .bind(password)
447        .fetch_one(&self.pool)
448        .await
449        .map_err(fix_error)?;
450
451        Ok(res.0)
452    }
453
454    #[instrument(skip_all)]
455    async fn add_session(&self, session: &NewSession) -> DbResult<()> {
456        let token: &str = &session.token;
457
458        sqlx::query(
459            "insert into sessions
460                (user_id, token)
461            values($1, $2)",
462        )
463        .bind(session.user_id)
464        .bind(token)
465        .execute(&self.pool)
466        .await
467        .map_err(fix_error)?;
468
469        Ok(())
470    }
471
472    #[instrument(skip_all)]
473    async fn get_user_session(&self, u: &User) -> DbResult<Session> {
474        sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
475            .bind(u.id)
476            .fetch_one(&self.pool)
477            .await
478            .map_err(fix_error)
479            .map(|DbSession(session)| session)
480    }
481
482    #[instrument(skip_all)]
483    async fn oldest_history(&self, user: &User) -> DbResult<History> {
484        sqlx::query_as(
485            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
486            where user_id = $1
487            order by timestamp asc
488            limit 1",
489        )
490        .bind(user.id)
491        .fetch_one(&self.pool)
492        .await
493        .map_err(fix_error)
494        .map(|DbHistory(h)| h)
495    }
496
497    #[instrument(skip_all)]
498    async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
499        let mut tx = self.pool.begin().await.map_err(fix_error)?;
500
501        // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max
502        // idx without having to make further database queries. Doing the query on this small
503        // amount of data should be much, much faster.
504        //
505        // Worst case, say we get this wrong. We end up caching data that isn't actually the max
506        // idx, so clients upload again. The cache logic can be verified with a sql query anyway :)
507
508        let mut heads = HashMap::<(HostId, &str), u64>::new();
509
510        for i in records {
511            let id = atuin_common::utils::uuid_v7();
512
513            sqlx::query(
514                "insert into store
515                    (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) 
516                values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
517                on conflict do nothing
518                ",
519            )
520            .bind(id)
521            .bind(i.id)
522            .bind(i.host.id)
523            .bind(i.idx as i64)
524            .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
525            .bind(&i.version)
526            .bind(&i.tag)
527            .bind(&i.data.data)
528            .bind(&i.data.content_encryption_key)
529            .bind(user.id)
530            .execute(&mut *tx)
531            .await
532            .map_err(fix_error)?;
533
534            // we're already iterating sooooo
535            heads
536                .entry((i.host.id, &i.tag))
537                .and_modify(|e| {
538                    if i.idx > *e {
539                        *e = i.idx
540                    }
541                })
542                .or_insert(i.idx);
543        }
544
545        // we've built the map of heads for this push, so commit it to the database
546        for ((host, tag), idx) in heads {
547            sqlx::query(
548                "insert into store_idx_cache
549                    (user_id, host, tag, idx) 
550                values ($1, $2, $3, $4)
551                on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4)
552                ",
553            )
554            .bind(user.id)
555            .bind(host)
556            .bind(tag)
557            .bind(idx as i64)
558            .execute(&mut *tx)
559            .await
560            .map_err(fix_error)?;
561        }
562
563        tx.commit().await.map_err(fix_error)?;
564
565        Ok(())
566    }
567
568    #[instrument(skip_all)]
569    async fn next_records(
570        &self,
571        user: &User,
572        host: HostId,
573        tag: String,
574        start: Option<RecordIdx>,
575        count: u64,
576    ) -> DbResult<Vec<Record<EncryptedData>>> {
577        tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
578        let start = start.unwrap_or(0);
579
580        let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
581            "select client_id, host, idx, timestamp, version, tag, data, cek from store
582                    where user_id = $1
583                    and tag = $2
584                    and host = $3
585                    and idx >= $4
586                    order by idx asc
587                    limit $5",
588        )
589        .bind(user.id)
590        .bind(tag.clone())
591        .bind(host)
592        .bind(start as i64)
593        .bind(count as i64)
594        .fetch_all(&self.pool)
595        .await
596        .map_err(fix_error);
597
598        let ret = match records {
599            Ok(records) => {
600                let records: Vec<Record<EncryptedData>> = records
601                    .into_iter()
602                    .map(|f| {
603                        let record: Record<EncryptedData> = f.into();
604                        record
605                    })
606                    .collect();
607
608                records
609            }
610            Err(DbError::NotFound) => {
611                tracing::debug!("no records found in store: {:?}/{}", host, tag);
612                return Ok(vec![]);
613            }
614            Err(e) => return Err(e),
615        };
616
617        Ok(ret)
618    }
619
620    async fn status(&self, user: &User) -> DbResult<RecordStatus> {
621        const STATUS_SQL: &str =
622            "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
623
624        let mut res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
625            .bind(user.id)
626            .fetch_all(&self.pool)
627            .await
628            .map_err(fix_error)?;
629        res.sort();
630
631        // We're temporarily increasing latency in order to improve confidence in the cache
632        // If it runs for a few days, and we confirm that cached values are equal to realtime, we
633        // can replace realtime with cached.
634        //
635        // But let's check so sync doesn't do Weird Things.
636
637        let mut cached_res: Vec<(Uuid, String, i64)> =
638            sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1")
639                .bind(user.id)
640                .fetch_all(&self.pool)
641                .await
642                .map_err(fix_error)?;
643        cached_res.sort();
644
645        let mut status = RecordStatus::new();
646
647        let equal = res == cached_res;
648
649        if equal {
650            counter!("atuin_store_idx_cache_consistent", 1);
651        } else {
652            // log the values if we have an inconsistent cache
653            tracing::debug!(user = user.username, cache_match = equal, res = ?res, cached = ?cached_res, "record store index request");
654            counter!("atuin_store_idx_cache_inconsistent", 1);
655        };
656
657        for i in res.iter() {
658            status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64);
659        }
660
661        Ok(status)
662    }
663}
664
665fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
666    let x = x.to_offset(UtcOffset::UTC);
667    PrimitiveDateTime::new(x.date(), x.time())
668}
669
670#[cfg(test)]
671mod tests {
672    use time::macros::datetime;
673
674    use crate::into_utc;
675
676    #[test]
677    fn utc() {
678        let dt = datetime!(2023-09-26 15:11:02 +05:30);
679        assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02));
680        assert_eq!(into_utc(dt).assume_utc(), dt);
681
682        let dt = datetime!(2023-09-26 15:11:02 -07:00);
683        assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02));
684        assert_eq!(into_utc(dt).assume_utc(), dt);
685
686        let dt = datetime!(2023-09-26 15:11:02 +00:00);
687        assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02));
688        assert_eq!(into_utc(dt).assume_utc(), dt);
689    }
690}