intelli_shell/storage/
mod.rs

1use std::{
2    path::Path,
3    sync::{Arc, atomic::AtomicBool},
4};
5
6use client::{SqliteClient, SqliteClientBuilder};
7use color_eyre::eyre::Context;
8use itertools::Itertools;
9use migrations::MIGRATIONS;
10use regex::Regex;
11use rusqlite::{OpenFlags, functions::FunctionFlags};
12
13use crate::{
14    errors::Result,
15    utils::{COMMAND_VARIABLE_REGEX_QUOTES, SplitCaptures, SplitItem},
16};
17
18mod client;
19mod migrations;
20mod queries;
21
22mod command;
23mod variable;
24mod version;
25
26type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
27
28/// `SqliteStorage` provides an interface for interacting with a SQLite database to store and retrieve application data,
29/// primarily [`Command`] and [`VariableValue`] entities
30#[derive(Clone)]
31pub struct SqliteStorage {
32    /// Whether the workspace-level temp tables are created
33    workspace_tables_loaded: Arc<AtomicBool>,
34    /// The SQLite client used for database operations
35    client: Arc<SqliteClient>,
36}
37
38impl SqliteStorage {
39    /// Creates a new instance of [`SqliteStorage`] using a persistent database file.
40    ///
41    /// If INTELLI_STORAGE environment variable is set, it will use the specified path for the database file.
42    pub async fn new(data_dir: impl AsRef<Path>) -> Result<Self> {
43        let builder = if let Some(path) = std::env::var_os("INTELLI_STORAGE") {
44            // If INTELLI_STORAGE is set, use it as the database path
45            tracing::info!("Using INTELLI_STORAGE path: {}", path.to_string_lossy());
46            SqliteClientBuilder::new().path(path)
47        } else {
48            // Otherwise, use the provided data directory
49            tracing::info!("Using default storage path: {}", data_dir.as_ref().display());
50            SqliteClientBuilder::new().path(data_dir.as_ref().join("storage.db3"))
51        };
52        Ok(Self {
53            workspace_tables_loaded: Arc::new(AtomicBool::new(false)),
54            client: Arc::new(Self::open_client(builder).await?),
55        })
56    }
57
58    /// Creates a new in-memory instance of [`SqliteStorage`].
59    ///
60    /// This is primarily intended for testing purposes, where a persistent database is not required.
61    #[cfg(test)]
62    pub async fn new_in_memory() -> Result<Self> {
63        let client = Self::open_client(SqliteClientBuilder::new()).await?;
64        Ok(Self {
65            workspace_tables_loaded: Arc::new(AtomicBool::new(false)),
66            client: Arc::new(client),
67        })
68    }
69
70    /// Opens and initializes an SQLite client.
71    ///
72    /// This internal helper function configures the client with necessary PRAGMA settings for optimal performance and
73    /// data integrity (WAL mode, normal sync, foreign keys) and applies all pending database migrations.
74    async fn open_client(builder: SqliteClientBuilder) -> Result<SqliteClient> {
75        // Build the client
76        let client = builder
77            .flags(OpenFlags::default())
78            .open()
79            .await
80            .wrap_err("Error initializing SQLite client")?;
81
82        // Use Write-Ahead Logging (WAL) mode for better concurrency and performance.
83        client
84            .conn(|conn| {
85                Ok(conn
86                    .pragma_update(None, "journal_mode", "wal")
87                    .wrap_err("Error applying journal mode pragma")?)
88            })
89            .await?;
90
91        // Set synchronous mode to NORMAL. This means SQLite will still sync at critical moments, but less frequently
92        // than FULL, offering a good balance between safety and performance.
93        client
94            .conn(|conn| {
95                Ok(conn
96                    .pragma_update(None, "synchronous", "normal")
97                    .wrap_err("Error applying synchronous pragma")?)
98            })
99            .await?;
100
101        // Enforce foreign key constraints to maintain data integrity.
102        // This has a slight performance cost but is crucial for relational data.
103        client
104            .conn(|conn| {
105                Ok(conn
106                    .pragma_update(None, "foreign_keys", "on")
107                    .wrap_err("Error applying foreign keys pragma")?)
108            })
109            .await?;
110
111        // Store temp schema in memory
112        client
113            .conn(|conn| {
114                Ok(conn
115                    .pragma_update(None, "temp_store", "memory")
116                    .wrap_err("Error applying temp store pragma")?)
117            })
118            .await?;
119
120        // Apply all defined database migrations to bring the schema to the latest version.
121        // This is done atomically within a transaction.
122        client
123            .conn_mut(|conn| Ok(MIGRATIONS.to_latest(conn).wrap_err("Error applying migrations")?))
124            .await?;
125
126        // Add a regexp function to the client
127        client
128            .conn(|conn| {
129                Ok(conn
130                    .create_scalar_function(
131                        "regexp",
132                        2,
133                        FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
134                        |ctx| {
135                            assert_eq!(ctx.len(), 2, "regexp() called with unexpected number of arguments");
136
137                            let text = ctx
138                                .get_raw(1)
139                                .as_str_or_null()
140                                .map_err(|e| rusqlite::Error::UserFunctionError(e.into()))?;
141
142                            let Some(text) = text else {
143                                return Ok(false);
144                            };
145
146                            let cached_re: Arc<Regex> =
147                                ctx.get_or_create_aux(0, |vr| Ok::<_, BoxError>(Regex::new(vr.as_str()?)?))?;
148
149                            Ok(cached_re.is_match(text))
150                        },
151                    )
152                    .wrap_err("Error adding regexp function")?)
153            })
154            .await?;
155
156        // Add a cmd-to-regex function
157        client
158            .conn(|conn| {
159                Ok(conn
160                    .create_scalar_function(
161                        "cmd_to_regex",
162                        1,
163                        FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
164                        |ctx| {
165                            assert_eq!(
166                                ctx.len(),
167                                1,
168                                "cmd_to_regex() called with unexpected number of arguments"
169                            );
170                            let cmd_template = ctx.get::<String>(0)?;
171
172                            // Use the SplitCaptures iterator to process both unmatched literals and captured variables
173                            let regex_body = SplitCaptures::new(&COMMAND_VARIABLE_REGEX_QUOTES, &cmd_template)
174                                .filter_map(|item| match item {
175                                    // For unmatched parts, trim them and escape any special regex chars
176                                    SplitItem::Unmatched(s) => {
177                                        let trimmed = s.trim();
178                                        if trimmed.is_empty() {
179                                            None
180                                        } else {
181                                            Some(regex::escape(trimmed))
182                                        }
183                                    }
184                                    // For captured parts (the variables), replace them with a capture group
185                                    SplitItem::Captured(caps) => {
186                                        // Check which capture group matched to see if the placeholder was quoted
187                                        let placeholder_regex = if caps.get(1).is_some() {
188                                            // Group 1 matched '{{...}}', so expect a single-quoted argument
189                                            r"('[^']*')"
190                                        } else if caps.get(2).is_some() {
191                                            // Group 2 matched "{{...}}", so expect a double-quoted argument
192                                            r#"("[^"]*")"#
193                                        } else {
194                                            // Group 3 matched {{...}}, so expect a generic argument
195                                            r#"('[^']*'|"[^"]*"|\S+)"#
196                                        };
197                                        Some(String::from(placeholder_regex))
198                                    },
199                                })
200                                // Join them by any number of whitespaces
201                                .join(r"\s+");
202
203                            // Build the final regex
204                            Ok(format!("^{regex_body}$"))
205                        },
206                    )
207                    .wrap_err("Error adding cmd-to-regex function")?)
208            })
209            .await?;
210
211        Ok(client)
212    }
213
214    #[cfg(debug_assertions)]
215    pub async fn query(&self, sql: String) -> Result<String> {
216        self.client
217            .conn(move |conn| {
218                use prettytable::{Cell, Row, Table};
219                use rusqlite::types::Value;
220
221                let mut stmt = conn.prepare(&sql)?;
222                let column_names = stmt
223                    .column_names()
224                    .into_iter()
225                    .map(String::from)
226                    .collect::<Vec<String>>();
227                let columns_len = column_names.len();
228                let mut table = Table::new();
229                table.add_row(Row::from(column_names));
230                let rows = stmt.query_map([], |row| {
231                    let mut cells = Vec::new();
232                    for i in 0..columns_len {
233                        let value: Value = row.get(i)?;
234                        let cell_value = match value {
235                            Value::Null => "NULL".to_string(),
236                            Value::Integer(i) => i.to_string(),
237                            Value::Real(f) => f.to_string(),
238                            Value::Text(t) => t,
239                            Value::Blob(_) => "[BLOB]".to_string(),
240                        };
241                        cells.push(Cell::new(&cell_value));
242                    }
243                    Ok(Row::from(cells))
244                })?;
245                for row in rows {
246                    table.add_row(row?);
247                }
248                Ok(table.to_string())
249            })
250            .await
251    }
252}