intelli_shell/storage/
mod.rs

1use std::{
2    path::Path,
3    sync::{Arc, atomic::AtomicBool},
4};
5
6use client::{SqliteClient, SqliteClientBuilder};
7use color_eyre::eyre::Context;
8use itertools::Itertools;
9use migrations::MIGRATIONS;
10use regex::Regex;
11use rusqlite::{OpenFlags, functions::FunctionFlags};
12
13use crate::{
14    errors::Result,
15    utils::{COMMAND_VARIABLE_REGEX_QUOTES, SplitCaptures, SplitItem},
16};
17
18mod client;
19mod migrations;
20mod queries;
21
22mod command;
23mod completion;
24mod import_export;
25mod variable;
26mod version;
27
28type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
29
30/// `SqliteStorage` provides an interface for interacting with a SQLite database to store and retrieve application data,
31/// primarily [`Command`] and [`VariableValue`] entities
32#[derive(Clone)]
33pub struct SqliteStorage {
34    /// Whether the workspace-level temp tables are created
35    workspace_tables_loaded: Arc<AtomicBool>,
36    /// The SQLite client used for database operations
37    client: Arc<SqliteClient>,
38}
39
40impl SqliteStorage {
41    /// Creates a new instance of [`SqliteStorage`] using a persistent database file.
42    ///
43    /// If INTELLI_STORAGE environment variable is set, it will use the specified path for the database file.
44    pub async fn new(data_dir: impl AsRef<Path>) -> Result<Self> {
45        let builder = if let Some(path) = std::env::var_os("INTELLI_STORAGE") {
46            // If INTELLI_STORAGE is set, use it as the database path
47            tracing::info!("Using INTELLI_STORAGE path: {}", path.to_string_lossy());
48            SqliteClientBuilder::new().path(path)
49        } else {
50            // Otherwise, use the provided data directory
51            tracing::info!("Using default storage path: {}", data_dir.as_ref().display());
52            SqliteClientBuilder::new().path(data_dir.as_ref().join("storage.db3"))
53        };
54        Ok(Self {
55            workspace_tables_loaded: Arc::new(AtomicBool::new(false)),
56            client: Arc::new(Self::open_client(builder).await?),
57        })
58    }
59
60    /// Creates a new in-memory instance of [`SqliteStorage`].
61    ///
62    /// This is primarily intended for testing purposes, where a persistent database is not required.
63    #[cfg(test)]
64    pub async fn new_in_memory() -> Result<Self> {
65        let client = Self::open_client(SqliteClientBuilder::new()).await?;
66        Ok(Self {
67            workspace_tables_loaded: Arc::new(AtomicBool::new(false)),
68            client: Arc::new(client),
69        })
70    }
71
72    /// Opens and initializes an SQLite client.
73    ///
74    /// This internal helper function configures the client with necessary PRAGMA settings for optimal performance and
75    /// data integrity (WAL mode, normal sync, foreign keys) and applies all pending database migrations.
76    async fn open_client(builder: SqliteClientBuilder) -> Result<SqliteClient> {
77        // Build the client
78        let client = builder
79            .flags(OpenFlags::default())
80            .open()
81            .await
82            .wrap_err("Error initializing SQLite client")?;
83
84        // Use Write-Ahead Logging (WAL) mode for better concurrency and performance.
85        client
86            .conn(|conn| {
87                Ok(conn
88                    .pragma_update(None, "journal_mode", "wal")
89                    .wrap_err("Error applying journal mode pragma")?)
90            })
91            .await?;
92
93        // Set synchronous mode to NORMAL. This means SQLite will still sync at critical moments, but less frequently
94        // than FULL, offering a good balance between safety and performance.
95        client
96            .conn(|conn| {
97                Ok(conn
98                    .pragma_update(None, "synchronous", "normal")
99                    .wrap_err("Error applying synchronous pragma")?)
100            })
101            .await?;
102
103        // Enforce foreign key constraints to maintain data integrity.
104        // This has a slight performance cost but is crucial for relational data.
105        client
106            .conn(|conn| {
107                Ok(conn
108                    .pragma_update(None, "foreign_keys", "on")
109                    .wrap_err("Error applying foreign keys pragma")?)
110            })
111            .await?;
112
113        // Store temp schema in memory
114        client
115            .conn(|conn| {
116                Ok(conn
117                    .pragma_update(None, "temp_store", "memory")
118                    .wrap_err("Error applying temp store pragma")?)
119            })
120            .await?;
121
122        // Apply all defined database migrations to bring the schema to the latest version.
123        // This is done atomically within a transaction.
124        client
125            .conn_mut(|conn| Ok(MIGRATIONS.to_latest(conn).wrap_err("Error applying migrations")?))
126            .await?;
127
128        // Add a regexp function to the client
129        client
130            .conn(|conn| {
131                Ok(conn
132                    .create_scalar_function(
133                        "regexp",
134                        2,
135                        FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
136                        |ctx| {
137                            assert_eq!(ctx.len(), 2, "regexp() called with unexpected number of arguments");
138
139                            let text = ctx
140                                .get_raw(1)
141                                .as_str_or_null()
142                                .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
143
144                            let Some(text) = text else {
145                                return Ok(false);
146                            };
147
148                            let cached_re: Arc<Regex> =
149                                ctx.get_or_create_aux(0, |vr| Ok::<_, BoxError>(Regex::new(vr.as_str()?)?))?;
150
151                            Ok(cached_re.is_match(text))
152                        },
153                    )
154                    .wrap_err("Error adding regexp function")?)
155            })
156            .await?;
157
158        // Add a cmd-to-regex function
159        client
160            .conn(|conn| {
161                Ok(conn
162                    .create_scalar_function(
163                        "cmd_to_regex",
164                        1,
165                        FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
166                        |ctx| {
167                            assert_eq!(
168                                ctx.len(),
169                                1,
170                                "cmd_to_regex() called with unexpected number of arguments"
171                            );
172                            let cmd_template = ctx.get::<String>(0)?;
173
174                            // Use the SplitCaptures iterator to process both unmatched literals and captured variables
175                            let regex_body = SplitCaptures::new(&COMMAND_VARIABLE_REGEX_QUOTES, &cmd_template)
176                                .filter_map(|item| match item {
177                                    // For unmatched parts, trim them and escape any special regex chars
178                                    SplitItem::Unmatched(s) => {
179                                        let trimmed = s.trim();
180                                        if trimmed.is_empty() {
181                                            None
182                                        } else {
183                                            Some(regex::escape(trimmed))
184                                        }
185                                    }
186                                    // For captured parts (the variables), replace them with a capture group
187                                    SplitItem::Captured(caps) => {
188                                        // Check which capture group matched to see if the placeholder was quoted
189                                        let placeholder_regex = if caps.get(1).is_some() {
190                                            // Group 1 matched '{{...}}', so expect a single-quoted argument
191                                            r"('[^']*')"
192                                        } else if caps.get(2).is_some() {
193                                            // Group 2 matched "{{...}}", so expect a double-quoted argument
194                                            r#"("[^"]*")"#
195                                        } else {
196                                            // Group 3 matched {{...}}, so expect a generic argument
197                                            r#"('[^']*'|"[^"]*"|\S+)"#
198                                        };
199                                        Some(String::from(placeholder_regex))
200                                    },
201                                })
202                                // Join them by any number of whitespaces
203                                .join(r"\s+");
204
205                            // Build the final regex
206                            Ok(format!("^{regex_body}$"))
207                        },
208                    )
209                    .wrap_err("Error adding cmd-to-regex function")?)
210            })
211            .await?;
212
213        Ok(client)
214    }
215
216    #[cfg(debug_assertions)]
217    pub async fn query(&self, sql: String) -> Result<String> {
218        self.client
219            .conn(move |conn| {
220                use prettytable::{Cell, Row, Table};
221                use rusqlite::types::Value;
222
223                let mut stmt = conn.prepare(&sql)?;
224                let column_names = stmt
225                    .column_names()
226                    .into_iter()
227                    .map(String::from)
228                    .collect::<Vec<String>>();
229                let columns_len = column_names.len();
230                let mut table = Table::new();
231                table.add_row(Row::from(column_names));
232                let rows = stmt.query_map([], |row| {
233                    let mut cells = Vec::new();
234                    for i in 0..columns_len {
235                        let value: Value = row.get(i)?;
236                        let cell_value = match value {
237                            Value::Null => "NULL".to_string(),
238                            Value::Integer(i) => i.to_string(),
239                            Value::Real(f) => f.to_string(),
240                            Value::Text(t) => t,
241                            Value::Blob(_) => "[BLOB]".to_string(),
242                        };
243                        cells.push(Cell::new(&cell_value));
244                    }
245                    Ok(Row::from(cells))
246                })?;
247                for row in rows {
248                    table.add_row(row?);
249                }
250                Ok(table.to_string())
251            })
252            .await
253    }
254}