atuin_server_sqlite/
lib.rs

1use std::str::FromStr;
2
3use async_trait::async_trait;
4use atuin_common::{
5    record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus},
6    utils::crypto_random_string,
7};
8use atuin_server_database::{
9    Database, DbError, DbResult, DbSettings,
10    models::{History, NewHistory, NewSession, NewUser, Session, User},
11};
12use futures_util::TryStreamExt;
13use sqlx::{
14    Row,
15    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
16    types::Uuid,
17};
18use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
19use tracing::instrument;
20use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
21
22mod wrappers;
23
24#[derive(Clone)]
25pub struct Sqlite {
26    pool: sqlx::Pool<sqlx::sqlite::Sqlite>,
27}
28
29fn fix_error(error: sqlx::Error) -> DbError {
30    match error {
31        sqlx::Error::RowNotFound => DbError::NotFound,
32        error => DbError::Other(error.into()),
33    }
34}
35
36#[async_trait]
37impl Database for Sqlite {
38    async fn new(settings: &DbSettings) -> DbResult<Self> {
39        let opts = SqliteConnectOptions::from_str(&settings.db_uri)
40            .map_err(fix_error)?
41            .journal_mode(SqliteJournalMode::Wal)
42            .create_if_missing(true);
43
44        let pool = SqlitePoolOptions::new()
45            .connect_with(opts)
46            .await
47            .map_err(fix_error)?;
48
49        sqlx::migrate!("./migrations")
50            .run(&pool)
51            .await
52            .map_err(|error| DbError::Other(error.into()))?;
53
54        Ok(Self { pool })
55    }
56
57    #[instrument(skip_all)]
58    async fn get_session(&self, token: &str) -> DbResult<Session> {
59        sqlx::query_as("select id, user_id, token from sessions where token = $1")
60            .bind(token)
61            .fetch_one(&self.pool)
62            .await
63            .map_err(fix_error)
64            .map(|DbSession(session)| session)
65    }
66
67    #[instrument(skip_all)]
68    async fn get_session_user(&self, token: &str) -> DbResult<User> {
69        sqlx::query_as(
70            "select users.id, users.username, users.email, users.password, users.verified_at from users 
71            inner join sessions 
72            on users.id = sessions.user_id 
73            and sessions.token = $1",
74        )
75        .bind(token)
76        .fetch_one(&self.pool)
77        .await
78        .map_err(fix_error)
79        .map(|DbUser(user)| user)
80    }
81
82    #[instrument(skip_all)]
83    async fn add_session(&self, session: &NewSession) -> DbResult<()> {
84        let token: &str = &session.token;
85
86        sqlx::query(
87            "insert into sessions
88                (user_id, token)
89            values($1, $2)",
90        )
91        .bind(session.user_id)
92        .bind(token)
93        .execute(&self.pool)
94        .await
95        .map_err(fix_error)?;
96
97        Ok(())
98    }
99
100    #[instrument(skip_all)]
101    async fn get_user(&self, username: &str) -> DbResult<User> {
102        sqlx::query_as(
103            "select id, username, email, password, verified_at from users where username = $1",
104        )
105        .bind(username)
106        .fetch_one(&self.pool)
107        .await
108        .map_err(fix_error)
109        .map(|DbUser(user)| user)
110    }
111
112    #[instrument(skip_all)]
113    async fn get_user_session(&self, u: &User) -> DbResult<Session> {
114        sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
115            .bind(u.id)
116            .fetch_one(&self.pool)
117            .await
118            .map_err(fix_error)
119            .map(|DbSession(session)| session)
120    }
121
122    #[instrument(skip_all)]
123    async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
124        let email: &str = &user.email;
125        let username: &str = &user.username;
126        let password: &str = &user.password;
127
128        let res: (i64,) = sqlx::query_as(
129            "insert into users
130                (username, email, password)
131            values($1, $2, $3)
132            returning id",
133        )
134        .bind(username)
135        .bind(email)
136        .bind(password)
137        .fetch_one(&self.pool)
138        .await
139        .map_err(fix_error)?;
140
141        Ok(res.0)
142    }
143
144    #[instrument(skip_all)]
145    async fn user_verified(&self, id: i64) -> DbResult<bool> {
146        let res: (bool,) =
147            sqlx::query_as("select verified_at is not null from users where id = $1")
148                .bind(id)
149                .fetch_one(&self.pool)
150                .await
151                .map_err(fix_error)?;
152
153        Ok(res.0)
154    }
155
156    #[instrument(skip_all)]
157    async fn verify_user(&self, id: i64) -> DbResult<()> {
158        sqlx::query(
159            "update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
160        )
161        .bind(id)
162        .execute(&self.pool)
163        .await
164        .map_err(fix_error)?;
165
166        Ok(())
167    }
168
169    #[instrument(skip_all)]
170    async fn user_verification_token(&self, id: i64) -> DbResult<String> {
171        const TOKEN_VALID_MINUTES: i64 = 15;
172
173        // First we check if there is a verification token
174        let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
175            "select token, valid_until from user_verification_token where user_id = $1",
176        )
177        .bind(id)
178        .fetch_optional(&self.pool)
179        .await
180        .map_err(fix_error)?;
181
182        let token = if let Some((token, valid_until)) = token {
183            // We have a token, AND it's still valid
184            if valid_until > time::OffsetDateTime::now_utc() {
185                token
186            } else {
187                // token has expired. generate a new one, return it
188                let token = crypto_random_string::<24>();
189
190                sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
191                    .bind(id)
192                    .bind(&token)
193                    .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
194                    .execute(&self.pool)
195                    .await
196                    .map_err(fix_error)?;
197
198                token
199            }
200        } else {
201            // No token in the database! Generate one, insert it
202            let token = crypto_random_string::<24>();
203
204            sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
205                .bind(id)
206                .bind(&token)
207                .bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
208                .execute(&self.pool)
209                .await
210                .map_err(fix_error)?;
211
212            token
213        };
214
215        Ok(token)
216    }
217
218    #[instrument(skip_all)]
219    async fn update_user_password(&self, user: &User) -> DbResult<()> {
220        sqlx::query(
221            "update users
222            set password = $1
223            where id = $2",
224        )
225        .bind(&user.password)
226        .bind(user.id)
227        .execute(&self.pool)
228        .await
229        .map_err(fix_error)?;
230
231        Ok(())
232    }
233
234    #[instrument(skip_all)]
235    async fn total_history(&self) -> DbResult<i64> {
236        let res: (i64,) = sqlx::query_as("select count(1) from history")
237            .fetch_optional(&self.pool)
238            .await
239            .map_err(fix_error)?
240            .unwrap_or((0,));
241
242        Ok(res.0)
243    }
244
245    #[instrument(skip_all)]
246    async fn count_history(&self, user: &User) -> DbResult<i64> {
247        // The cache is new, and the user might not yet have a cache value.
248        // They will have one as soon as they post up some new history, but handle that
249        // edge case.
250
251        let res: (i64,) = sqlx::query_as(
252            "select count(1) from history
253            where user_id = $1",
254        )
255        .bind(user.id)
256        .fetch_one(&self.pool)
257        .await
258        .map_err(fix_error)?;
259
260        Ok(res.0)
261    }
262
263    #[instrument(skip_all)]
264    async fn count_history_cached(&self, _user: &User) -> DbResult<i64> {
265        Err(DbError::NotFound)
266    }
267
268    #[instrument(skip_all)]
269    async fn delete_user(&self, u: &User) -> DbResult<()> {
270        sqlx::query("delete from sessions where user_id = $1")
271            .bind(u.id)
272            .execute(&self.pool)
273            .await
274            .map_err(fix_error)?;
275
276        sqlx::query("delete from users where id = $1")
277            .bind(u.id)
278            .execute(&self.pool)
279            .await
280            .map_err(fix_error)?;
281
282        sqlx::query("delete from history where user_id = $1")
283            .bind(u.id)
284            .execute(&self.pool)
285            .await
286            .map_err(fix_error)?;
287
288        Ok(())
289    }
290
291    async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
292        sqlx::query(
293            "update history
294            set deleted_at = $3
295            where user_id = $1
296            and client_id = $2
297            and deleted_at is null", // don't just keep setting it
298        )
299        .bind(user.id)
300        .bind(id)
301        .bind(time::OffsetDateTime::now_utc())
302        .fetch_all(&self.pool)
303        .await
304        .map_err(fix_error)?;
305
306        Ok(())
307    }
308
309    #[instrument(skip_all)]
310    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
311        // The cache is new, and the user might not yet have a cache value.
312        // They will have one as soon as they post up some new history, but handle that
313        // edge case.
314
315        let res = sqlx::query(
316            "select client_id from history 
317            where user_id = $1
318            and deleted_at is not null",
319        )
320        .bind(user.id)
321        .fetch_all(&self.pool)
322        .await
323        .map_err(fix_error)?;
324
325        let res = res.iter().map(|row| row.get("client_id")).collect();
326
327        Ok(res)
328    }
329
330    async fn delete_store(&self, user: &User) -> DbResult<()> {
331        sqlx::query(
332            "delete from store
333            where user_id = $1",
334        )
335        .bind(user.id)
336        .execute(&self.pool)
337        .await
338        .map_err(fix_error)?;
339
340        Ok(())
341    }
342
343    #[instrument(skip_all)]
344    async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
345        let mut tx = self.pool.begin().await.map_err(fix_error)?;
346
347        for i in records {
348            let id = atuin_common::utils::uuid_v7();
349
350            sqlx::query(
351                "insert into store
352                    (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) 
353                values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
354                on conflict do nothing
355                ",
356            )
357            .bind(id)
358            .bind(i.id)
359            .bind(i.host.id)
360            .bind(i.idx as i64)
361            .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
362            .bind(&i.version)
363            .bind(&i.tag)
364            .bind(&i.data.data)
365            .bind(&i.data.content_encryption_key)
366            .bind(user.id)
367            .execute(&mut *tx)
368            .await
369            .map_err(fix_error)?;
370        }
371
372        tx.commit().await.map_err(fix_error)?;
373
374        Ok(())
375    }
376
377    #[instrument(skip_all)]
378    async fn next_records(
379        &self,
380        user: &User,
381        host: HostId,
382        tag: String,
383        start: Option<RecordIdx>,
384        count: u64,
385    ) -> DbResult<Vec<Record<EncryptedData>>> {
386        tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
387        let start = start.unwrap_or(0);
388
389        let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
390            "select client_id, host, idx, timestamp, version, tag, data, cek from store
391                    where user_id = $1
392                    and tag = $2
393                    and host = $3
394                    and idx >= $4
395                    order by idx asc
396                    limit $5",
397        )
398        .bind(user.id)
399        .bind(tag.clone())
400        .bind(host)
401        .bind(start as i64)
402        .bind(count as i64)
403        .fetch_all(&self.pool)
404        .await
405        .map_err(fix_error);
406
407        let ret = match records {
408            Ok(records) => {
409                let records: Vec<Record<EncryptedData>> = records
410                    .into_iter()
411                    .map(|f| {
412                        let record: Record<EncryptedData> = f.into();
413                        record
414                    })
415                    .collect();
416
417                records
418            }
419            Err(DbError::NotFound) => {
420                tracing::debug!("no records found in store: {:?}/{}", host, tag);
421                return Ok(vec![]);
422            }
423            Err(e) => return Err(e),
424        };
425
426        Ok(ret)
427    }
428
429    async fn status(&self, user: &User) -> DbResult<RecordStatus> {
430        const STATUS_SQL: &str =
431            "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
432
433        let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
434            .bind(user.id)
435            .fetch_all(&self.pool)
436            .await
437            .map_err(fix_error)?;
438
439        let mut status = RecordStatus::new();
440
441        for i in res {
442            status.set_raw(HostId(i.0), i.1, i.2 as u64);
443        }
444
445        Ok(status)
446    }
447
448    #[instrument(skip_all)]
449    async fn count_history_range(
450        &self,
451        user: &User,
452        range: std::ops::Range<time::OffsetDateTime>,
453    ) -> DbResult<i64> {
454        let res: (i64,) = sqlx::query_as(
455            "select count(1) from history
456            where user_id = $1
457            and timestamp >= $2::date
458            and timestamp < $3::date",
459        )
460        .bind(user.id)
461        .bind(into_utc(range.start))
462        .bind(into_utc(range.end))
463        .fetch_one(&self.pool)
464        .await
465        .map_err(fix_error)?;
466
467        Ok(res.0)
468    }
469
470    #[instrument(skip_all)]
471    async fn list_history(
472        &self,
473        user: &User,
474        created_after: time::OffsetDateTime,
475        since: time::OffsetDateTime,
476        host: &str,
477        page_size: i64,
478    ) -> DbResult<Vec<History>> {
479        let res = sqlx::query_as(
480            "select id, client_id, user_id, hostname, timestamp, data, created_at from history
481            where user_id = $1
482            and hostname != $2
483            and created_at >= $3
484            and timestamp >= $4
485            order by timestamp asc
486            limit $5",
487        )
488        .bind(user.id)
489        .bind(host)
490        .bind(into_utc(created_after))
491        .bind(into_utc(since))
492        .bind(page_size)
493        .fetch(&self.pool)
494        .map_ok(|DbHistory(h)| h)
495        .try_collect()
496        .await
497        .map_err(fix_error)?;
498
499        Ok(res)
500    }
501
502    #[instrument(skip_all)]
503    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
504        let mut tx = self.pool.begin().await.map_err(fix_error)?;
505
506        for i in history {
507            let client_id: &str = &i.client_id;
508            let hostname: &str = &i.hostname;
509            let data: &str = &i.data;
510
511            sqlx::query(
512                "insert into history
513                    (client_id, user_id, hostname, timestamp, data) 
514                values ($1, $2, $3, $4, $5)
515                on conflict do nothing
516                ",
517            )
518            .bind(client_id)
519            .bind(i.user_id)
520            .bind(hostname)
521            .bind(i.timestamp)
522            .bind(data)
523            .execute(&mut *tx)
524            .await
525            .map_err(fix_error)?;
526        }
527
528        tx.commit().await.map_err(fix_error)?;
529
530        Ok(())
531    }
532
533    #[instrument(skip_all)]
534    async fn oldest_history(&self, user: &User) -> DbResult<History> {
535        sqlx::query_as(
536            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
537            where user_id = $1
538            order by timestamp asc
539            limit 1",
540        )
541        .bind(user.id)
542        .fetch_one(&self.pool)
543        .await
544        .map_err(fix_error)
545        .map(|DbHistory(h)| h)
546    }
547}
548
549fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
550    let x = x.to_offset(UtcOffset::UTC);
551    PrimitiveDateTime::new(x.date(), x.time())
552}