1use std::str::FromStr;
6use std::{path::Path, time::Duration};
7
8use async_trait::async_trait;
9use eyre::{Result, eyre};
10use fs_err as fs;
11
12use sqlx::{
13 Row,
14 sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
15};
16
17use atuin_common::record::{
18 EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
19};
20use atuin_common::utils;
21use uuid::Uuid;
22
23use super::encryption::PASETO_V4;
24use super::store::Store;
25
26#[derive(Debug, Clone)]
27pub struct SqliteStore {
28 pool: SqlitePool,
29}
30
31impl SqliteStore {
32 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
33 let path = path.as_ref();
34
35 debug!("opening sqlite database at {path:?}");
36
37 if utils::broken_symlink(path) {
38 eprintln!(
39 "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
40 );
41 std::process::exit(1);
42 }
43
44 if !path.exists()
45 && let Some(dir) = path.parent()
46 {
47 fs::create_dir_all(dir)?;
48 }
49
50 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
51 .journal_mode(SqliteJournalMode::Wal)
52 .foreign_keys(true)
53 .create_if_missing(true);
54
55 let pool = SqlitePoolOptions::new()
56 .acquire_timeout(Duration::from_secs_f64(timeout))
57 .connect_with(opts)
58 .await?;
59
60 Self::setup_db(&pool).await?;
61
62 Ok(Self { pool })
63 }
64
65 async fn setup_db(pool: &SqlitePool) -> Result<()> {
66 debug!("running sqlite database setup");
67
68 sqlx::migrate!("./record-migrations").run(pool).await?;
69
70 Ok(())
71 }
72
73 async fn save_raw(
74 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
75 r: &Record<EncryptedData>,
76 ) -> Result<()> {
77 sqlx::query(
79 "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
80 values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
81 )
82 .bind(r.id.0.as_hyphenated().to_string())
83 .bind(r.idx as i64)
84 .bind(r.host.id.0.as_hyphenated().to_string())
85 .bind(r.tag.as_str())
86 .bind(r.timestamp as i64)
87 .bind(r.version.as_str())
88 .bind(r.data.data.as_str())
89 .bind(r.data.content_encryption_key.as_str())
90 .execute(&mut **tx)
91 .await?;
92
93 Ok(())
94 }
95
96 fn query_row(row: SqliteRow) -> Record<EncryptedData> {
97 let idx: i64 = row.get("idx");
98 let timestamp: i64 = row.get("timestamp");
99
100 let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
102 let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
103
104 Record {
105 id: RecordId(id),
106 idx: idx as u64,
107 host: Host::new(HostId(host)),
108 timestamp: timestamp as u64,
109 tag: row.get("tag"),
110 version: row.get("version"),
111 data: EncryptedData {
112 data: row.get("data"),
113 content_encryption_key: row.get("cek"),
114 },
115 }
116 }
117
118 async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> {
119 let res = sqlx::query("select * from store ")
120 .map(Self::query_row)
121 .fetch_all(&self.pool)
122 .await?;
123
124 Ok(res)
125 }
126}
127
128#[async_trait]
129impl Store for SqliteStore {
130 async fn push_batch(
131 &self,
132 records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
133 ) -> Result<()> {
134 let mut tx = self.pool.begin().await?;
135
136 for record in records {
137 Self::save_raw(&mut tx, record).await?;
138 }
139
140 tx.commit().await?;
141
142 Ok(())
143 }
144
145 async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
146 let res = sqlx::query("select * from store where store.id = ?1")
147 .bind(id.0.as_hyphenated().to_string())
148 .map(Self::query_row)
149 .fetch_one(&self.pool)
150 .await?;
151
152 Ok(res)
153 }
154
155 async fn delete(&self, id: RecordId) -> Result<()> {
156 sqlx::query("delete from store where id = ?1")
157 .bind(id.0.as_hyphenated().to_string())
158 .execute(&self.pool)
159 .await?;
160
161 Ok(())
162 }
163
164 async fn delete_all(&self) -> Result<()> {
165 sqlx::query("delete from store").execute(&self.pool).await?;
166
167 Ok(())
168 }
169
170 async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
171 let res =
172 sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
173 .bind(host.0.as_hyphenated().to_string())
174 .bind(tag)
175 .map(Self::query_row)
176 .fetch_one(&self.pool)
177 .await;
178
179 match res {
180 Err(sqlx::Error::RowNotFound) => Ok(None),
181 Err(e) => Err(eyre!("an error occurred: {}", e)),
182 Ok(record) => Ok(Some(record)),
183 }
184 }
185
186 async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
187 self.idx(host, tag, 0).await
188 }
189
190 async fn len_all(&self) -> Result<u64> {
191 let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store")
192 .fetch_one(&self.pool)
193 .await;
194 match res {
195 Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
196 Ok(v) => Ok(v.0 as u64),
197 }
198 }
199
200 async fn len_tag(&self, tag: &str) -> Result<u64> {
201 let res: Result<(i64,), sqlx::Error> =
202 sqlx::query_as("select count(*) from store where tag=?1")
203 .bind(tag)
204 .fetch_one(&self.pool)
205 .await;
206 match res {
207 Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
208 Ok(v) => Ok(v.0 as u64),
209 }
210 }
211
212 async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
213 let last = self.last(host, tag).await?;
214
215 if let Some(last) = last {
216 return Ok(last.idx + 1);
217 }
218
219 return Ok(0);
220 }
221
222 async fn next(
223 &self,
224 host: HostId,
225 tag: &str,
226 idx: RecordIdx,
227 limit: u64,
228 ) -> Result<Vec<Record<EncryptedData>>> {
229 let res = sqlx::query(
230 "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4",
231 )
232 .bind(idx as i64)
233 .bind(host.0.as_hyphenated().to_string())
234 .bind(tag)
235 .bind(limit as i64)
236 .map(Self::query_row)
237 .fetch_all(&self.pool)
238 .await?;
239
240 Ok(res)
241 }
242
243 async fn idx(
244 &self,
245 host: HostId,
246 tag: &str,
247 idx: RecordIdx,
248 ) -> Result<Option<Record<EncryptedData>>> {
249 let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
250 .bind(idx as i64)
251 .bind(host.0.as_hyphenated().to_string())
252 .bind(tag)
253 .map(Self::query_row)
254 .fetch_one(&self.pool)
255 .await;
256
257 match res {
258 Err(sqlx::Error::RowNotFound) => Ok(None),
259 Err(e) => Err(eyre!("an error occurred: {}", e)),
260 Ok(v) => Ok(Some(v)),
261 }
262 }
263
264 async fn status(&self) -> Result<RecordStatus> {
265 let mut status = RecordStatus::new();
266
267 let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
268 sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
269 .fetch_all(&self.pool)
270 .await;
271
272 let res = match res {
273 Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
274 Ok(v) => v,
275 };
276
277 for i in res {
278 let host = HostId(
279 Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
280 );
281
282 status.set_raw(host, i.1, i.2 as u64);
283 }
284
285 Ok(status)
286 }
287
288 async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
289 let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
290 .bind(tag)
291 .map(Self::query_row)
292 .fetch_all(&self.pool)
293 .await?;
294
295 Ok(res)
296 }
297
298 async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> {
301 let all = self.load_all().await?;
309
310 let re_encrypted = all
311 .into_iter()
312 .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key))
313 .collect::<Result<Vec<_>>>()?;
314
315 let mut tx = self.pool.begin().await?;
319
320 let res = sqlx::query("delete from store").execute(&mut *tx).await?;
321
322 let rows = res.rows_affected();
323 debug!("deleted {rows} rows");
324
325 for record in re_encrypted {
329 Self::save_raw(&mut tx, &record).await?;
330 }
331
332 tx.commit().await?;
333
334 Ok(())
335 }
336
337 async fn verify(&self, key: &[u8; 32]) -> Result<()> {
340 let all = self.load_all().await?;
341
342 all.into_iter()
343 .map(|record| record.decrypt::<PASETO_V4>(key))
344 .collect::<Result<Vec<_>>>()?;
345
346 Ok(())
347 }
348
349 async fn purge(&self, key: &[u8; 32]) -> Result<()> {
352 let all = self.load_all().await?;
353
354 for record in all.iter() {
355 match record.clone().decrypt::<PASETO_V4>(key) {
356 Ok(_) => continue,
357 Err(_) => {
358 println!(
359 "Failed to decrypt {}, deleting",
360 record.id.0.as_hyphenated()
361 );
362
363 self.delete(record.id).await?;
364 }
365 }
366 }
367
368 Ok(())
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use atuin_common::{
375 record::{DecryptedData, EncryptedData, Host, HostId, Record},
376 utils::uuid_v7,
377 };
378
379 use crate::{
380 encryption::generate_encoded_key,
381 record::{encryption::PASETO_V4, store::Store},
382 settings::test_local_timeout,
383 };
384
385 use super::SqliteStore;
386
387 fn test_record() -> Record<EncryptedData> {
388 Record::builder()
389 .host(Host::new(HostId(atuin_common::utils::uuid_v7())))
390 .version("v1".into())
391 .tag(atuin_common::utils::uuid_v7().simple().to_string())
392 .data(EncryptedData {
393 data: "1234".into(),
394 content_encryption_key: "1234".into(),
395 })
396 .idx(0)
397 .build()
398 }
399
400 #[tokio::test]
401 async fn create_db() {
402 let db = SqliteStore::new(":memory:", test_local_timeout()).await;
403
404 assert!(
405 db.is_ok(),
406 "db could not be created, {:?}",
407 db.err().unwrap()
408 );
409 }
410
411 #[tokio::test]
412 async fn push_record() {
413 let db = SqliteStore::new(":memory:", test_local_timeout())
414 .await
415 .unwrap();
416 let record = test_record();
417
418 db.push(&record).await.expect("failed to insert record");
419 }
420
421 #[tokio::test]
422 async fn get_record() {
423 let db = SqliteStore::new(":memory:", test_local_timeout())
424 .await
425 .unwrap();
426 let record = test_record();
427 db.push(&record).await.unwrap();
428
429 let new_record = db.get(record.id).await.expect("failed to fetch record");
430
431 assert_eq!(record, new_record, "records are not equal");
432 }
433
434 #[tokio::test]
435 async fn last() {
436 let db = SqliteStore::new(":memory:", test_local_timeout())
437 .await
438 .unwrap();
439 let record = test_record();
440 db.push(&record).await.unwrap();
441
442 let last = db
443 .last(record.host.id, record.tag.as_str())
444 .await
445 .expect("failed to get store len");
446
447 assert_eq!(
448 last.unwrap().id,
449 record.id,
450 "expected to get back the same record that was inserted"
451 );
452 }
453
454 #[tokio::test]
455 async fn first() {
456 let db = SqliteStore::new(":memory:", test_local_timeout())
457 .await
458 .unwrap();
459 let record = test_record();
460 db.push(&record).await.unwrap();
461
462 let first = db
463 .first(record.host.id, record.tag.as_str())
464 .await
465 .expect("failed to get store len");
466
467 assert_eq!(
468 first.unwrap().id,
469 record.id,
470 "expected to get back the same record that was inserted"
471 );
472 }
473
474 #[tokio::test]
475 async fn len() {
476 let db = SqliteStore::new(":memory:", test_local_timeout())
477 .await
478 .unwrap();
479 let record = test_record();
480 db.push(&record).await.unwrap();
481
482 let len = db
483 .len(record.host.id, record.tag.as_str())
484 .await
485 .expect("failed to get store len");
486
487 assert_eq!(len, 1, "expected length of 1 after insert");
488 }
489
490 #[tokio::test]
491 async fn len_tag() {
492 let db = SqliteStore::new(":memory:", test_local_timeout())
493 .await
494 .unwrap();
495 let record = test_record();
496 db.push(&record).await.unwrap();
497
498 let len = db
499 .len_tag(record.tag.as_str())
500 .await
501 .expect("failed to get store len");
502
503 assert_eq!(len, 1, "expected length of 1 after insert");
504 }
505
506 #[tokio::test]
507 async fn len_different_tags() {
508 let db = SqliteStore::new(":memory:", test_local_timeout())
509 .await
510 .unwrap();
511
512 let first = test_record();
516 let second = test_record();
517
518 db.push(&first).await.unwrap();
519 db.push(&second).await.unwrap();
520
521 let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
522 let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
523
524 assert_eq!(first_len, 1, "expected length of 1 after insert");
525 assert_eq!(second_len, 1, "expected length of 1 after insert");
526 }
527
528 #[tokio::test]
529 async fn append_a_bunch() {
530 let db = SqliteStore::new(":memory:", test_local_timeout())
531 .await
532 .unwrap();
533
534 let mut tail = test_record();
535 db.push(&tail).await.expect("failed to push record");
536
537 for _ in 1..100 {
538 tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
539 db.push(&tail).await.unwrap();
540 }
541
542 assert_eq!(
543 db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
544 100,
545 "failed to insert 100 records"
546 );
547
548 assert_eq!(
549 db.len_tag(tail.tag.as_str()).await.unwrap(),
550 100,
551 "failed to insert 100 records"
552 );
553 }
554
555 #[tokio::test]
556 async fn append_a_big_bunch() {
557 let db = SqliteStore::new(":memory:", test_local_timeout())
558 .await
559 .unwrap();
560
561 let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000);
562
563 let mut tail = test_record();
564 records.push(tail.clone());
565
566 for _ in 1..10000 {
567 tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
568 records.push(tail.clone());
569 }
570
571 db.push_batch(records.iter()).await.unwrap();
572
573 assert_eq!(
574 db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
575 10000,
576 "failed to insert 10k records"
577 );
578 }
579
580 #[tokio::test]
581 async fn re_encrypt() {
582 let store = SqliteStore::new(":memory:", test_local_timeout())
583 .await
584 .unwrap();
585 let (key, _) = generate_encoded_key().unwrap();
586 let data = vec![0u8, 1u8, 2u8, 3u8];
587 let host_id = HostId(uuid_v7());
588
589 for i in 0..10 {
590 let record = Record::builder()
591 .host(Host::new(host_id))
592 .version(String::from("test"))
593 .tag(String::from("test"))
594 .idx(i)
595 .data(DecryptedData(data.clone()))
596 .build();
597
598 let record = record.encrypt::<PASETO_V4>(&key.into());
599 store
600 .push(&record)
601 .await
602 .expect("failed to push encrypted record");
603 }
604
605 let all = store.all_tagged("test").await.unwrap();
607
608 assert_eq!(all.len(), 10, "failed to fetch all records");
609
610 for record in all {
611 let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap();
612 assert_eq!(decrypted.data.0, data);
613 }
614
615 let (new_key, _) = generate_encoded_key().unwrap();
620 store
621 .re_encrypt(&key.into(), &new_key.into())
622 .await
623 .expect("failed to re-encrypt store");
624
625 let all = store.all_tagged("test").await.unwrap();
626
627 for record in all.iter() {
628 let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into());
629 assert!(
630 decrypted.is_err(),
631 "did not get error decrypting with old key after re-encrypt"
632 )
633 }
634
635 for record in all {
636 let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap();
637 assert_eq!(decrypted.data.0, data);
638 }
639
640 assert_eq!(store.len(host_id, "test").await.unwrap(), 10);
641 }
642}