1use core::slice;
2use std::{
3 env, fs,
4 io::{BufRead, BufReader, BufWriter, Write},
5 sync::Mutex,
6};
7
8use anyhow::{anyhow, Context, Result};
9use directories::ProjectDirs;
10use iter_flow::Iterflow;
11use itertools::Itertools;
12use once_cell::sync::Lazy;
13use regex::Regex;
14use rusqlite::{params_from_iter, Connection, Error, ErrorCode, OptionalExtension, Row};
15use rusqlite_migration::{Migrations, M};
16
17use crate::{
18 common::flatten_str,
19 model::{Command, LabelSuggestion},
20};
21
22static MIGRATIONS: Lazy<Migrations> = Lazy::new(|| {
24 Migrations::new(vec![
25 M::up(
26 r#"CREATE TABLE command (
27 category TEXT NOT NULL,
28 alias TEXT NULL,
29 cmd TEXT NOT NULL UNIQUE,
30 description TEXT NOT NULL,
31 usage INTEGER DEFAULT 0
32 );"#,
33 ),
34 M::up(r#"CREATE VIRTUAL TABLE command_fts USING fts5(flat_cmd, flat_description);"#),
35 M::up(
36 r#"CREATE TABLE label_suggestion (
37 flat_root_cmd TEXT NOT NULL,
38 flat_label TEXT NOT NULL,
39 suggestion TEXT NOT NULL,
40 usage INTEGER DEFAULT 0,
41 PRIMARY KEY (flat_root_cmd, flat_label, suggestion)
42 );"#,
43 ),
44 ])
45});
46
47pub const USER_CATEGORY: &str = "user";
49
50static ALLOWED_FTS_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r#"[^a-zA-Z0-9 ]"#).unwrap());
52
53pub struct SqliteStorage {
55 conn: Mutex<Connection>,
56}
57
58impl SqliteStorage {
59 pub fn new() -> Result<Self> {
61 let path = env::var_os("INTELLI_HOME")
62 .map(Into::into)
63 .map(anyhow::Ok)
64 .unwrap_or_else(|| {
65 Ok(ProjectDirs::from("org", "IntelliShell", "Intelli-Shell")
66 .context("Error initializing project dir")?
67 .data_dir()
68 .to_path_buf())
69 })?;
70
71 fs::create_dir_all(&path).context("Could't create data dir")?;
72
73 Ok(Self {
74 conn: Mutex::new(
75 Self::initialize_connection(
76 Connection::open(path.join("storage.db3")).context("Error opening SQLite connection")?,
77 )
78 .context("Error initializing SQLite connection")?,
79 ),
80 })
81 }
82
83 pub fn new_in_memory() -> Result<Self> {
85 Ok(Self {
86 conn: Mutex::new(
87 Self::initialize_connection(Connection::open_in_memory()?)
88 .context("Error initializing SQLite connection")?,
89 ),
90 })
91 }
92
93 fn initialize_connection(mut conn: Connection) -> Result<Connection> {
95 conn.pragma_update(None, "journal_mode", "WAL")
97 .context("Error applying journal mode pragma")?;
98 conn.pragma_update(None, "synchronous", "normal")
100 .context("Error applying synchronous pragma")?;
101 conn.pragma_update(None, "foreign_keys", "on")
103 .context("Error applying foreign keys pragma")?;
104
105 MIGRATIONS.to_latest(&mut conn).context("Error applying migrations")?;
107
108 Ok(conn)
109 }
110
111 pub fn insert_command(&self, command: &mut Command) -> Result<bool> {
117 Ok(self.insert_commands(slice::from_mut(command))? == 1)
118 }
119
120 pub fn insert_commands(&self, commands: &mut [Command]) -> Result<u64> {
126 let mut res = 0;
127
128 let mut conn = self.conn.lock().expect("poisoned lock");
129 let tx = conn.transaction()?;
130
131 {
132 let mut stmt_cmd = tx.prepare(
133 r#"INSERT INTO command (category, alias, cmd, description) VALUES (?, ?, ?, ?)
134 ON CONFLICT(cmd) DO UPDATE SET description=excluded.description
135 RETURNING rowid"#,
136 )?;
137 let mut stmt_fts_check = tx.prepare("SELECT rowid FROM command_fts WHERE rowid = ?")?;
138 let mut stmt_fts_update = tx.prepare("UPDATE command_fts SET flat_description = ? WHERE rowid = ?")?;
139 let mut stmt_fts_insert =
140 tx.prepare("INSERT INTO command_fts (rowid, flat_cmd, flat_description) VALUES (?, ?, ?)")?;
141
142 for command in commands {
143 let row_id = stmt_cmd
144 .query_row(
145 (
146 &command.category,
147 command.alias.as_deref(),
148 &command.cmd,
149 &command.description,
150 ),
151 |r| r.get(0),
152 )
153 .context("Error inserting command")?;
154
155 command.id = row_id;
156
157 let current_row: Option<i32> = stmt_fts_check
158 .query_row([row_id], |r| r.get(0))
159 .optional()
160 .context("Error checking fts")?;
161
162 match current_row {
163 Some(_) => {
164 stmt_fts_update
165 .execute((flatten_str(&command.description), row_id))
166 .context("Error updating command fts")?;
167 }
168 None => {
169 res += 1;
170 stmt_fts_insert
171 .execute((row_id, flatten_str(&command.cmd), flatten_str(&command.description)))
172 .context("Error inserting command fts")?;
173 }
174 }
175 }
176 }
177
178 tx.commit()?;
179
180 Ok(res)
181 }
182
183 pub fn update_command(&self, command: &Command) -> Result<bool> {
187 let mut conn = self.conn.lock().expect("poisoned lock");
188 let tx = conn.transaction()?;
189
190 let updated = tx
191 .execute(
192 r#"UPDATE command SET alias = ?, cmd = ?, description = ?, usage = ? WHERE rowid = ?"#,
193 (
194 command.alias.as_deref(),
195 &command.cmd,
196 &command.description,
197 command.usage,
198 command.id,
199 ),
200 )
201 .context("Error updating command")?;
202
203 if updated == 1 {
204 let updated = tx
205 .execute(
206 r#"UPDATE command_fts SET flat_cmd = ?, flat_description = ? WHERE rowid = ?"#,
207 (flatten_str(&command.cmd), flatten_str(&command.description), command.id),
208 )
209 .context("Error updating command fts")?;
210 if updated == 1 {
211 tx.commit()?;
212 Ok(true)
213 } else {
214 Ok(false)
215 }
216 } else {
217 Ok(false)
218 }
219 }
220
221 pub fn increment_command_usage(&self, command_id: i64) -> Result<bool> {
225 let conn = self.conn.lock().expect("poisoned lock");
226 let updated = conn
227 .execute(r#"UPDATE command SET usage = usage + 1 WHERE rowid = ?"#, [command_id])
228 .context("Error updating command usage")?;
229
230 Ok(updated == 1)
231 }
232
233 pub fn delete_command(&self, command_id: i64) -> Result<bool> {
237 let mut conn = self.conn.lock().expect("poisoned lock");
238 let tx = conn.transaction()?;
239
240 let deleted = tx
241 .execute(r#"DELETE FROM command WHERE rowid = ?"#, [command_id])
242 .context("Error deleting command")?;
243
244 if deleted == 1 {
245 let deleted = tx
246 .execute(r#"DELETE FROM command_fts WHERE rowid = ?"#, [command_id])
247 .context("Error deleting command fts")?;
248 if deleted == 1 {
249 tx.commit()?;
250 Ok(true)
251 } else {
252 Ok(false)
253 }
254 } else {
255 Ok(false)
256 }
257 }
258
259 pub fn get_commands(&self, category: impl AsRef<str>) -> Result<Vec<Command>> {
261 let category = category.as_ref();
262
263 let conn = self.conn.lock().expect("poisoned lock");
264 let mut stmt = conn.prepare(
265 r#"SELECT rowid, category, alias, cmd, description, usage
266 FROM command
267 WHERE category = ?
268 ORDER BY usage DESC"#,
269 )?;
270
271 let commands = stmt
272 .query([category])?
273 .mapped(command_from_row)
274 .finish_vec()
275 .context("Error querying commands")?;
276
277 Ok(commands)
278 }
279
280 pub fn find_commands(&self, search: impl AsRef<str>) -> Result<Vec<Command>> {
282 let search = search.as_ref().trim();
283 if search.is_empty() {
284 return self.get_commands(USER_CATEGORY);
285 }
286 let flat_search = flatten_str(search);
287
288 let conn = self.conn.lock().expect("poisoned lock");
289 let alias_cmd = conn
290 .query_row(
291 r#"SELECT rowid, category, alias, cmd, description, usage
292 FROM command
293 WHERE alias = :flat_search OR alias = :search"#,
294 &[(":flat_search", flat_search.as_str()), (":search", search)],
295 command_from_row,
296 )
297 .optional()
298 .context("Error querying command by alias")?;
299 if let Some(cmd) = alias_cmd {
300 return Ok(vec![cmd]);
301 }
302
303 let hashtags = flat_search
304 .split_whitespace()
305 .filter(|t| t.starts_with('#'))
306 .collect_vec();
307
308 let flat_fts_search = ALLOWED_FTS_REGEX.replace_all(&flat_search, "");
309 let flat_fts_search = flat_fts_search.trim();
310 if flat_fts_search.is_empty() || flat_fts_search == " " {
311 drop(conn);
312 return self.get_commands(USER_CATEGORY);
313 }
314
315 let mut stmt = conn.prepare(
316 r#"
317 SELECT DISTINCT rowid, category, alias, cmd, description, usage
318 FROM (
319 SELECT c.rowid, c.category, c.alias, c.cmd, c.description, c.usage, 2 as ord
320 FROM command_fts s
321 JOIN command c ON s.rowid = c.rowid
322 WHERE command_fts MATCH :match_cmd_ordered
323
324 UNION ALL
325
326 SELECT c.rowid, c.category, c.alias, c.cmd, c.description, c.usage, 1 as ord
327 FROM command_fts s
328 JOIN command c ON s.rowid = c.rowid
329 WHERE command_fts MATCH :match_simple
330
331 UNION ALL
332
333 SELECT c.rowid, c.category, c.alias, c.cmd, c.description, c.usage, 0 as ord
334 FROM command_fts s
335 JOIN command c ON s.rowid = c.rowid
336 WHERE s.flat_cmd GLOB :glob OR s.flat_description GLOB :glob
337 )
338 ORDER BY ord DESC, usage DESC, (CASE WHEN category = 'user' THEN 1 ELSE 0 END) DESC
339 "#,
340 )?;
341
342 let match_cmd_ordered = format!(
343 "\"flat_cmd\" : ^{}",
344 flat_fts_search
345 .split_whitespace()
346 .map(|token| format!("{token}*"))
347 .join(" + ")
348 );
349 let match_simple = flat_fts_search
350 .split_whitespace()
351 .map(|token| format!("{token}*"))
352 .join(" ");
353 let glob = flat_search
354 .split_whitespace()
355 .map(|token| format!("*{token}*"))
356 .join(" ");
357
358 let commands = stmt
359 .query(&[
360 (":match_cmd_ordered", &match_cmd_ordered),
361 (":match_simple", &match_simple),
362 (":glob", &glob),
363 ])?
364 .mapped(command_from_row)
365 .filter(|r| {
366 if !hashtags.is_empty() {
367 if let Ok(command) = r {
368 for tag in &hashtags {
369 if !command.description.contains(tag) {
370 return false;
371 }
372 }
373 }
374 }
375 true
376 })
377 .finish_vec()
378 .context("Error querying fts command")?;
379
380 Ok(commands)
381 }
382
383 pub fn export(&self, category: impl AsRef<str>, file_path: impl Into<String>) -> Result<usize> {
389 let category = category.as_ref();
390 let file_path = file_path.into();
391 let commands = self.get_commands(category)?;
392 let size = commands.len();
393 let file = fs::File::create(&file_path).context("Error creating output file")?;
394 let mut w = BufWriter::new(file);
395 for command in commands {
396 writeln!(w, "{} ## {}", command.cmd, command.description).context("Error writing file")?;
397 }
398 w.flush().context("Error writing file")?;
399 Ok(size)
400 }
401
402 pub fn import(&self, category: impl AsRef<str>, file_path: String) -> Result<u64> {
408 let category = category.as_ref();
409 let file = fs::File::open(file_path).context("Error opening file")?;
410 let r = BufReader::new(file);
411 let mut commands = r
412 .lines()
413 .map_err(anyhow::Error::from)
414 .filter_ok(|line| !line.is_empty() && !line.starts_with('#'))
415 .and_then(|line| {
416 let (cmd, description) = line
417 .split_once(" ## ")
418 .ok_or_else(|| anyhow!("Unexpected file format"))?;
419 Ok::<_, anyhow::Error>(Command::new(category, cmd, description))
420 })
421 .finish_vec()?;
422
423 let new = self.insert_commands(&mut commands)?;
424
425 Ok(new)
426 }
427
428 pub fn is_empty(&self) -> Result<bool> {
430 Ok(self.len()? == 0)
431 }
432
433 pub fn len(&self) -> Result<u64> {
435 let conn = self.conn.lock().expect("poisoned lock");
436 let mut stmt = conn.prepare(r#"SELECT COUNT(*) FROM command"#)?;
437 Ok(stmt.query_row([], |r| r.get(0))?)
438 }
439
440 pub fn insert_label_suggestion(&self, suggestion: &LabelSuggestion) -> Result<bool> {
444 if suggestion.flat_label == suggestion.suggestion {
445 return Ok(false);
446 }
447
448 let conn = self.conn.lock().expect("poisoned lock");
449 let inserted = match conn.execute(
450 r#"INSERT INTO label_suggestion (flat_root_cmd, flat_label, suggestion, usage) VALUES (?, ?, ?, ?)"#,
451 (
452 &suggestion.flat_root_cmd,
453 &suggestion.flat_label,
454 &suggestion.suggestion,
455 suggestion.usage,
456 ),
457 ) {
458 Ok(i) => i,
459 Err(Error::SqliteFailure(err, msg)) => match err.code {
460 ErrorCode::ConstraintViolation => return Ok(false),
461 _ => {
462 return Err(
463 anyhow::Error::new(Error::SqliteFailure(err, msg)).context("Error inserting label suggestion")
464 );
465 }
466 },
467 Err(err) => {
468 return Err(anyhow::Error::new(err).context("Error inserting label suggestion"));
469 }
470 };
471
472 Ok(inserted == 1)
473 }
474
475 pub fn update_label_suggestion(
479 &self,
480 suggestion: &mut LabelSuggestion,
481 new_suggestion: impl Into<String>,
482 ) -> Result<bool> {
483 let conn = self.conn.lock().expect("poisoned lock");
484 let new_suggestion = new_suggestion.into();
485 let updated = conn
486 .execute(
487 r#"UPDATE label_suggestion SET suggestion = ? WHERE flat_root_cmd = ? AND flat_label = ? AND suggestion = ?"#,
488 (
489 &new_suggestion,
490 &suggestion.flat_root_cmd,
491 &suggestion.flat_label,
492 &suggestion.suggestion,
493 ),
494 )
495 .context("Error updating label suggestion")?;
496
497 let updated = updated == 1;
498
499 if updated {
500 suggestion.suggestion = new_suggestion;
501 }
502
503 Ok(updated)
504 }
505
506 pub fn update_label_suggestion_usage(&self, suggestion: &LabelSuggestion) -> Result<bool> {
510 let conn = self.conn.lock().expect("poisoned lock");
511 let updated = conn
512 .execute(
513 r#"UPDATE label_suggestion SET usage = ? WHERE flat_root_cmd = ? AND flat_label = ? AND suggestion = ?"#,
514 (
515 suggestion.usage,
516 &suggestion.flat_root_cmd,
517 &suggestion.flat_label,
518 &suggestion.suggestion,
519 ),
520 )
521 .context("Error updating label suggestion usage")?;
522
523 Ok(updated == 1)
524 }
525
526 pub fn delete_label_suggestion(&self, suggestion: &LabelSuggestion) -> Result<bool> {
530 let conn = self.conn.lock().expect("poisoned lock");
531 let deleted = conn
532 .execute(
533 r#"DELETE FROM label_suggestion WHERE flat_root_cmd = ? AND flat_label = ? AND suggestion = ?"#,
534 (
535 &suggestion.flat_root_cmd,
536 &suggestion.flat_label,
537 &suggestion.suggestion,
538 ),
539 )
540 .context("Error deleting label suggestion")?;
541
542 Ok(deleted == 1)
543 }
544
545 pub fn find_suggestions_for(
547 &self,
548 root_cmd: impl AsRef<str>,
549 label: impl AsRef<str>,
550 ) -> Result<Vec<LabelSuggestion>> {
551 let flat_root_cmd = flatten_str(root_cmd.as_ref());
552 let label = label.as_ref();
553 let mut parameters = label.split('|').map(flatten_str).collect_vec();
554 parameters.insert(0, flatten_str(label));
555
556 const QUERY: &str = r#"
557 SELECT * FROM (
558 SELECT
559 s.flat_root_cmd,
560 s.flat_label,
561 s.suggestion,
562 s.usage,
563 q.sum_usage,
564 RANK () OVER (
565 PARTITION BY s.suggestion
566 ORDER BY LENGTH(s.flat_label) DESC
567 ) rank
568 FROM label_suggestion s
569 JOIN (
570 SELECT flat_root_cmd, suggestion, SUM(usage) as sum_usage
571 FROM label_suggestion
572 WHERE flat_root_cmd = ?1 AND flat_label IN (#LABELS#)
573 GROUP BY flat_root_cmd, suggestion
574 ) q ON s.flat_root_cmd = q.flat_root_cmd AND s.suggestion = q.suggestion
575 )
576 WHERE rank = 1
577 ORDER BY
578 sum_usage DESC,
579 (CASE WHEN flat_label = ?2 THEN 1 ELSE 0 END) DESC
580 "#;
581
582 let conn = self.conn.lock().expect("poisoned lock");
583 let mut stmt = conn.prepare(
584 &QUERY.replace(
585 "#LABELS#",
586 ¶meters
587 .iter()
588 .enumerate()
589 .map(|(i, _)| format!("?{}", i + 2))
590 .join(","),
591 ),
592 )?;
593
594 parameters.insert(0, flat_root_cmd);
595
596 let suggestions = stmt
597 .query(params_from_iter(parameters.iter()))?
598 .mapped(label_suggestion_from_row)
599 .finish_vec()
600 .context("Error querying label suggestions")?;
601
602 Ok(suggestions)
603 }
604}
605
606fn command_from_row(row: &Row<'_>) -> rusqlite::Result<Command> {
608 Ok(Command {
609 id: row.get(0)?,
610 category: row.get(1)?,
611 alias: row.get(2)?,
612 cmd: row.get(3)?,
613 description: row.get(4)?,
614 usage: row.get(5)?,
615 })
616}
617
618fn label_suggestion_from_row(row: &Row<'_>) -> rusqlite::Result<LabelSuggestion> {
620 Ok(LabelSuggestion {
621 flat_root_cmd: row.get(0)?,
622 flat_label: row.get(1)?,
623 suggestion: row.get(2)?,
624 usage: row.get(3)?,
625 })
626}
627
628impl Drop for SqliteStorage {
629 fn drop(&mut self) {
630 let conn = self.conn.lock().expect("poisoned lock");
631 conn.pragma_update(None, "analysis_limit", "400")
633 .expect("Failed analysis_limit PRAGMA");
634 conn.execute_batch("PRAGMA optimize;").expect("Failed optimize PRAGMA");
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::MIGRATIONS;
642
643 #[test]
644 fn migrations_test() {
645 assert!(MIGRATIONS.validate().is_ok());
646 }
647}