intelli_shell/storage/
mod.rs

1use std::{
2    path::Path,
3    sync::{Arc, atomic::AtomicBool},
4};
5
6use client::{SqliteClient, SqliteClientBuilder};
7use color_eyre::{Result, eyre::Context};
8use itertools::Itertools;
9use migrations::MIGRATIONS;
10use regex::Regex;
11use rusqlite::{OpenFlags, functions::FunctionFlags};
12
13use crate::utils::{COMMAND_VARIABLE_REGEX_QUOTES, SplitCaptures, SplitItem};
14
15mod client;
16mod migrations;
17mod queries;
18
19mod command;
20mod variable;
21mod version;
22
23type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
24
25/// `SqliteStorage` provides an interface for interacting with a SQLite database to store and retrieve application data,
26/// primarily [`Command`] and [`VariableValue`] entities
27#[derive(Clone)]
28pub struct SqliteStorage {
29    /// Whether the workspace-level temp tables are created
30    workspace_tables_loaded: Arc<AtomicBool>,
31    /// The SQLite client used for database operations
32    client: Arc<SqliteClient>,
33}
34
35impl SqliteStorage {
36    /// Creates a new instance of [`SqliteStorage`] using a persistent database file.
37    ///
38    /// If INTELLI_STORAGE environment variable is set, it will use the specified path for the database file.
39    pub async fn new(data_dir: impl AsRef<Path>) -> Result<Self> {
40        let builder = if let Some(path) = std::env::var_os("INTELLI_STORAGE") {
41            // If INTELLI_STORAGE is set, use it as the database path
42            tracing::info!("Using INTELLI_STORAGE path: {}", path.to_string_lossy());
43            SqliteClientBuilder::new().path(path)
44        } else {
45            // Otherwise, use the provided data directory
46            tracing::info!("Using default storage path: {}", data_dir.as_ref().display());
47            SqliteClientBuilder::new().path(data_dir.as_ref().join("storage.db3"))
48        };
49        Ok(Self {
50            workspace_tables_loaded: Arc::new(AtomicBool::new(false)),
51            client: Arc::new(Self::open_client(builder).await?),
52        })
53    }
54
55    /// Creates a new in-memory instance of [`SqliteStorage`].
56    ///
57    /// This is primarily intended for testing purposes, where a persistent database is not required.
58    #[cfg(test)]
59    pub async fn new_in_memory() -> Result<Self> {
60        let client = Self::open_client(SqliteClientBuilder::new()).await?;
61        Ok(Self {
62            workspace_tables_loaded: Arc::new(AtomicBool::new(false)),
63            client: Arc::new(client),
64        })
65    }
66
67    /// Opens and initializes an SQLite client.
68    ///
69    /// This internal helper function configures the client with necessary PRAGMA settings for optimal performance and
70    /// data integrity (WAL mode, normal sync, foreign keys) and applies all pending database migrations.
71    async fn open_client(builder: SqliteClientBuilder) -> Result<SqliteClient> {
72        // Build the client
73        let client = builder
74            .flags(OpenFlags::default())
75            .open()
76            .await
77            .wrap_err("Error initializing SQLite client")?;
78
79        // Use Write-Ahead Logging (WAL) mode for better concurrency and performance.
80        client
81            .conn(|conn| {
82                conn.pragma_update(None, "journal_mode", "wal")
83                    .wrap_err("Error applying journal mode pragma")
84            })
85            .await?;
86
87        // Set synchronous mode to NORMAL. This means SQLite will still sync at critical moments, but less frequently
88        // than FULL, offering a good balance between safety and performance.
89        client
90            .conn(|conn| {
91                conn.pragma_update(None, "synchronous", "normal")
92                    .wrap_err("Error applying synchronous pragma")
93            })
94            .await?;
95
96        // Enforce foreign key constraints to maintain data integrity.
97        // This has a slight performance cost but is crucial for relational data.
98        client
99            .conn(|conn| {
100                conn.pragma_update(None, "foreign_keys", "on")
101                    .wrap_err("Error applying foreign keys pragma")
102            })
103            .await?;
104
105        // Store temp schema in memory
106        client
107            .conn(|conn| {
108                conn.pragma_update(None, "temp_store", "memory")
109                    .wrap_err("Error applying temp store pragma")
110            })
111            .await?;
112
113        // Apply all defined database migrations to bring the schema to the latest version.
114        // This is done atomically within a transaction.
115        client
116            .conn_mut(|conn| MIGRATIONS.to_latest(conn).wrap_err("Error applying migrations"))
117            .await?;
118
119        // Add a regexp function to the client
120        client
121            .conn(|conn| {
122                conn.create_scalar_function(
123                    "regexp",
124                    2,
125                    FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
126                    |ctx| {
127                        assert_eq!(ctx.len(), 2, "regexp() called with unexpected number of arguments");
128
129                        let text = ctx
130                            .get_raw(1)
131                            .as_str_or_null()
132                            .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
133
134                        let Some(text) = text else {
135                            return Ok(false);
136                        };
137
138                        let cached_re: Arc<Regex> =
139                            ctx.get_or_create_aux(0, |vr| Ok::<_, BoxError>(Regex::new(vr.as_str()?)?))?;
140
141                        Ok(cached_re.is_match(text))
142                    },
143                )
144                .wrap_err("Error adding regexp function")
145            })
146            .await?;
147
148        // Add a cmd-to-regex function
149        client
150            .conn(|conn| {
151                conn.create_scalar_function(
152                    "cmd_to_regex",
153                    1,
154                    FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
155                    |ctx| {
156                        assert_eq!(
157                            ctx.len(),
158                            1,
159                            "cmd_to_regex() called with unexpected number of arguments"
160                        );
161                        let cmd_template = ctx.get::<String>(0)?;
162
163                        // Use the SplitCaptures iterator to process both unmatched literals and captured variables
164                        let regex_body = SplitCaptures::new(&COMMAND_VARIABLE_REGEX_QUOTES, &cmd_template)
165                            .filter_map(|item| match item {
166                                // For unmatched parts, trim them and escape any special regex chars
167                                SplitItem::Unmatched(s) => {
168                                    let trimmed = s.trim();
169                                    if trimmed.is_empty() {
170                                        None
171                                    } else {
172                                        Some(regex::escape(trimmed))
173                                    }
174                                }
175                                // For captured parts (the variables), replace them with a capture group
176                                SplitItem::Captured(caps) => {
177                                    // Check which capture group matched to see if the placeholder was quoted
178                                    let placeholder_regex = if caps.get(1).is_some() {
179                                        // Group 1 matched '{{...}}', so expect a single-quoted argument
180                                        r"('[^']*')"
181                                    } else if caps.get(2).is_some() {
182                                        // Group 2 matched "{{...}}", so expect a double-quoted argument
183                                        r#"("[^"]*")"#
184                                    } else {
185                                        // Group 3 matched {{...}}, so expect a generic argument
186                                        r#"('[^']*'|"[^"]*"|\S+)"#
187                                    };
188                                    Some(String::from(placeholder_regex))
189                                },
190                            })
191                            // Join them by any number of whitespaces
192                            .join(r"\s+");
193
194                        // Build the final regex
195                        Ok(format!("^{regex_body}$"))
196                    },
197                )
198                .wrap_err("Error adding cmd-to-regex function")
199            })
200            .await?;
201
202        Ok(client)
203    }
204
205    #[cfg(debug_assertions)]
206    pub async fn query(&self, sql: String) -> Result<String> {
207        self.client
208            .conn(move |conn| {
209                use prettytable::{Cell, Row, Table};
210                use rusqlite::types::Value;
211
212                let mut stmt = conn.prepare(&sql)?;
213                let column_names = stmt
214                    .column_names()
215                    .into_iter()
216                    .map(String::from)
217                    .collect::<Vec<String>>();
218                let columns_len = column_names.len();
219                let mut table = Table::new();
220                table.add_row(Row::from(column_names));
221                let rows = stmt.query_map([], |row| {
222                    let mut cells = Vec::new();
223                    for i in 0..columns_len {
224                        let value: Value = row.get(i)?;
225                        let cell_value = match value {
226                            Value::Null => "NULL".to_string(),
227                            Value::Integer(i) => i.to_string(),
228                            Value::Real(f) => f.to_string(),
229                            Value::Text(t) => t,
230                            Value::Blob(_) => "[BLOB]".to_string(),
231                        };
232                        cells.push(Cell::new(&cell_value));
233                    }
234                    Ok(Row::from(cells))
235                })?;
236                for row in rows {
237                    table.add_row(row?);
238                }
239                Ok(table.to_string())
240            })
241            .await
242    }
243}