1use std::{path::Path, str::FromStr, time::Duration};
2
3use atuin_common::utils;
4use sqlx::{
5 Result, Row,
6 sqlite::{
7 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
8 SqliteSynchronous,
9 },
10};
11use tokio::fs;
12use tracing::debug;
13
14use crate::store::entry::KvEntry;
15
16#[derive(Debug, Clone)]
17pub struct Database {
18 pub pool: SqlitePool,
19}
20
21impl Database {
22 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
23 let path = path.as_ref();
24 debug!("opening KV sqlite database at {:?}", path);
25
26 if utils::broken_symlink(path) {
27 eprintln!(
28 "Atuin: KV sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
29 );
30 std::process::exit(1);
31 }
32
33 if !path.exists()
34 && let Some(dir) = path.parent()
35 {
36 fs::create_dir_all(dir).await?;
37 }
38
39 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
40 .journal_mode(SqliteJournalMode::Wal)
41 .optimize_on_close(true, None)
42 .synchronous(SqliteSynchronous::Normal)
43 .with_regexp()
44 .foreign_keys(true)
45 .create_if_missing(true);
46
47 let pool = SqlitePoolOptions::new()
48 .acquire_timeout(Duration::from_secs_f64(timeout))
49 .connect_with(opts)
50 .await?;
51
52 Self::setup_db(&pool).await?;
53 Ok(Self { pool })
54 }
55
56 pub async fn sqlite_version(&self) -> Result<String> {
57 sqlx::query_scalar("SELECT sqlite_version()")
58 .fetch_one(&self.pool)
59 .await
60 }
61
62 async fn setup_db(pool: &SqlitePool) -> Result<()> {
63 debug!("running sqlite database setup");
64
65 sqlx::migrate!("./migrations").run(pool).await?;
66
67 Ok(())
68 }
69
70 async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, e: &KvEntry) -> Result<()> {
71 sqlx::query(
72 "insert into kv(namespace, key, value)
73 values(?1, ?2, ?3)
74 on conflict(namespace, key) do update set
75 namespace = excluded.namespace,
76 key = excluded.key,
77 value = excluded.value",
78 )
79 .bind(e.namespace.as_str())
80 .bind(e.key.as_str())
81 .bind(e.value.as_str())
82 .execute(&mut **tx)
83 .await?;
84
85 Ok(())
86 }
87
88 async fn delete_raw(
89 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
90 namespace: &str,
91 key: &str,
92 ) -> Result<()> {
93 sqlx::query("delete from kv where namespace = ?1 and key = ?2")
94 .bind(namespace)
95 .bind(key)
96 .execute(&mut **tx)
97 .await?;
98 Ok(())
99 }
100
101 pub async fn save(&self, e: &KvEntry) -> Result<()> {
102 debug!("saving kv entry to sqlite");
103 let mut tx = self.pool.begin().await?;
104 Self::save_raw(&mut tx, e).await?;
105 tx.commit().await?;
106
107 Ok(())
108 }
109
110 pub async fn delete(&self, namespace: &str, key: &str) -> Result<()> {
111 debug!("deleting kv entry {namespace}/{key}");
112
113 let mut tx = self.pool.begin().await?;
114 Self::delete_raw(&mut tx, namespace, key).await?;
115 tx.commit().await?;
116
117 Ok(())
118 }
119
120 fn query_kv_entry(row: SqliteRow) -> KvEntry {
121 let namespace = row.get("namespace");
122 let key = row.get("key");
123 let value = row.get("value");
124
125 KvEntry::builder()
126 .namespace(namespace)
127 .key(key)
128 .value(value)
129 .build()
130 }
131
132 pub async fn load(&self, namespace: &str, key: &str) -> Result<Option<KvEntry>> {
133 debug!("loading kv entry {namespace}.{key}");
134
135 let res = sqlx::query("select * from kv where namespace = ?1 and key = ?2")
136 .bind(namespace)
137 .bind(key)
138 .map(Self::query_kv_entry)
139 .fetch_optional(&self.pool)
140 .await?;
141
142 Ok(res)
143 }
144
145 pub async fn list(&self, namespace: Option<&str>) -> Result<Vec<KvEntry>> {
146 debug!("listing kv entries");
147
148 let res = if let Some(namespace) = namespace {
149 sqlx::query("select * from kv where namespace = ?1 order by key asc")
150 .bind(namespace)
151 .map(Self::query_kv_entry)
152 .fetch_all(&self.pool)
153 .await?
154 } else {
155 sqlx::query("select * from kv order by namespace, key asc")
156 .map(Self::query_kv_entry)
157 .fetch_all(&self.pool)
158 .await?
159 };
160
161 Ok(res)
162 }
163}
164
165#[cfg(test)]
166mod test {
167 use super::*;
168
169 #[tokio::test]
170 async fn test_list() {
171 let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
172 let scripts = db.list(None).await.unwrap();
173 assert_eq!(scripts.len(), 0);
174
175 let entry = KvEntry::builder()
176 .namespace("test".to_string())
177 .key("test".to_string())
178 .value("test".to_string())
179 .build();
180
181 db.save(&entry).await.unwrap();
182
183 let entries = db.list(None).await.unwrap();
184 assert_eq!(entries.len(), 1);
185 assert_eq!(entries[0].namespace, "test");
186 assert_eq!(entries[0].key, "test");
187 assert_eq!(entries[0].value, "test");
188 }
189
190 #[tokio::test]
191 async fn test_save_load() {
192 let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
193
194 let entry = KvEntry::builder()
195 .namespace("test".to_string())
196 .key("test".to_string())
197 .value("test".to_string())
198 .build();
199
200 db.save(&entry).await.unwrap();
201
202 let loaded = db
203 .load(&entry.namespace, &entry.key)
204 .await
205 .unwrap()
206 .unwrap();
207
208 assert_eq!(loaded, entry);
209 }
210
211 #[tokio::test]
212 async fn test_delete() {
213 let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
214
215 let entry = KvEntry::builder()
216 .namespace("test".to_string())
217 .key("test".to_string())
218 .value("test".to_string())
219 .build();
220
221 db.save(&entry).await.unwrap();
222
223 assert_eq!(db.list(None).await.unwrap().len(), 1);
224 db.delete(&entry.namespace, &entry.key).await.unwrap();
225
226 let loaded = db.list(None).await.unwrap();
227 assert_eq!(loaded.len(), 0);
228 }
229}