Skip to main content

atuin_server_sqlite/
lib.rs

1use std::str::FromStr;
2
3use async_trait::async_trait;
4use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
5use atuin_server_database::{
6    Database, DbError, DbResult, DbSettings,
7    models::{History, NewHistory, NewSession, NewUser, Session, User},
8};
9use futures_util::TryStreamExt;
10use sqlx::{
11    Row,
12    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
13    types::Uuid,
14};
15use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
16use tracing::instrument;
17use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
18
19mod wrappers;
20
21#[derive(Clone)]
22pub struct Sqlite {
23    pool: sqlx::Pool<sqlx::sqlite::Sqlite>,
24}
25
26fn fix_error(error: sqlx::Error) -> DbError {
27    match error {
28        sqlx::Error::RowNotFound => DbError::NotFound,
29        error => DbError::Other(error.into()),
30    }
31}
32
33#[async_trait]
34impl Database for Sqlite {
35    async fn new(settings: &DbSettings) -> DbResult<Self> {
36        let opts = SqliteConnectOptions::from_str(&settings.db_uri)
37            .map_err(fix_error)?
38            .journal_mode(SqliteJournalMode::Wal)
39            .create_if_missing(true);
40
41        let pool = SqlitePoolOptions::new()
42            .connect_with(opts)
43            .await
44            .map_err(fix_error)?;
45
46        sqlx::migrate!("./migrations")
47            .run(&pool)
48            .await
49            .map_err(|error| DbError::Other(error.into()))?;
50
51        Ok(Self { pool })
52    }
53
54    #[instrument(skip_all)]
55    async fn get_session(&self, token: &str) -> DbResult<Session> {
56        sqlx::query_as("select id, user_id, token from sessions where token = $1")
57            .bind(token)
58            .fetch_one(&self.pool)
59            .await
60            .map_err(fix_error)
61            .map(|DbSession(session)| session)
62    }
63
64    #[instrument(skip_all)]
65    async fn get_session_user(&self, token: &str) -> DbResult<User> {
66        sqlx::query_as(
67            "select users.id, users.username, users.email, users.password from users
68            inner join sessions
69            on users.id = sessions.user_id
70            and sessions.token = $1",
71        )
72        .bind(token)
73        .fetch_one(&self.pool)
74        .await
75        .map_err(fix_error)
76        .map(|DbUser(user)| user)
77    }
78
79    #[instrument(skip_all)]
80    async fn add_session(&self, session: &NewSession) -> DbResult<()> {
81        let token: &str = &session.token;
82
83        sqlx::query(
84            "insert into sessions
85                (user_id, token)
86            values($1, $2)",
87        )
88        .bind(session.user_id)
89        .bind(token)
90        .execute(&self.pool)
91        .await
92        .map_err(fix_error)?;
93
94        Ok(())
95    }
96
97    #[instrument(skip_all)]
98    async fn get_user(&self, username: &str) -> DbResult<User> {
99        sqlx::query_as("select id, username, email, password from users where username = $1")
100            .bind(username)
101            .fetch_one(&self.pool)
102            .await
103            .map_err(fix_error)
104            .map(|DbUser(user)| user)
105    }
106
107    #[instrument(skip_all)]
108    async fn get_user_session(&self, u: &User) -> DbResult<Session> {
109        sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
110            .bind(u.id)
111            .fetch_one(&self.pool)
112            .await
113            .map_err(fix_error)
114            .map(|DbSession(session)| session)
115    }
116
117    #[instrument(skip_all)]
118    async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
119        let email: &str = &user.email;
120        let username: &str = &user.username;
121        let password: &str = &user.password;
122
123        let res: (i64,) = sqlx::query_as(
124            "insert into users
125                (username, email, password)
126            values($1, $2, $3)
127            returning id",
128        )
129        .bind(username)
130        .bind(email)
131        .bind(password)
132        .fetch_one(&self.pool)
133        .await
134        .map_err(fix_error)?;
135
136        Ok(res.0)
137    }
138
139    #[instrument(skip_all)]
140    async fn update_user_password(&self, user: &User) -> DbResult<()> {
141        sqlx::query(
142            "update users
143            set password = $1
144            where id = $2",
145        )
146        .bind(&user.password)
147        .bind(user.id)
148        .execute(&self.pool)
149        .await
150        .map_err(fix_error)?;
151
152        Ok(())
153    }
154
155    #[instrument(skip_all)]
156    async fn count_history(&self, user: &User) -> DbResult<i64> {
157        // The cache is new, and the user might not yet have a cache value.
158        // They will have one as soon as they post up some new history, but handle that
159        // edge case.
160
161        let res: (i64,) = sqlx::query_as(
162            "select count(1) from history
163            where user_id = $1",
164        )
165        .bind(user.id)
166        .fetch_one(&self.pool)
167        .await
168        .map_err(fix_error)?;
169
170        Ok(res.0)
171    }
172
173    #[instrument(skip_all)]
174    async fn count_history_cached(&self, _user: &User) -> DbResult<i64> {
175        Err(DbError::NotFound)
176    }
177
178    #[instrument(skip_all)]
179    async fn delete_user(&self, u: &User) -> DbResult<()> {
180        sqlx::query("delete from sessions where user_id = $1")
181            .bind(u.id)
182            .execute(&self.pool)
183            .await
184            .map_err(fix_error)?;
185
186        sqlx::query("delete from users where id = $1")
187            .bind(u.id)
188            .execute(&self.pool)
189            .await
190            .map_err(fix_error)?;
191
192        sqlx::query("delete from history where user_id = $1")
193            .bind(u.id)
194            .execute(&self.pool)
195            .await
196            .map_err(fix_error)?;
197
198        Ok(())
199    }
200
201    async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
202        sqlx::query(
203            "update history
204            set deleted_at = $3
205            where user_id = $1
206            and client_id = $2
207            and deleted_at is null", // don't just keep setting it
208        )
209        .bind(user.id)
210        .bind(id)
211        .bind(time::OffsetDateTime::now_utc())
212        .fetch_all(&self.pool)
213        .await
214        .map_err(fix_error)?;
215
216        Ok(())
217    }
218
219    #[instrument(skip_all)]
220    async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
221        // The cache is new, and the user might not yet have a cache value.
222        // They will have one as soon as they post up some new history, but handle that
223        // edge case.
224
225        let res = sqlx::query(
226            "select client_id from history 
227            where user_id = $1
228            and deleted_at is not null",
229        )
230        .bind(user.id)
231        .fetch_all(&self.pool)
232        .await
233        .map_err(fix_error)?;
234
235        let res = res.iter().map(|row| row.get("client_id")).collect();
236
237        Ok(res)
238    }
239
240    async fn delete_store(&self, user: &User) -> DbResult<()> {
241        sqlx::query(
242            "delete from store
243            where user_id = $1",
244        )
245        .bind(user.id)
246        .execute(&self.pool)
247        .await
248        .map_err(fix_error)?;
249
250        Ok(())
251    }
252
253    #[instrument(skip_all)]
254    async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
255        let mut tx = self.pool.begin().await.map_err(fix_error)?;
256
257        for i in records {
258            let id = atuin_common::utils::uuid_v7();
259
260            sqlx::query(
261                "insert into store
262                    (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) 
263                values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
264                on conflict do nothing
265                ",
266            )
267            .bind(id)
268            .bind(i.id)
269            .bind(i.host.id)
270            .bind(i.idx as i64)
271            .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
272            .bind(&i.version)
273            .bind(&i.tag)
274            .bind(&i.data.data)
275            .bind(&i.data.content_encryption_key)
276            .bind(user.id)
277            .execute(&mut *tx)
278            .await
279            .map_err(fix_error)?;
280        }
281
282        tx.commit().await.map_err(fix_error)?;
283
284        Ok(())
285    }
286
287    #[instrument(skip_all)]
288    async fn next_records(
289        &self,
290        user: &User,
291        host: HostId,
292        tag: String,
293        start: Option<RecordIdx>,
294        count: u64,
295    ) -> DbResult<Vec<Record<EncryptedData>>> {
296        tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
297        let start = start.unwrap_or(0);
298
299        let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
300            "select client_id, host, idx, timestamp, version, tag, data, cek from store
301                    where user_id = $1
302                    and tag = $2
303                    and host = $3
304                    and idx >= $4
305                    order by idx asc
306                    limit $5",
307        )
308        .bind(user.id)
309        .bind(tag.clone())
310        .bind(host)
311        .bind(start as i64)
312        .bind(count as i64)
313        .fetch_all(&self.pool)
314        .await
315        .map_err(fix_error);
316
317        let ret = match records {
318            Ok(records) => {
319                let records: Vec<Record<EncryptedData>> = records
320                    .into_iter()
321                    .map(|f| {
322                        let record: Record<EncryptedData> = f.into();
323                        record
324                    })
325                    .collect();
326
327                records
328            }
329            Err(DbError::NotFound) => {
330                tracing::debug!("no records found in store: {:?}/{}", host, tag);
331                return Ok(vec![]);
332            }
333            Err(e) => return Err(e),
334        };
335
336        Ok(ret)
337    }
338
339    async fn status(&self, user: &User) -> DbResult<RecordStatus> {
340        const STATUS_SQL: &str =
341            "select host, tag, max(idx) from store where user_id = $1 group by host, tag";
342
343        let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
344            .bind(user.id)
345            .fetch_all(&self.pool)
346            .await
347            .map_err(fix_error)?;
348
349        let mut status = RecordStatus::new();
350
351        for i in res {
352            status.set_raw(HostId(i.0), i.1, i.2 as u64);
353        }
354
355        Ok(status)
356    }
357
358    #[instrument(skip_all)]
359    async fn count_history_range(
360        &self,
361        user: &User,
362        range: std::ops::Range<time::OffsetDateTime>,
363    ) -> DbResult<i64> {
364        let res: (i64,) = sqlx::query_as(
365            "select count(1) from history
366            where user_id = $1
367            and timestamp >= $2::date
368            and timestamp < $3::date",
369        )
370        .bind(user.id)
371        .bind(into_utc(range.start))
372        .bind(into_utc(range.end))
373        .fetch_one(&self.pool)
374        .await
375        .map_err(fix_error)?;
376
377        Ok(res.0)
378    }
379
380    #[instrument(skip_all)]
381    async fn list_history(
382        &self,
383        user: &User,
384        created_after: time::OffsetDateTime,
385        since: time::OffsetDateTime,
386        host: &str,
387        page_size: i64,
388    ) -> DbResult<Vec<History>> {
389        let res = sqlx::query_as(
390            "select id, client_id, user_id, hostname, timestamp, data, created_at from history
391            where user_id = $1
392            and hostname != $2
393            and created_at >= $3
394            and timestamp >= $4
395            order by timestamp asc
396            limit $5",
397        )
398        .bind(user.id)
399        .bind(host)
400        .bind(into_utc(created_after))
401        .bind(into_utc(since))
402        .bind(page_size)
403        .fetch(&self.pool)
404        .map_ok(|DbHistory(h)| h)
405        .try_collect()
406        .await
407        .map_err(fix_error)?;
408
409        Ok(res)
410    }
411
412    #[instrument(skip_all)]
413    async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
414        let mut tx = self.pool.begin().await.map_err(fix_error)?;
415
416        for i in history {
417            let client_id: &str = &i.client_id;
418            let hostname: &str = &i.hostname;
419            let data: &str = &i.data;
420
421            sqlx::query(
422                "insert into history
423                    (client_id, user_id, hostname, timestamp, data) 
424                values ($1, $2, $3, $4, $5)
425                on conflict do nothing
426                ",
427            )
428            .bind(client_id)
429            .bind(i.user_id)
430            .bind(hostname)
431            .bind(i.timestamp)
432            .bind(data)
433            .execute(&mut *tx)
434            .await
435            .map_err(fix_error)?;
436        }
437
438        tx.commit().await.map_err(fix_error)?;
439
440        Ok(())
441    }
442
443    #[instrument(skip_all)]
444    async fn oldest_history(&self, user: &User) -> DbResult<History> {
445        sqlx::query_as(
446            "select id, client_id, user_id, hostname, timestamp, data, created_at from history 
447            where user_id = $1
448            order by timestamp asc
449            limit 1",
450        )
451        .bind(user.id)
452        .fetch_one(&self.pool)
453        .await
454        .map_err(fix_error)
455        .map(|DbHistory(h)| h)
456    }
457}
458
459fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
460    let x = x.to_offset(UtcOffset::UTC);
461    PrimitiveDateTime::new(x.date(), x.time())
462}