Skip to main content

atuin_scripts/
database.rs

1use std::{path::Path, str::FromStr, time::Duration};
2
3use atuin_common::utils;
4use sqlx::{
5    Result, Row,
6    sqlite::{
7        SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
8        SqliteSynchronous,
9    },
10};
11use tokio::fs;
12use tracing::debug;
13use uuid::Uuid;
14
15use crate::store::script::Script;
16
17#[derive(Debug, Clone)]
18pub struct Database {
19    pub pool: SqlitePool,
20}
21
22impl Database {
23    pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> {
24        let path = path.as_ref();
25        debug!("opening script sqlite database at {:?}", path);
26
27        if utils::broken_symlink(path) {
28            eprintln!(
29                "Atuin: Script sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement."
30            );
31            std::process::exit(1);
32        }
33
34        if !path.exists()
35            && let Some(dir) = path.parent()
36        {
37            fs::create_dir_all(dir).await?;
38        }
39
40        let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
41            .journal_mode(SqliteJournalMode::Wal)
42            .optimize_on_close(true, None)
43            .synchronous(SqliteSynchronous::Normal)
44            .with_regexp()
45            .foreign_keys(true)
46            .create_if_missing(true);
47
48        let pool = SqlitePoolOptions::new()
49            .acquire_timeout(Duration::from_secs_f64(timeout))
50            .connect_with(opts)
51            .await?;
52
53        Self::setup_db(&pool).await?;
54        Ok(Self { pool })
55    }
56
57    pub async fn sqlite_version(&self) -> Result<String> {
58        sqlx::query_scalar("SELECT sqlite_version()")
59            .fetch_one(&self.pool)
60            .await
61    }
62
63    async fn setup_db(pool: &SqlitePool) -> Result<()> {
64        debug!("running sqlite database setup");
65
66        sqlx::migrate!("./migrations").run(pool).await?;
67
68        Ok(())
69    }
70
71    async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, s: &Script) -> Result<()> {
72        sqlx::query(
73            "insert or ignore into scripts(id, name, description, shebang, script)
74                values(?1, ?2, ?3, ?4, ?5)",
75        )
76        .bind(s.id.to_string())
77        .bind(s.name.as_str())
78        .bind(s.description.as_str())
79        .bind(s.shebang.as_str())
80        .bind(s.script.as_str())
81        .execute(&mut **tx)
82        .await?;
83
84        for tag in s.tags.iter() {
85            sqlx::query(
86                "insert or ignore into script_tags(script_id, tag)
87                values(?1, ?2)",
88            )
89            .bind(s.id.to_string())
90            .bind(tag)
91            .execute(&mut **tx)
92            .await?;
93        }
94
95        Ok(())
96    }
97
98    pub async fn save(&self, s: &Script) -> Result<()> {
99        debug!("saving script to sqlite");
100        let mut tx = self.pool.begin().await?;
101        Self::save_raw(&mut tx, s).await?;
102        tx.commit().await?;
103
104        Ok(())
105    }
106
107    pub async fn save_bulk(&self, s: &[Script]) -> Result<()> {
108        debug!("saving scripts to sqlite");
109
110        let mut tx = self.pool.begin().await?;
111
112        for i in s {
113            Self::save_raw(&mut tx, i).await?;
114        }
115
116        tx.commit().await?;
117
118        Ok(())
119    }
120
121    fn query_script(row: SqliteRow) -> Script {
122        let id = row.get("id");
123        let name = row.get("name");
124        let description = row.get("description");
125        let shebang = row.get("shebang");
126        let script = row.get("script");
127
128        let id = Uuid::parse_str(id).unwrap();
129
130        Script {
131            id,
132            name,
133            description,
134            shebang,
135            script,
136            tags: vec![],
137        }
138    }
139
140    fn query_script_tags(row: SqliteRow) -> String {
141        row.get("tag")
142    }
143
144    #[allow(dead_code)]
145    async fn load(&self, id: &str) -> Result<Option<Script>> {
146        debug!("loading script item {}", id);
147
148        let res = sqlx::query("select * from scripts where id = ?1")
149            .bind(id)
150            .map(Self::query_script)
151            .fetch_optional(&self.pool)
152            .await?;
153
154        // intentionally not joining, don't want to duplicate the script data in memory a whole bunch.
155        if let Some(mut script) = res {
156            let tags = sqlx::query("select tag from script_tags where script_id = ?1")
157                .bind(id)
158                .map(Self::query_script_tags)
159                .fetch_all(&self.pool)
160                .await?;
161
162            script.tags = tags;
163            Ok(Some(script))
164        } else {
165            Ok(None)
166        }
167    }
168
169    pub async fn list(&self) -> Result<Vec<Script>> {
170        debug!("listing scripts");
171
172        let mut res = sqlx::query("select * from scripts")
173            .map(Self::query_script)
174            .fetch_all(&self.pool)
175            .await?;
176
177        // Fetch all the tags for each script
178        for script in res.iter_mut() {
179            let tags = sqlx::query("select tag from script_tags where script_id = ?1")
180                .bind(script.id.to_string())
181                .map(Self::query_script_tags)
182                .fetch_all(&self.pool)
183                .await?;
184
185            script.tags = tags;
186        }
187
188        Ok(res)
189    }
190
191    pub async fn clear(&self) -> Result<()> {
192        debug!("clearing all scripts from sqlite");
193
194        sqlx::query("delete from script_tags")
195            .execute(&self.pool)
196            .await?;
197        sqlx::query("delete from scripts")
198            .execute(&self.pool)
199            .await?;
200
201        Ok(())
202    }
203
204    pub async fn delete(&self, id: &str) -> Result<()> {
205        debug!("deleting script {}", id);
206
207        sqlx::query("delete from scripts where id = ?1")
208            .bind(id)
209            .execute(&self.pool)
210            .await?;
211
212        // delete all the tags for the script
213        sqlx::query("delete from script_tags where script_id = ?1")
214            .bind(id)
215            .execute(&self.pool)
216            .await?;
217
218        Ok(())
219    }
220
221    pub async fn update(&self, s: &Script) -> Result<()> {
222        debug!("updating script {:?}", s);
223
224        let mut tx = self.pool.begin().await?;
225
226        // Update the script's base fields
227        sqlx::query("update scripts set name = ?1, description = ?2, shebang = ?3, script = ?4 where id = ?5")
228            .bind(s.name.as_str())
229            .bind(s.description.as_str())
230            .bind(s.shebang.as_str())
231            .bind(s.script.as_str())
232            .bind(s.id.to_string())
233            .execute(&mut *tx)
234            .await?;
235
236        // Delete all existing tags for this script
237        sqlx::query("delete from script_tags where script_id = ?1")
238            .bind(s.id.to_string())
239            .execute(&mut *tx)
240            .await?;
241
242        // Insert new tags
243        for tag in s.tags.iter() {
244            sqlx::query(
245                "insert or ignore into script_tags(script_id, tag)
246                values(?1, ?2)",
247            )
248            .bind(s.id.to_string())
249            .bind(tag)
250            .execute(&mut *tx)
251            .await?;
252        }
253
254        tx.commit().await?;
255
256        Ok(())
257    }
258
259    pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
260        let res = sqlx::query("select * from scripts where name = ?1")
261            .bind(name)
262            .map(Self::query_script)
263            .fetch_optional(&self.pool)
264            .await?;
265
266        let script = if let Some(mut script) = res {
267            let tags = sqlx::query("select tag from script_tags where script_id = ?1")
268                .bind(script.id.to_string())
269                .map(Self::query_script_tags)
270                .fetch_all(&self.pool)
271                .await?;
272
273            script.tags = tags;
274            Some(script)
275        } else {
276            None
277        };
278
279        Ok(script)
280    }
281}
282
283#[cfg(test)]
284mod test {
285    use super::*;
286
287    #[tokio::test]
288    async fn test_list() {
289        let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
290        let scripts = db.list().await.unwrap();
291        assert_eq!(scripts.len(), 0);
292
293        let script = Script::builder()
294            .name("test".to_string())
295            .description("test".to_string())
296            .shebang("test".to_string())
297            .script("test".to_string())
298            .build();
299
300        db.save(&script).await.unwrap();
301
302        let scripts = db.list().await.unwrap();
303        assert_eq!(scripts.len(), 1);
304        assert_eq!(scripts[0].name, "test");
305    }
306
307    #[tokio::test]
308    async fn test_save_load() {
309        let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
310
311        let script = Script::builder()
312            .name("test name".to_string())
313            .description("test description".to_string())
314            .shebang("test shebang".to_string())
315            .script("test script".to_string())
316            .build();
317
318        db.save(&script).await.unwrap();
319
320        let loaded = db.load(&script.id.to_string()).await.unwrap().unwrap();
321
322        assert_eq!(loaded, script);
323    }
324
325    #[tokio::test]
326    async fn test_save_bulk() {
327        let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
328
329        let scripts = vec![
330            Script::builder()
331                .name("test name".to_string())
332                .description("test description".to_string())
333                .shebang("test shebang".to_string())
334                .script("test script".to_string())
335                .build(),
336            Script::builder()
337                .name("test name 2".to_string())
338                .description("test description 2".to_string())
339                .shebang("test shebang 2".to_string())
340                .script("test script 2".to_string())
341                .build(),
342        ];
343
344        db.save_bulk(&scripts).await.unwrap();
345
346        let loaded = db.list().await.unwrap();
347        assert_eq!(loaded.len(), 2);
348        assert_eq!(loaded[0].name, "test name");
349        assert_eq!(loaded[1].name, "test name 2");
350    }
351
352    #[tokio::test]
353    async fn test_delete() {
354        let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
355
356        let script = Script::builder()
357            .name("test name".to_string())
358            .description("test description".to_string())
359            .shebang("test shebang".to_string())
360            .script("test script".to_string())
361            .build();
362
363        db.save(&script).await.unwrap();
364
365        assert_eq!(db.list().await.unwrap().len(), 1);
366        db.delete(&script.id.to_string()).await.unwrap();
367
368        let loaded = db.list().await.unwrap();
369        assert_eq!(loaded.len(), 0);
370    }
371}