atuin_scripts/
database.rs1use std::{path::Path, str::FromStr, time::Duration};
2
3use atuin_common::utils;
4use sqlx::{
5 Result, Row,
6 sqlite::{
7 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
8 SqliteSynchronous,
9 },
10};
11use tokio::fs;
12use tracing::debug;
13use uuid::Uuid;
14
15use crate::store::script::Script;
16
17#[derive(Debug, Clone)]
18pub struct Database {
19 pub pool: SqlitePool,
20}
21
22impl Database {
23 pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
24 let path = path.as_ref();
25 debug!("opening script sqlite database at {:?}", path);
26
27 if utils::broken_symlink(path) {
28 eprintln!(
29 "Atuin: Script sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
30 );
31 std::process::exit(1);
32 }
33
34 if !path.exists()
35 && let Some(dir) = path.parent()
36 {
37 fs::create_dir_all(dir).await?;
38 }
39
40 let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
41 .journal_mode(SqliteJournalMode::Wal)
42 .optimize_on_close(true, None)
43 .synchronous(SqliteSynchronous::Normal)
44 .with_regexp()
45 .foreign_keys(true)
46 .create_if_missing(true);
47
48 let pool = SqlitePoolOptions::new()
49 .acquire_timeout(Duration::from_secs_f64(timeout))
50 .connect_with(opts)
51 .await?;
52
53 Self::setup_db(&pool).await?;
54 Ok(Self { pool })
55 }
56
57 pub async fn sqlite_version(&self) -> Result<String> {
58 sqlx::query_scalar("SELECT sqlite_version()")
59 .fetch_one(&self.pool)
60 .await
61 }
62
63 async fn setup_db(pool: &SqlitePool) -> Result<()> {
64 debug!("running sqlite database setup");
65
66 sqlx::migrate!("./migrations").run(pool).await?;
67
68 Ok(())
69 }
70
71 async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, s: &Script) -> Result<()> {
72 sqlx::query(
73 "insert or ignore into scripts(id, name, description, shebang, script)
74 values(?1, ?2, ?3, ?4, ?5)",
75 )
76 .bind(s.id.to_string())
77 .bind(s.name.as_str())
78 .bind(s.description.as_str())
79 .bind(s.shebang.as_str())
80 .bind(s.script.as_str())
81 .execute(&mut **tx)
82 .await?;
83
84 for tag in s.tags.iter() {
85 sqlx::query(
86 "insert or ignore into script_tags(script_id, tag)
87 values(?1, ?2)",
88 )
89 .bind(s.id.to_string())
90 .bind(tag)
91 .execute(&mut **tx)
92 .await?;
93 }
94
95 Ok(())
96 }
97
98 pub async fn save(&self, s: &Script) -> Result<()> {
99 debug!("saving script to sqlite");
100 let mut tx = self.pool.begin().await?;
101 Self::save_raw(&mut tx, s).await?;
102 tx.commit().await?;
103
104 Ok(())
105 }
106
107 pub async fn save_bulk(&self, s: &[Script]) -> Result<()> {
108 debug!("saving scripts to sqlite");
109
110 let mut tx = self.pool.begin().await?;
111
112 for i in s {
113 Self::save_raw(&mut tx, i).await?;
114 }
115
116 tx.commit().await?;
117
118 Ok(())
119 }
120
121 fn query_script(row: SqliteRow) -> Script {
122 let id = row.get("id");
123 let name = row.get("name");
124 let description = row.get("description");
125 let shebang = row.get("shebang");
126 let script = row.get("script");
127
128 let id = Uuid::parse_str(id).unwrap();
129
130 Script {
131 id,
132 name,
133 description,
134 shebang,
135 script,
136 tags: vec![],
137 }
138 }
139
140 fn query_script_tags(row: SqliteRow) -> String {
141 row.get("tag")
142 }
143
144 #[allow(dead_code)]
145 async fn load(&self, id: &str) -> Result<Option<Script>> {
146 debug!("loading script item {}", id);
147
148 let res = sqlx::query("select * from scripts where id = ?1")
149 .bind(id)
150 .map(Self::query_script)
151 .fetch_optional(&self.pool)
152 .await?;
153
154 if let Some(mut script) = res {
156 let tags = sqlx::query("select tag from script_tags where script_id = ?1")
157 .bind(id)
158 .map(Self::query_script_tags)
159 .fetch_all(&self.pool)
160 .await?;
161
162 script.tags = tags;
163 Ok(Some(script))
164 } else {
165 Ok(None)
166 }
167 }
168
169 pub async fn list(&self) -> Result<Vec<Script>> {
170 debug!("listing scripts");
171
172 let mut res = sqlx::query("select * from scripts")
173 .map(Self::query_script)
174 .fetch_all(&self.pool)
175 .await?;
176
177 for script in res.iter_mut() {
179 let tags = sqlx::query("select tag from script_tags where script_id = ?1")
180 .bind(script.id.to_string())
181 .map(Self::query_script_tags)
182 .fetch_all(&self.pool)
183 .await?;
184
185 script.tags = tags;
186 }
187
188 Ok(res)
189 }
190
191 pub async fn clear(&self) -> Result<()> {
192 debug!("clearing all scripts from sqlite");
193
194 sqlx::query("delete from script_tags")
195 .execute(&self.pool)
196 .await?;
197 sqlx::query("delete from scripts")
198 .execute(&self.pool)
199 .await?;
200
201 Ok(())
202 }
203
204 pub async fn delete(&self, id: &str) -> Result<()> {
205 debug!("deleting script {}", id);
206
207 sqlx::query("delete from scripts where id = ?1")
208 .bind(id)
209 .execute(&self.pool)
210 .await?;
211
212 sqlx::query("delete from script_tags where script_id = ?1")
214 .bind(id)
215 .execute(&self.pool)
216 .await?;
217
218 Ok(())
219 }
220
221 pub async fn update(&self, s: &Script) -> Result<()> {
222 debug!("updating script {:?}", s);
223
224 let mut tx = self.pool.begin().await?;
225
226 sqlx::query("update scripts set name = ?1, description = ?2, shebang = ?3, script = ?4 where id = ?5")
228 .bind(s.name.as_str())
229 .bind(s.description.as_str())
230 .bind(s.shebang.as_str())
231 .bind(s.script.as_str())
232 .bind(s.id.to_string())
233 .execute(&mut *tx)
234 .await?;
235
236 sqlx::query("delete from script_tags where script_id = ?1")
238 .bind(s.id.to_string())
239 .execute(&mut *tx)
240 .await?;
241
242 for tag in s.tags.iter() {
244 sqlx::query(
245 "insert or ignore into script_tags(script_id, tag)
246 values(?1, ?2)",
247 )
248 .bind(s.id.to_string())
249 .bind(tag)
250 .execute(&mut *tx)
251 .await?;
252 }
253
254 tx.commit().await?;
255
256 Ok(())
257 }
258
259 pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
260 let res = sqlx::query("select * from scripts where name = ?1")
261 .bind(name)
262 .map(Self::query_script)
263 .fetch_optional(&self.pool)
264 .await?;
265
266 let script = if let Some(mut script) = res {
267 let tags = sqlx::query("select tag from script_tags where script_id = ?1")
268 .bind(script.id.to_string())
269 .map(Self::query_script_tags)
270 .fetch_all(&self.pool)
271 .await?;
272
273 script.tags = tags;
274 Some(script)
275 } else {
276 None
277 };
278
279 Ok(script)
280 }
281}
282
283#[cfg(test)]
284mod test {
285 use super::*;
286
287 #[tokio::test]
288 async fn test_list() {
289 let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
290 let scripts = db.list().await.unwrap();
291 assert_eq!(scripts.len(), 0);
292
293 let script = Script::builder()
294 .name("test".to_string())
295 .description("test".to_string())
296 .shebang("test".to_string())
297 .script("test".to_string())
298 .build();
299
300 db.save(&script).await.unwrap();
301
302 let scripts = db.list().await.unwrap();
303 assert_eq!(scripts.len(), 1);
304 assert_eq!(scripts[0].name, "test");
305 }
306
307 #[tokio::test]
308 async fn test_save_load() {
309 let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
310
311 let script = Script::builder()
312 .name("test name".to_string())
313 .description("test description".to_string())
314 .shebang("test shebang".to_string())
315 .script("test script".to_string())
316 .build();
317
318 db.save(&script).await.unwrap();
319
320 let loaded = db.load(&script.id.to_string()).await.unwrap().unwrap();
321
322 assert_eq!(loaded, script);
323 }
324
325 #[tokio::test]
326 async fn test_save_bulk() {
327 let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
328
329 let scripts = vec![
330 Script::builder()
331 .name("test name".to_string())
332 .description("test description".to_string())
333 .shebang("test shebang".to_string())
334 .script("test script".to_string())
335 .build(),
336 Script::builder()
337 .name("test name 2".to_string())
338 .description("test description 2".to_string())
339 .shebang("test shebang 2".to_string())
340 .script("test script 2".to_string())
341 .build(),
342 ];
343
344 db.save_bulk(&scripts).await.unwrap();
345
346 let loaded = db.list().await.unwrap();
347 assert_eq!(loaded.len(), 2);
348 assert_eq!(loaded[0].name, "test name");
349 assert_eq!(loaded[1].name, "test name 2");
350 }
351
352 #[tokio::test]
353 async fn test_delete() {
354 let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
355
356 let script = Script::builder()
357 .name("test name".to_string())
358 .description("test description".to_string())
359 .shebang("test shebang".to_string())
360 .script("test script".to_string())
361 .build();
362
363 db.save(&script).await.unwrap();
364
365 assert_eq!(db.list().await.unwrap().len(), 1);
366 db.delete(&script.id.to_string()).await.unwrap();
367
368 let loaded = db.list().await.unwrap();
369 assert_eq!(loaded.len(), 0);
370 }
371}