Skip to main content

atuin_client/record/
sqlite_store.rs

1// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries.
2// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index
3// by tag/host
4
5use std::str::FromStr;
6use std::{path::Path, time::Duration};
7
8use async_trait::async_trait;
9use eyre::{Result, eyre};
10use fs_err as fs;
11
12use sqlx::{
13    Row,
14    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
15};
16
17use atuin_common::record::{
18    EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
19};
20use atuin_common::utils;
21use uuid::Uuid;
22
23use super::encryption::PASETO_V4;
24use super::store::Store;
25
26#[derive(Debug, Clone)]
27pub struct SqliteStore {
28    pool: SqlitePool,
29}
30
31impl SqliteStore {
32    pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
33        let path = path.as_ref();
34
35        debug!("opening sqlite database at {path:?}");
36
37        if utils::broken_symlink(path) {
38            eprintln!(
39                "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
40            );
41            std::process::exit(1);
42        }
43
44        if !path.exists()
45            && let Some(dir) = path.parent()
46        {
47            fs::create_dir_all(dir)?;
48        }
49
50        let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
51            .journal_mode(SqliteJournalMode::Wal)
52            .foreign_keys(true)
53            .create_if_missing(true);
54
55        let pool = SqlitePoolOptions::new()
56            .acquire_timeout(Duration::from_secs_f64(timeout))
57            .connect_with(opts)
58            .await?;
59
60        Self::setup_db(&pool).await?;
61
62        Ok(Self { pool })
63    }
64
65    async fn setup_db(pool: &SqlitePool) -> Result<()> {
66        debug!("running sqlite database setup");
67
68        sqlx::migrate!("./record-migrations").run(pool).await?;
69
70        Ok(())
71    }
72
73    async fn save_raw(
74        tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
75        r: &Record<EncryptedData>,
76    ) -> Result<()> {
77        // In sqlite, we are "limited" to i64. But that is still fine, until 2262.
78        sqlx::query(
79            "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
80                values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
81        )
82        .bind(r.id.0.as_hyphenated().to_string())
83        .bind(r.idx as i64)
84        .bind(r.host.id.0.as_hyphenated().to_string())
85        .bind(r.tag.as_str())
86        .bind(r.timestamp as i64)
87        .bind(r.version.as_str())
88        .bind(r.data.data.as_str())
89        .bind(r.data.content_encryption_key.as_str())
90        .execute(&mut **tx)
91        .await?;
92
93        Ok(())
94    }
95
96    fn query_row(row: SqliteRow) -> Record<EncryptedData> {
97        let idx: i64 = row.get("idx");
98        let timestamp: i64 = row.get("timestamp");
99
100        // tbh at this point things are pretty fucked so just panic
101        let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
102        let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
103
104        Record {
105            id: RecordId(id),
106            idx: idx as u64,
107            host: Host::new(HostId(host)),
108            timestamp: timestamp as u64,
109            tag: row.get("tag"),
110            version: row.get("version"),
111            data: EncryptedData {
112                data: row.get("data"),
113                content_encryption_key: row.get("cek"),
114            },
115        }
116    }
117
118    async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> {
119        let res = sqlx::query("select * from store ")
120            .map(Self::query_row)
121            .fetch_all(&self.pool)
122            .await?;
123
124        Ok(res)
125    }
126}
127
128#[async_trait]
129impl Store for SqliteStore {
130    async fn push_batch(
131        &self,
132        records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
133    ) -> Result<()> {
134        let mut tx = self.pool.begin().await?;
135
136        for record in records {
137            Self::save_raw(&mut tx, record).await?;
138        }
139
140        tx.commit().await?;
141
142        Ok(())
143    }
144
145    async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
146        let res = sqlx::query("select * from store where store.id = ?1")
147            .bind(id.0.as_hyphenated().to_string())
148            .map(Self::query_row)
149            .fetch_one(&self.pool)
150            .await?;
151
152        Ok(res)
153    }
154
155    async fn delete(&self, id: RecordId) -> Result<()> {
156        sqlx::query("delete from store where id = ?1")
157            .bind(id.0.as_hyphenated().to_string())
158            .execute(&self.pool)
159            .await?;
160
161        Ok(())
162    }
163
164    async fn delete_all(&self) -> Result<()> {
165        sqlx::query("delete from store").execute(&self.pool).await?;
166
167        Ok(())
168    }
169
170    async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
171        let res =
172            sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
173                .bind(host.0.as_hyphenated().to_string())
174                .bind(tag)
175                .map(Self::query_row)
176                .fetch_one(&self.pool)
177                .await;
178
179        match res {
180            Err(sqlx::Error::RowNotFound) => Ok(None),
181            Err(e) => Err(eyre!("an error occurred: {}", e)),
182            Ok(record) => Ok(Some(record)),
183        }
184    }
185
186    async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
187        self.idx(host, tag, 0).await
188    }
189
190    async fn len_all(&self) -> Result<u64> {
191        let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store")
192            .fetch_one(&self.pool)
193            .await;
194        match res {
195            Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
196            Ok(v) => Ok(v.0 as u64),
197        }
198    }
199
200    async fn len_tag(&self, tag: &str) -> Result<u64> {
201        let res: Result<(i64,), sqlx::Error> =
202            sqlx::query_as("select count(*) from store where tag=?1")
203                .bind(tag)
204                .fetch_one(&self.pool)
205                .await;
206        match res {
207            Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
208            Ok(v) => Ok(v.0 as u64),
209        }
210    }
211
212    async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
213        let last = self.last(host, tag).await?;
214
215        if let Some(last) = last {
216            return Ok(last.idx + 1);
217        }
218
219        return Ok(0);
220    }
221
222    async fn next(
223        &self,
224        host: HostId,
225        tag: &str,
226        idx: RecordIdx,
227        limit: u64,
228    ) -> Result<Vec<Record<EncryptedData>>> {
229        let res = sqlx::query(
230            "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4",
231        )
232        .bind(idx as i64)
233        .bind(host.0.as_hyphenated().to_string())
234        .bind(tag)
235        .bind(limit as i64)
236        .map(Self::query_row)
237        .fetch_all(&self.pool)
238        .await?;
239
240        Ok(res)
241    }
242
243    async fn idx(
244        &self,
245        host: HostId,
246        tag: &str,
247        idx: RecordIdx,
248    ) -> Result<Option<Record<EncryptedData>>> {
249        let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
250            .bind(idx as i64)
251            .bind(host.0.as_hyphenated().to_string())
252            .bind(tag)
253            .map(Self::query_row)
254            .fetch_one(&self.pool)
255            .await;
256
257        match res {
258            Err(sqlx::Error::RowNotFound) => Ok(None),
259            Err(e) => Err(eyre!("an error occurred: {}", e)),
260            Ok(v) => Ok(Some(v)),
261        }
262    }
263
264    async fn status(&self) -> Result<RecordStatus> {
265        let mut status = RecordStatus::new();
266
267        let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
268            sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
269                .fetch_all(&self.pool)
270                .await;
271
272        let res = match res {
273            Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
274            Ok(v) => v,
275        };
276
277        for i in res {
278            let host = HostId(
279                Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
280            );
281
282            status.set_raw(host, i.1, i.2 as u64);
283        }
284
285        Ok(status)
286    }
287
288    async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
289        let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
290            .bind(tag)
291            .map(Self::query_row)
292            .fetch_all(&self.pool)
293            .await?;
294
295        Ok(res)
296    }
297
298    /// Reencrypt every single item in this store with a new key
299    /// Be careful - this may mess with sync.
300    async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> {
301        // Load all the records
302        // In memory like some of the other code here
303        // This will never be called in a hot loop, and only under the following circumstances
304        // 1. The user has logged into a new account, with a new key. They are unlikely to have a
305        //    lot of data
306        // 2. The user has encountered some sort of issue, and runs a maintenance command that
307        //    invokes this
308        let all = self.load_all().await?;
309
310        let re_encrypted = all
311            .into_iter()
312            .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key))
313            .collect::<Result<Vec<_>>>()?;
314
315        // next up, we delete all the old data and reinsert the new stuff
316        // do it in one transaction, so if anything fails we rollback OK
317
318        let mut tx = self.pool.begin().await?;
319
320        let res = sqlx::query("delete from store").execute(&mut *tx).await?;
321
322        let rows = res.rows_affected();
323        debug!("deleted {rows} rows");
324
325        // don't call push_batch, as it will start its own transaction
326        // call the underlying save_raw
327
328        for record in re_encrypted {
329            Self::save_raw(&mut tx, &record).await?;
330        }
331
332        tx.commit().await?;
333
334        Ok(())
335    }
336
337    /// Verify that every record in this store can be decrypted with the current key
338    /// Someday maybe also check each tag/record can be deserialized, but not for now.
339    async fn verify(&self, key: &[u8; 32]) -> Result<()> {
340        let all = self.load_all().await?;
341
342        all.into_iter()
343            .map(|record| record.decrypt::<PASETO_V4>(key))
344            .collect::<Result<Vec<_>>>()?;
345
346        Ok(())
347    }
348
349    /// Verify that every record in this store can be decrypted with the current key
350    /// Someday maybe also check each tag/record can be deserialized, but not for now.
351    async fn purge(&self, key: &[u8; 32]) -> Result<()> {
352        let all = self.load_all().await?;
353
354        for record in all.iter() {
355            match record.clone().decrypt::<PASETO_V4>(key) {
356                Ok(_) => continue,
357                Err(_) => {
358                    println!(
359                        "Failed to decrypt {}, deleting",
360                        record.id.0.as_hyphenated()
361                    );
362
363                    self.delete(record.id).await?;
364                }
365            }
366        }
367
368        Ok(())
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use atuin_common::{
375        record::{DecryptedData, EncryptedData, Host, HostId, Record},
376        utils::uuid_v7,
377    };
378
379    use crate::{
380        encryption::generate_encoded_key,
381        record::{encryption::PASETO_V4, store::Store},
382        settings::test_local_timeout,
383    };
384
385    use super::SqliteStore;
386
387    fn test_record() -> Record<EncryptedData> {
388        Record::builder()
389            .host(Host::new(HostId(atuin_common::utils::uuid_v7())))
390            .version("v1".into())
391            .tag(atuin_common::utils::uuid_v7().simple().to_string())
392            .data(EncryptedData {
393                data: "1234".into(),
394                content_encryption_key: "1234".into(),
395            })
396            .idx(0)
397            .build()
398    }
399
400    #[tokio::test]
401    async fn create_db() {
402        let db = SqliteStore::new(":memory:", test_local_timeout()).await;
403
404        assert!(
405            db.is_ok(),
406            "db could not be created, {:?}",
407            db.err().unwrap()
408        );
409    }
410
411    #[tokio::test]
412    async fn push_record() {
413        let db = SqliteStore::new(":memory:", test_local_timeout())
414            .await
415            .unwrap();
416        let record = test_record();
417
418        db.push(&record).await.expect("failed to insert record");
419    }
420
421    #[tokio::test]
422    async fn get_record() {
423        let db = SqliteStore::new(":memory:", test_local_timeout())
424            .await
425            .unwrap();
426        let record = test_record();
427        db.push(&record).await.unwrap();
428
429        let new_record = db.get(record.id).await.expect("failed to fetch record");
430
431        assert_eq!(record, new_record, "records are not equal");
432    }
433
434    #[tokio::test]
435    async fn last() {
436        let db = SqliteStore::new(":memory:", test_local_timeout())
437            .await
438            .unwrap();
439        let record = test_record();
440        db.push(&record).await.unwrap();
441
442        let last = db
443            .last(record.host.id, record.tag.as_str())
444            .await
445            .expect("failed to get store len");
446
447        assert_eq!(
448            last.unwrap().id,
449            record.id,
450            "expected to get back the same record that was inserted"
451        );
452    }
453
454    #[tokio::test]
455    async fn first() {
456        let db = SqliteStore::new(":memory:", test_local_timeout())
457            .await
458            .unwrap();
459        let record = test_record();
460        db.push(&record).await.unwrap();
461
462        let first = db
463            .first(record.host.id, record.tag.as_str())
464            .await
465            .expect("failed to get store len");
466
467        assert_eq!(
468            first.unwrap().id,
469            record.id,
470            "expected to get back the same record that was inserted"
471        );
472    }
473
474    #[tokio::test]
475    async fn len() {
476        let db = SqliteStore::new(":memory:", test_local_timeout())
477            .await
478            .unwrap();
479        let record = test_record();
480        db.push(&record).await.unwrap();
481
482        let len = db
483            .len(record.host.id, record.tag.as_str())
484            .await
485            .expect("failed to get store len");
486
487        assert_eq!(len, 1, "expected length of 1 after insert");
488    }
489
490    #[tokio::test]
491    async fn len_tag() {
492        let db = SqliteStore::new(":memory:", test_local_timeout())
493            .await
494            .unwrap();
495        let record = test_record();
496        db.push(&record).await.unwrap();
497
498        let len = db
499            .len_tag(record.tag.as_str())
500            .await
501            .expect("failed to get store len");
502
503        assert_eq!(len, 1, "expected length of 1 after insert");
504    }
505
506    #[tokio::test]
507    async fn len_different_tags() {
508        let db = SqliteStore::new(":memory:", test_local_timeout())
509            .await
510            .unwrap();
511
512        // these have different tags, so the len should be the same
513        // we model multiple stores within one database
514        // new store = new tag = independent length
515        let first = test_record();
516        let second = test_record();
517
518        db.push(&first).await.unwrap();
519        db.push(&second).await.unwrap();
520
521        let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
522        let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
523
524        assert_eq!(first_len, 1, "expected length of 1 after insert");
525        assert_eq!(second_len, 1, "expected length of 1 after insert");
526    }
527
528    #[tokio::test]
529    async fn append_a_bunch() {
530        let db = SqliteStore::new(":memory:", test_local_timeout())
531            .await
532            .unwrap();
533
534        let mut tail = test_record();
535        db.push(&tail).await.expect("failed to push record");
536
537        for _ in 1..100 {
538            tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
539            db.push(&tail).await.unwrap();
540        }
541
542        assert_eq!(
543            db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
544            100,
545            "failed to insert 100 records"
546        );
547
548        assert_eq!(
549            db.len_tag(tail.tag.as_str()).await.unwrap(),
550            100,
551            "failed to insert 100 records"
552        );
553    }
554
555    #[tokio::test]
556    async fn append_a_big_bunch() {
557        let db = SqliteStore::new(":memory:", test_local_timeout())
558            .await
559            .unwrap();
560
561        let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000);
562
563        let mut tail = test_record();
564        records.push(tail.clone());
565
566        for _ in 1..10000 {
567            tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
568            records.push(tail.clone());
569        }
570
571        db.push_batch(records.iter()).await.unwrap();
572
573        assert_eq!(
574            db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
575            10000,
576            "failed to insert 10k records"
577        );
578    }
579
580    #[tokio::test]
581    async fn re_encrypt() {
582        let store = SqliteStore::new(":memory:", test_local_timeout())
583            .await
584            .unwrap();
585        let (key, _) = generate_encoded_key().unwrap();
586        let data = vec![0u8, 1u8, 2u8, 3u8];
587        let host_id = HostId(uuid_v7());
588
589        for i in 0..10 {
590            let record = Record::builder()
591                .host(Host::new(host_id))
592                .version(String::from("test"))
593                .tag(String::from("test"))
594                .idx(i)
595                .data(DecryptedData(data.clone()))
596                .build();
597
598            let record = record.encrypt::<PASETO_V4>(&key.into());
599            store
600                .push(&record)
601                .await
602                .expect("failed to push encrypted record");
603        }
604
605        // first, check that we can decrypt the data with the current key
606        let all = store.all_tagged("test").await.unwrap();
607
608        assert_eq!(all.len(), 10, "failed to fetch all records");
609
610        for record in all {
611            let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap();
612            assert_eq!(decrypted.data.0, data);
613        }
614
615        // reencrypt the store, then check if
616        // 1) it cannot be decrypted with the old key
617        // 2) it can be decrypted with the new key
618
619        let (new_key, _) = generate_encoded_key().unwrap();
620        store
621            .re_encrypt(&key.into(), &new_key.into())
622            .await
623            .expect("failed to re-encrypt store");
624
625        let all = store.all_tagged("test").await.unwrap();
626
627        for record in all.iter() {
628            let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into());
629            assert!(
630                decrypted.is_err(),
631                "did not get error decrypting with old key after re-encrypt"
632            )
633        }
634
635        for record in all {
636            let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap();
637            assert_eq!(decrypted.data.0, data);
638        }
639
640        assert_eq!(store.len(host_id, "test").await.unwrap(), 10);
641    }
642}