Skip to main content

atuin_kv/
database.rs

1use std::{path::Path, str::FromStr, time::Duration};
2
3use atuin_common::utils;
4use sqlx::{
5    Result, Row,
6    sqlite::{
7        SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
8        SqliteSynchronous,
9    },
10};
11use tokio::fs;
12use tracing::debug;
13
14use crate::store::entry::KvEntry;
15
16#[derive(Debug, Clone)]
17pub struct Database {
18    pub pool: SqlitePool,
19}
20
21impl Database {
22    pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
23        let path = path.as_ref();
24        debug!("opening KV sqlite database at {:?}", path);
25
26        if utils::broken_symlink(path) {
27            eprintln!(
28                "Atuin: KV sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
29            );
30            std::process::exit(1);
31        }
32
33        if !path.exists()
34            && let Some(dir) = path.parent()
35        {
36            fs::create_dir_all(dir).await?;
37        }
38
39        let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
40            .journal_mode(SqliteJournalMode::Wal)
41            .optimize_on_close(true, None)
42            .synchronous(SqliteSynchronous::Normal)
43            .with_regexp()
44            .foreign_keys(true)
45            .create_if_missing(true);
46
47        let pool = SqlitePoolOptions::new()
48            .acquire_timeout(Duration::from_secs_f64(timeout))
49            .connect_with(opts)
50            .await?;
51
52        Self::setup_db(&pool).await?;
53        Ok(Self { pool })
54    }
55
56    pub async fn sqlite_version(&self) -> Result<String> {
57        sqlx::query_scalar("SELECT sqlite_version()")
58            .fetch_one(&self.pool)
59            .await
60    }
61
62    async fn setup_db(pool: &SqlitePool) -> Result<()> {
63        debug!("running sqlite database setup");
64
65        sqlx::migrate!("./migrations").run(pool).await?;
66
67        Ok(())
68    }
69
70    async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, e: &KvEntry) -> Result<()> {
71        sqlx::query(
72            "insert into kv(namespace, key, value)
73                values(?1, ?2, ?3)
74                on conflict(namespace, key) do update set
75                    namespace = excluded.namespace,
76                    key = excluded.key,
77                    value = excluded.value",
78        )
79        .bind(e.namespace.as_str())
80        .bind(e.key.as_str())
81        .bind(e.value.as_str())
82        .execute(&mut **tx)
83        .await?;
84
85        Ok(())
86    }
87
88    async fn delete_raw(
89        tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
90        namespace: &str,
91        key: &str,
92    ) -> Result<()> {
93        sqlx::query("delete from kv where namespace = ?1 and key = ?2")
94            .bind(namespace)
95            .bind(key)
96            .execute(&mut **tx)
97            .await?;
98        Ok(())
99    }
100
101    pub async fn save(&self, e: &KvEntry) -> Result<()> {
102        debug!("saving kv entry to sqlite");
103        let mut tx = self.pool.begin().await?;
104        Self::save_raw(&mut tx, e).await?;
105        tx.commit().await?;
106
107        Ok(())
108    }
109
110    pub async fn delete(&self, namespace: &str, key: &str) -> Result<()> {
111        debug!("deleting kv entry {namespace}/{key}");
112
113        let mut tx = self.pool.begin().await?;
114        Self::delete_raw(&mut tx, namespace, key).await?;
115        tx.commit().await?;
116
117        Ok(())
118    }
119
120    fn query_kv_entry(row: SqliteRow) -> KvEntry {
121        let namespace = row.get("namespace");
122        let key = row.get("key");
123        let value = row.get("value");
124
125        KvEntry::builder()
126            .namespace(namespace)
127            .key(key)
128            .value(value)
129            .build()
130    }
131
132    pub async fn load(&self, namespace: &str, key: &str) -> Result<Option<KvEntry>> {
133        debug!("loading kv entry {namespace}.{key}");
134
135        let res = sqlx::query("select * from kv where namespace = ?1 and key = ?2")
136            .bind(namespace)
137            .bind(key)
138            .map(Self::query_kv_entry)
139            .fetch_optional(&self.pool)
140            .await?;
141
142        Ok(res)
143    }
144
145    pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> {
146        debug!("listing kv entries");
147
148        let res = if let Some(namespace) = namespace {
149            sqlx::query("select * from kv where namespace = ?1 order by key asc")
150                .bind(namespace)
151                .map(Self::query_kv_entry)
152                .fetch_all(&self.pool)
153                .await?
154        } else {
155            sqlx::query("select * from kv order by namespace, key asc")
156                .map(Self::query_kv_entry)
157                .fetch_all(&self.pool)
158                .await?
159        };
160
161        Ok(res)
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use super::*;
168
169    #[tokio::test]
170    async fn test_list() {
171        let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
172        let scripts = db.list(None).await.unwrap();
173        assert_eq!(scripts.len(), 0);
174
175        let entry = KvEntry::builder()
176            .namespace("test".to_string())
177            .key("test".to_string())
178            .value("test".to_string())
179            .build();
180
181        db.save(&entry).await.unwrap();
182
183        let entries = db.list(None).await.unwrap();
184        assert_eq!(entries.len(), 1);
185        assert_eq!(entries[0].namespace, "test");
186        assert_eq!(entries[0].key, "test");
187        assert_eq!(entries[0].value, "test");
188    }
189
190    #[tokio::test]
191    async fn test_save_load() {
192        let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
193
194        let entry = KvEntry::builder()
195            .namespace("test".to_string())
196            .key("test".to_string())
197            .value("test".to_string())
198            .build();
199
200        db.save(&entry).await.unwrap();
201
202        let loaded = db
203            .load(&entry.namespace, &entry.key)
204            .await
205            .unwrap()
206            .unwrap();
207
208        assert_eq!(loaded, entry);
209    }
210
211    #[tokio::test]
212    async fn test_delete() {
213        let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
214
215        let entry = KvEntry::builder()
216            .namespace("test".to_string())
217            .key("test".to_string())
218            .value("test".to_string())
219            .build();
220
221        db.save(&entry).await.unwrap();
222
223        assert_eq!(db.list(None).await.unwrap().len(), 1);
224        db.delete(&entry.namespace, &entry.key).await.unwrap();
225
226        let loaded = db.list(None).await.unwrap();
227        assert_eq!(loaded.len(), 0);
228    }
229}