intelli_shell/
storage.rs

1use core::slice;
2use std::{
3    env, fs,
4    io::{BufRead, BufReader, BufWriter, Write},
5    sync::Mutex,
6};
7
8use anyhow::{anyhow, Context, Result};
9use directories::ProjectDirs;
10use iter_flow::Iterflow;
11use itertools::Itertools;
12use once_cell::sync::Lazy;
13use regex::Regex;
14use rusqlite::{params_from_iter, Connection, Error, ErrorCode, OptionalExtension, Row};
15use rusqlite_migration::{Migrations, M};
16
17use crate::{
18    common::flatten_str,
19    model::{Command, LabelSuggestion},
20};
21
22/// Database migrations
23static MIGRATIONS: Lazy<Migrations> = Lazy::new(|| {
24    Migrations::new(vec![
25        M::up(
26            r#"CREATE TABLE command (
27                category TEXT NOT NULL,
28                alias TEXT NULL,
29                cmd TEXT NOT NULL UNIQUE,
30                description TEXT NOT NULL,
31                usage INTEGER DEFAULT 0
32            );"#,
33        ),
34        M::up(r#"CREATE VIRTUAL TABLE command_fts USING fts5(flat_cmd, flat_description);"#),
35        M::up(
36            r#"CREATE TABLE label_suggestion (
37                flat_root_cmd TEXT NOT NULL,
38                flat_label TEXT NOT NULL,
39                suggestion TEXT NOT NULL,
40                usage INTEGER DEFAULT 0,
41                PRIMARY KEY (flat_root_cmd, flat_label, suggestion)
42            );"#,
43        ),
44    ])
45});
46
47/// Category for user defined commands
48pub const USER_CATEGORY: &str = "user";
49
50/// Regex to match not allowed FTS characters
51static ALLOWED_FTS_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r#"[^a-zA-Z0-9 ]"#).unwrap());
52
53/// SQLite-based storage
54pub struct SqliteStorage {
55    conn: Mutex<Connection>,
56}
57
58impl SqliteStorage {
59    /// Builds a new SQLite storage on the default path
60    pub fn new() -> Result<Self> {
61        let path = env::var_os("INTELLI_HOME")
62            .map(Into::into)
63            .map(anyhow::Ok)
64            .unwrap_or_else(|| {
65                Ok(ProjectDirs::from("org", "IntelliShell", "Intelli-Shell")
66                    .context("Error initializing project dir")?
67                    .data_dir()
68                    .to_path_buf())
69            })?;
70
71        fs::create_dir_all(&path).context("Could't create data dir")?;
72
73        Ok(Self {
74            conn: Mutex::new(
75                Self::initialize_connection(
76                    Connection::open(path.join("storage.db3")).context("Error opening SQLite connection")?,
77                )
78                .context("Error initializing SQLite connection")?,
79            ),
80        })
81    }
82
83    /// Builds a new in-memory SQLite storage for testing purposes
84    pub fn new_in_memory() -> Result<Self> {
85        Ok(Self {
86            conn: Mutex::new(
87                Self::initialize_connection(Connection::open_in_memory()?)
88                    .context("Error initializing SQLite connection")?,
89            ),
90        })
91    }
92
93    /// Initializes an SQLite connection applying migrations and common pragmas
94    fn initialize_connection(mut conn: Connection) -> Result<Connection> {
95        // Different implementation of the atomicity properties
96        conn.pragma_update(None, "journal_mode", "WAL")
97            .context("Error applying journal mode pragma")?;
98        // Synchronize less often to the filesystem
99        conn.pragma_update(None, "synchronous", "normal")
100            .context("Error applying synchronous pragma")?;
101        // Check foreign key reference, slightly worst performance
102        conn.pragma_update(None, "foreign_keys", "on")
103            .context("Error applying foreign keys pragma")?;
104
105        // Update the database schema, atomically
106        MIGRATIONS.to_latest(&mut conn).context("Error applying migrations")?;
107
108        Ok(conn)
109    }
110
111    /// Inserts a command and updates its `id` with the inserted value.
112    ///
113    /// If the command already exist on the database, its description will be updated.
114    ///
115    /// Returns wether the command was inserted or not (updated)
116    pub fn insert_command(&self, command: &mut Command) -> Result<bool> {
117        Ok(self.insert_commands(slice::from_mut(command))? == 1)
118    }
119
120    /// Inserts a bunch of commands and updates its `id` with the inserted value.
121    ///
122    /// If any command already exist on the database, its description will be updated.
123    ///
124    /// Returns the number of commands inserted (the rest are updated)
125    pub fn insert_commands(&self, commands: &mut [Command]) -> Result<u64> {
126        let mut res = 0;
127
128        let mut conn = self.conn.lock().expect("poisoned lock");
129        let tx = conn.transaction()?;
130
131        {
132            let mut stmt_cmd = tx.prepare(
133                r#"INSERT INTO command (category, alias, cmd, description) VALUES (?, ?, ?, ?)
134                ON CONFLICT(cmd) DO UPDATE SET description=excluded.description
135                RETURNING rowid"#,
136            )?;
137            let mut stmt_fts_check = tx.prepare("SELECT rowid FROM command_fts WHERE rowid = ?")?;
138            let mut stmt_fts_update = tx.prepare("UPDATE command_fts SET flat_description = ? WHERE rowid = ?")?;
139            let mut stmt_fts_insert =
140                tx.prepare("INSERT INTO command_fts (rowid, flat_cmd, flat_description) VALUES (?, ?, ?)")?;
141
142            for command in commands {
143                let row_id = stmt_cmd
144                    .query_row(
145                        (
146                            &command.category,
147                            command.alias.as_deref(),
148                            &command.cmd,
149                            &command.description,
150                        ),
151                        |r| r.get(0),
152                    )
153                    .context("Error inserting command")?;
154
155                command.id = row_id;
156
157                let current_row: Option<i32> = stmt_fts_check
158                    .query_row([row_id], |r| r.get(0))
159                    .optional()
160                    .context("Error checking fts")?;
161
162                match current_row {
163                    Some(_) => {
164                        stmt_fts_update
165                            .execute((flatten_str(&command.description), row_id))
166                            .context("Error updating command fts")?;
167                    }
168                    None => {
169                        res += 1;
170                        stmt_fts_insert
171                            .execute((row_id, flatten_str(&command.cmd), flatten_str(&command.description)))
172                            .context("Error inserting command fts")?;
173                    }
174                }
175            }
176        }
177
178        tx.commit()?;
179
180        Ok(res)
181    }
182
183    /// Updates an existing command
184    ///
185    /// Returns wether the command exists and was updated or not.
186    pub fn update_command(&self, command: &Command) -> Result<bool> {
187        let mut conn = self.conn.lock().expect("poisoned lock");
188        let tx = conn.transaction()?;
189
190        let updated = tx
191            .execute(
192                r#"UPDATE command SET alias = ?, cmd = ?, description = ?, usage = ? WHERE rowid = ?"#,
193                (
194                    command.alias.as_deref(),
195                    &command.cmd,
196                    &command.description,
197                    command.usage,
198                    command.id,
199                ),
200            )
201            .context("Error updating command")?;
202
203        if updated == 1 {
204            let updated = tx
205                .execute(
206                    r#"UPDATE command_fts SET flat_cmd = ?, flat_description = ? WHERE rowid = ?"#,
207                    (flatten_str(&command.cmd), flatten_str(&command.description), command.id),
208                )
209                .context("Error updating command fts")?;
210            if updated == 1 {
211                tx.commit()?;
212                Ok(true)
213            } else {
214                Ok(false)
215            }
216        } else {
217            Ok(false)
218        }
219    }
220
221    /// Updates an existing command by incrementing its usage by one
222    ///
223    /// Returns wether the command exists and was updated or not.
224    pub fn increment_command_usage(&self, command_id: i64) -> Result<bool> {
225        let conn = self.conn.lock().expect("poisoned lock");
226        let updated = conn
227            .execute(r#"UPDATE command SET usage = usage + 1 WHERE rowid = ?"#, [command_id])
228            .context("Error updating command usage")?;
229
230        Ok(updated == 1)
231    }
232
233    /// Deletes an existing command
234    ///
235    /// Returns wether the command exists and was deleted or not.
236    pub fn delete_command(&self, command_id: i64) -> Result<bool> {
237        let mut conn = self.conn.lock().expect("poisoned lock");
238        let tx = conn.transaction()?;
239
240        let deleted = tx
241            .execute(r#"DELETE FROM command WHERE rowid = ?"#, [command_id])
242            .context("Error deleting command")?;
243
244        if deleted == 1 {
245            let deleted = tx
246                .execute(r#"DELETE FROM command_fts WHERE rowid = ?"#, [command_id])
247                .context("Error deleting command fts")?;
248            if deleted == 1 {
249                tx.commit()?;
250                Ok(true)
251            } else {
252                Ok(false)
253            }
254        } else {
255            Ok(false)
256        }
257    }
258
259    /// Get commands matching a category
260    pub fn get_commands(&self, category: impl AsRef<str>) -> Result<Vec<Command>> {
261        let category = category.as_ref();
262
263        let conn = self.conn.lock().expect("poisoned lock");
264        let mut stmt = conn.prepare(
265            r#"SELECT rowid, category, alias, cmd, description, usage 
266            FROM command
267            WHERE category = ?
268            ORDER BY usage DESC"#,
269        )?;
270
271        let commands = stmt
272            .query([category])?
273            .mapped(command_from_row)
274            .finish_vec()
275            .context("Error querying commands")?;
276
277        Ok(commands)
278    }
279
280    /// Finds commands matching the given search criteria
281    pub fn find_commands(&self, search: impl AsRef<str>) -> Result<Vec<Command>> {
282        let search = search.as_ref().trim();
283        if search.is_empty() {
284            return self.get_commands(USER_CATEGORY);
285        }
286        let flat_search = flatten_str(search);
287
288        let conn = self.conn.lock().expect("poisoned lock");
289        let alias_cmd = conn
290            .query_row(
291                r#"SELECT rowid, category, alias, cmd, description, usage 
292                FROM command
293                WHERE alias = :flat_search OR alias = :search"#,
294                &[(":flat_search", flat_search.as_str()), (":search", search)],
295                command_from_row,
296            )
297            .optional()
298            .context("Error querying command by alias")?;
299        if let Some(cmd) = alias_cmd {
300            return Ok(vec![cmd]);
301        }
302
303        let hashtags = flat_search
304            .split_whitespace()
305            .filter(|t| t.starts_with('#'))
306            .collect_vec();
307
308        let flat_fts_search = ALLOWED_FTS_REGEX.replace_all(&flat_search, "");
309        let flat_fts_search = flat_fts_search.trim();
310        if flat_fts_search.is_empty() || flat_fts_search == " " {
311            drop(conn);
312            return self.get_commands(USER_CATEGORY);
313        }
314
315        let mut stmt = conn.prepare(
316            r#"
317                    SELECT DISTINCT rowid, category, alias, cmd, description, usage 
318                    FROM (
319                        SELECT c.rowid, c.category, c.alias, c.cmd, c.description, c.usage, 2 as ord
320                        FROM command_fts s
321                        JOIN command c ON s.rowid = c.rowid
322                        WHERE command_fts MATCH :match_cmd_ordered
323                    
324                        UNION ALL
325                        
326                        SELECT c.rowid, c.category, c.alias, c.cmd, c.description, c.usage, 1 as ord
327                        FROM command_fts s
328                        JOIN command c ON s.rowid = c.rowid
329                        WHERE command_fts MATCH :match_simple
330
331                        UNION ALL
332                        
333                        SELECT c.rowid, c.category, c.alias, c.cmd, c.description, c.usage, 0 as ord
334                        FROM command_fts s
335                        JOIN command c ON s.rowid = c.rowid
336                        WHERE s.flat_cmd GLOB :glob OR s.flat_description GLOB :glob
337                    )
338                    ORDER BY ord DESC, usage DESC, (CASE WHEN category = 'user' THEN 1 ELSE 0 END) DESC
339                "#,
340        )?;
341
342        let match_cmd_ordered = format!(
343            "\"flat_cmd\" : ^{}",
344            flat_fts_search
345                .split_whitespace()
346                .map(|token| format!("{token}*"))
347                .join(" + ")
348        );
349        let match_simple = flat_fts_search
350            .split_whitespace()
351            .map(|token| format!("{token}*"))
352            .join(" ");
353        let glob = flat_search
354            .split_whitespace()
355            .map(|token| format!("*{token}*"))
356            .join(" ");
357
358        let commands = stmt
359            .query(&[
360                (":match_cmd_ordered", &match_cmd_ordered),
361                (":match_simple", &match_simple),
362                (":glob", &glob),
363            ])?
364            .mapped(command_from_row)
365            .filter(|r| {
366                if !hashtags.is_empty() {
367                    if let Ok(command) = r {
368                        for tag in &hashtags {
369                            if !command.description.contains(tag) {
370                                return false;
371                            }
372                        }
373                    }
374                }
375                true
376            })
377            .finish_vec()
378            .context("Error querying fts command")?;
379
380        Ok(commands)
381    }
382
383    /// Exports the commands from a given category into the given file path
384    ///
385    /// ## Returns
386    ///
387    /// The number of exported commands
388    pub fn export(&self, category: impl AsRef<str>, file_path: impl Into<String>) -> Result<usize> {
389        let category = category.as_ref();
390        let file_path = file_path.into();
391        let commands = self.get_commands(category)?;
392        let size = commands.len();
393        let file = fs::File::create(&file_path).context("Error creating output file")?;
394        let mut w = BufWriter::new(file);
395        for command in commands {
396            writeln!(w, "{} ## {}", command.cmd, command.description).context("Error writing file")?;
397        }
398        w.flush().context("Error writing file")?;
399        Ok(size)
400    }
401
402    /// Imports commands from the given file into a category.
403    ///
404    /// ## Returns
405    ///
406    /// The number of newly inserted commands
407    pub fn import(&self, category: impl AsRef<str>, file_path: String) -> Result<u64> {
408        let category = category.as_ref();
409        let file = fs::File::open(file_path).context("Error opening file")?;
410        let r = BufReader::new(file);
411        let mut commands = r
412            .lines()
413            .map_err(anyhow::Error::from)
414            .filter_ok(|line| !line.is_empty() && !line.starts_with('#'))
415            .and_then(|line| {
416                let (cmd, description) = line
417                    .split_once(" ## ")
418                    .ok_or_else(|| anyhow!("Unexpected file format"))?;
419                Ok::<_, anyhow::Error>(Command::new(category, cmd, description))
420            })
421            .finish_vec()?;
422
423        let new = self.insert_commands(&mut commands)?;
424
425        Ok(new)
426    }
427
428    /// Determines if the store is empty (no commands stored)
429    pub fn is_empty(&self) -> Result<bool> {
430        Ok(self.len()? == 0)
431    }
432
433    /// Returns the number of stored commands
434    pub fn len(&self) -> Result<u64> {
435        let conn = self.conn.lock().expect("poisoned lock");
436        let mut stmt = conn.prepare(r#"SELECT COUNT(*) FROM command"#)?;
437        Ok(stmt.query_row([], |r| r.get(0))?)
438    }
439
440    /// Inserts a label suggestion if it doesn't exists.
441    ///
442    /// Returns wether the suggestion was inserted or not (already existed)
443    pub fn insert_label_suggestion(&self, suggestion: &LabelSuggestion) -> Result<bool> {
444        if suggestion.flat_label == suggestion.suggestion {
445            return Ok(false);
446        }
447
448        let conn = self.conn.lock().expect("poisoned lock");
449        let inserted = match conn.execute(
450            r#"INSERT INTO label_suggestion (flat_root_cmd, flat_label, suggestion, usage) VALUES (?, ?, ?, ?)"#,
451            (
452                &suggestion.flat_root_cmd,
453                &suggestion.flat_label,
454                &suggestion.suggestion,
455                suggestion.usage,
456            ),
457        ) {
458            Ok(i) => i,
459            Err(Error::SqliteFailure(err, msg)) => match err.code {
460                ErrorCode::ConstraintViolation => return Ok(false),
461                _ => {
462                    return Err(
463                        anyhow::Error::new(Error::SqliteFailure(err, msg)).context("Error inserting label suggestion")
464                    );
465                }
466            },
467            Err(err) => {
468                return Err(anyhow::Error::new(err).context("Error inserting label suggestion"));
469            }
470        };
471
472        Ok(inserted == 1)
473    }
474
475    /// Updates an existing label suggestion
476    ///
477    /// Returns wether the suggestion exists and was updated or not.
478    pub fn update_label_suggestion(
479        &self,
480        suggestion: &mut LabelSuggestion,
481        new_suggestion: impl Into<String>,
482    ) -> Result<bool> {
483        let conn = self.conn.lock().expect("poisoned lock");
484        let new_suggestion = new_suggestion.into();
485        let updated = conn
486            .execute(
487                r#"UPDATE label_suggestion SET suggestion = ? WHERE flat_root_cmd = ? AND flat_label = ? AND suggestion = ?"#,
488                (
489                    &new_suggestion,
490                    &suggestion.flat_root_cmd,
491                    &suggestion.flat_label,
492                    &suggestion.suggestion,
493                ),
494            )
495            .context("Error updating label suggestion")?;
496
497        let updated = updated == 1;
498
499        if updated {
500            suggestion.suggestion = new_suggestion;
501        }
502
503        Ok(updated)
504    }
505
506    /// Updates the usage of an existing label suggestion
507    ///
508    /// Returns wether the suggestion exists and was updated or not.
509    pub fn update_label_suggestion_usage(&self, suggestion: &LabelSuggestion) -> Result<bool> {
510        let conn = self.conn.lock().expect("poisoned lock");
511        let updated = conn
512            .execute(
513                r#"UPDATE label_suggestion SET usage = ? WHERE flat_root_cmd = ? AND flat_label = ? AND suggestion = ?"#,
514                (
515                    suggestion.usage,
516                    &suggestion.flat_root_cmd,
517                    &suggestion.flat_label,
518                    &suggestion.suggestion,
519                ),
520            )
521            .context("Error updating label suggestion usage")?;
522
523        Ok(updated == 1)
524    }
525
526    /// Deletes an existing label suggestion
527    ///
528    /// Returns wether the suggestion exists and was deleted or not.
529    pub fn delete_label_suggestion(&self, suggestion: &LabelSuggestion) -> Result<bool> {
530        let conn = self.conn.lock().expect("poisoned lock");
531        let deleted = conn
532            .execute(
533                r#"DELETE FROM label_suggestion WHERE flat_root_cmd = ? AND flat_label = ? AND suggestion = ?"#,
534                (
535                    &suggestion.flat_root_cmd,
536                    &suggestion.flat_label,
537                    &suggestion.suggestion,
538                ),
539            )
540            .context("Error deleting label suggestion")?;
541
542        Ok(deleted == 1)
543    }
544
545    /// Finds label suggestions for the given root command and label
546    pub fn find_suggestions_for(
547        &self,
548        root_cmd: impl AsRef<str>,
549        label: impl AsRef<str>,
550    ) -> Result<Vec<LabelSuggestion>> {
551        let flat_root_cmd = flatten_str(root_cmd.as_ref());
552        let label = label.as_ref();
553        let mut parameters = label.split('|').map(flatten_str).collect_vec();
554        parameters.insert(0, flatten_str(label));
555
556        const QUERY: &str = r#"
557            SELECT * FROM (
558                SELECT 
559                    s.flat_root_cmd, 
560                    s.flat_label, 
561                    s.suggestion, 
562                    s.usage, 
563                    q.sum_usage,
564                    RANK () OVER ( 
565                        PARTITION BY s.suggestion
566                        ORDER BY LENGTH(s.flat_label) DESC
567                    ) rank 
568                FROM label_suggestion s
569                JOIN (
570                    SELECT flat_root_cmd, suggestion, SUM(usage) as sum_usage
571                    FROM label_suggestion
572                    WHERE flat_root_cmd = ?1 AND flat_label IN (#LABELS#)
573                    GROUP BY flat_root_cmd, suggestion
574                ) q ON s.flat_root_cmd = q.flat_root_cmd AND s.suggestion = q.suggestion
575            )
576            WHERE rank = 1
577            ORDER BY 
578                sum_usage DESC, 
579                (CASE WHEN flat_label = ?2 THEN 1 ELSE 0 END) DESC
580        "#;
581
582        let conn = self.conn.lock().expect("poisoned lock");
583        let mut stmt = conn.prepare(
584            &QUERY.replace(
585                "#LABELS#",
586                &parameters
587                    .iter()
588                    .enumerate()
589                    .map(|(i, _)| format!("?{}", i + 2))
590                    .join(","),
591            ),
592        )?;
593
594        parameters.insert(0, flat_root_cmd);
595
596        let suggestions = stmt
597            .query(params_from_iter(parameters.iter()))?
598            .mapped(label_suggestion_from_row)
599            .finish_vec()
600            .context("Error querying label suggestions")?;
601
602        Ok(suggestions)
603    }
604}
605
606/// Maps a [Command] from a [Row]
607fn command_from_row(row: &Row<'_>) -> rusqlite::Result<Command> {
608    Ok(Command {
609        id: row.get(0)?,
610        category: row.get(1)?,
611        alias: row.get(2)?,
612        cmd: row.get(3)?,
613        description: row.get(4)?,
614        usage: row.get(5)?,
615    })
616}
617
618/// Maps a [LabelSuggestion] from a [Row]
619fn label_suggestion_from_row(row: &Row<'_>) -> rusqlite::Result<LabelSuggestion> {
620    Ok(LabelSuggestion {
621        flat_root_cmd: row.get(0)?,
622        flat_label: row.get(1)?,
623        suggestion: row.get(2)?,
624        usage: row.get(3)?,
625    })
626}
627
628impl Drop for SqliteStorage {
629    fn drop(&mut self) {
630        let conn = self.conn.lock().expect("poisoned lock");
631        // Make sure pragma optimize does not take too long
632        conn.pragma_update(None, "analysis_limit", "400")
633            .expect("Failed analysis_limit PRAGMA");
634        // Gather statistics to improve query optimization
635        conn.execute_batch("PRAGMA optimize;").expect("Failed optimize PRAGMA");
636    }
637}
638
639#[cfg(test)]
640mod tests {
641    use super::MIGRATIONS;
642
643    #[test]
644    fn migrations_test() {
645        assert!(MIGRATIONS.validate().is_ok());
646    }
647}