1use std::{cmp::Ordering, sync::atomic::Ordering as AtomicOrdering};
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{Row, fallible_iterator::FallibleIterator, ffi, types::Type};
5use sea_query::SqliteQueryBuilder;
6use sea_query_rusqlite::RusqliteBinder;
7use tracing::instrument;
8use uuid::Uuid;
9
10use super::{SqliteStorage, queries::*};
11use crate::{
12 config::SearchCommandTuning,
13 errors::{Result, UserFacingError},
14 model::{Command, SOURCE_TLDR, SearchCommandsFilter},
15};
16
17impl SqliteStorage {
18 #[instrument(skip_all)]
21 pub async fn setup_workspace_storage(&self) -> Result<()> {
22 tracing::trace!("Creating workspace-specific tables");
23 self.client
24 .conn_mut(|conn| {
25 let schemas: Vec<String> = conn
27 .prepare(
28 r"SELECT sql
29 FROM sqlite_master
30 WHERE (type = 'table' AND name = 'variable_completion')
31 OR (type = 'table' AND name = 'command')
32 OR (type = 'table' AND name LIKE 'command_%fts')
33 OR (type = 'trigger' AND name LIKE 'command_%_fts' AND tbl_name = 'command')",
34 )?
35 .query_map([], |row| row.get(0))?
36 .collect::<Result<Vec<String>, _>>()?;
37
38 let tx = conn.transaction()?;
39
40 for schema in schemas {
42 let temp_schema = schema
43 .replace("variable_completion", "workspace_variable_completion")
44 .replace("command", "workspace_command")
45 .replace("CREATE TABLE ", "CREATE TEMP TABLE ")
46 .replace("CREATE VIRTUAL TABLE ", "CREATE VIRTUAL TABLE temp.")
47 .replace("CREATE TRIGGER ", "CREATE TEMP TRIGGER ");
48 tracing::trace!("Executing:\n{temp_schema}");
49 tx.execute(&temp_schema, [])?;
50 }
51
52 tx.commit()?;
53 Ok(())
54 })
55 .await?;
56
57 self.workspace_tables_loaded.store(true, AtomicOrdering::SeqCst);
58
59 Ok(())
60 }
61
62 #[instrument(skip_all)]
64 pub async fn is_empty(&self) -> Result<bool> {
65 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
66 self.client
67 .conn(move |conn| {
68 let query = if workspace_tables_loaded {
69 "SELECT NOT EXISTS (SELECT 1 FROM command UNION ALL SELECT 1 FROM workspace_command)"
70 } else {
71 "SELECT NOT EXISTS(SELECT 1 FROM command)"
72 };
73 tracing::trace!("Checking if storage is empty:\n{query}");
74 Ok(conn.query_row(query, [], |r| r.get(0))?)
75 })
76 .await
77 }
78
79 #[instrument(skip_all)]
81 pub async fn find_tags(
82 &self,
83 filter: SearchCommandsFilter,
84 tag_prefix: Option<String>,
85 tuning: &SearchCommandTuning,
86 ) -> Result<Vec<(String, u64, bool)>> {
87 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
88 let query = query_find_tags(filter, tag_prefix, tuning, workspace_tables_loaded)?;
89 if tracing::enabled!(tracing::Level::TRACE) {
90 tracing::trace!("Querying tags:\n{}", query.to_string(SqliteQueryBuilder));
91 }
92 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
93 self.client
94 .conn(move |conn| {
95 conn.prepare(&stmt)?
96 .query(&*values.as_params())?
97 .and_then(|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))
98 .collect()
99 })
100 .await
101 }
102
103 #[instrument(skip_all)]
108 pub async fn find_commands(
109 &self,
110 filter: SearchCommandsFilter,
111 working_path: impl Into<String>,
112 tuning: &SearchCommandTuning,
113 ) -> Result<(Vec<Command>, bool)> {
114 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
115 let cleaned_filter = filter.cleaned();
116
117 let mut query_alias = None;
119 if let Some(ref term) = cleaned_filter.search_term {
120 if workspace_tables_loaded {
122 query_alias = Some((
123 format!(
124 r#"SELECT *
125 FROM (
126 SELECT rowid, * FROM workspace_command
127 UNION ALL
128 SELECT rowid, * FROM command
129 ) c
130 WHERE c.alias IS NOT NULL AND c.alias = ?1
131 LIMIT {QUERY_LIMIT}"#
132 ),
133 (term.clone(),),
134 ));
135 } else {
136 query_alias = Some((
137 format!(
138 r#"SELECT c.rowid, c.*
139 FROM command c
140 WHERE c.alias IS NOT NULL AND c.alias = ?1
141 LIMIT {QUERY_LIMIT}"#
142 ),
143 (term.clone(),),
144 ));
145 }
146 }
147
148 let query = query_find_commands(cleaned_filter, working_path, tuning, workspace_tables_loaded)?;
150 let query_trace = if tracing::enabled!(tracing::Level::TRACE) {
151 query.to_string(SqliteQueryBuilder)
152 } else {
153 String::default()
154 };
155 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
156
157 let tuning = *tuning;
159 self.client
160 .conn(move |conn| {
161 if let Some((query_alias, a_params)) = query_alias {
163 tracing::trace!("Querying aliased commands:\n{query_alias}");
164 let rows = conn
166 .prepare(&query_alias)?
167 .query(a_params)?
168 .map(|r| Command::try_from(r))
169 .collect::<Vec<_>>()?;
170 if !rows.is_empty() {
172 return Ok((rows, true));
173 }
174 }
175 if tracing::enabled!(tracing::Level::TRACE) {
177 tracing::trace!("Querying commands:\n{query_trace}");
178 }
179 Ok((
180 rerank_query_results(
181 conn.prepare(&stmt)?
182 .query(&*values.as_params())?
183 .and_then(|r| QueryResultItem::try_from(r))
184 .collect::<Result<Vec<_>, _>>()?,
185 &tuning,
186 ),
187 false,
188 ))
189 })
190 .await
191 }
192
193 #[instrument(skip_all)]
195 pub async fn delete_tldr_commands(&self, category: Option<String>) -> Result<u64> {
196 self.client
197 .conn_mut(move |conn| {
198 let mut query = String::from("DELETE FROM command WHERE source = ?1");
199 let mut params: Vec<String> = vec![SOURCE_TLDR.to_owned()];
200 if let Some(cat) = category {
201 query.push_str(" AND category = ?2");
202 params.push(cat);
203 }
204 tracing::trace!("Deleting tldr commands:\n{query}");
205 let affected = conn.execute(&query, rusqlite::params_from_iter(params))?;
206 Ok(affected as u64)
207 })
208 .await
209 }
210
211 #[instrument(skip_all)]
215 pub async fn insert_command(&self, command: Command) -> Result<Command> {
216 self.client
217 .conn(move |conn| {
218 let query = r#"INSERT INTO command (
219 id,
220 category,
221 source,
222 alias,
223 cmd,
224 flat_cmd,
225 description,
226 flat_description,
227 tags,
228 created_at,
229 updated_at
230 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)"#;
231 tracing::trace!("Inserting a command:\n{query}");
232 let res = conn.execute(
233 query,
234 (
235 &command.id,
236 &command.category,
237 &command.source,
238 &command.alias,
239 &command.cmd,
240 &command.flat_cmd,
241 &command.description,
242 &command.flat_description,
243 serde_json::to_value(&command.tags)?,
244 &command.created_at,
245 &command.updated_at,
246 ),
247 );
248 match res {
249 Ok(_) => Ok(command),
250 Err(err) => {
251 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
252 if code == ffi::SQLITE_CONSTRAINT_UNIQUE || code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY {
253 Err(UserFacingError::CommandAlreadyExists.into())
254 } else {
255 Err(Report::from(err).into())
256 }
257 }
258 }
259 })
260 .await
261 }
262
263 #[instrument(skip_all)]
267 pub async fn update_command(&self, command: Command) -> Result<Command> {
268 self.client
269 .conn(move |conn| {
270 let query = r#"UPDATE command SET
271 category = ?2,
272 source = ?3,
273 alias = ?4,
274 cmd = ?5,
275 flat_cmd = ?6,
276 description = ?7,
277 flat_description = ?8,
278 tags = ?9,
279 created_at = ?10,
280 updated_at = ?11
281 WHERE id = ?1"#;
282 tracing::trace!("Updating a command:\n{query}");
283 let res = conn.execute(
284 query,
285 (
286 &command.id,
287 &command.category,
288 &command.source,
289 &command.alias,
290 &command.cmd,
291 &command.flat_cmd,
292 &command.description,
293 &command.flat_description,
294 serde_json::to_value(&command.tags)?,
295 &command.created_at,
296 &command.updated_at,
297 ),
298 );
299 match res {
300 Ok(0) => Err(eyre!("Command not found: {}", command.id).into()),
301 Ok(_) => Ok(command),
302 Err(err) => {
303 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
304 if code == ffi::SQLITE_CONSTRAINT_UNIQUE {
305 Err(UserFacingError::CommandAlreadyExists.into())
306 } else {
307 Err(Report::from(err).into())
308 }
309 }
310 }
311 })
312 .await
313 }
314
315 #[instrument(skip_all)]
317 pub async fn increment_command_usage(
318 &self,
319 command_id: Uuid,
320 path: impl AsRef<str> + Send + 'static,
321 ) -> Result<i32> {
322 self.client
323 .conn_mut(move |conn| {
324 let query = r#"
325 INSERT INTO command_usage (command_id, path, usage_count)
326 VALUES (?1, ?2, 1)
327 ON CONFLICT(command_id, path) DO UPDATE SET
328 usage_count = usage_count + 1
329 RETURNING usage_count;"#;
330 tracing::trace!("Incrementing command usage:\n{query}");
331 Ok(conn.query_row(query, (&command_id, &path.as_ref()), |r| r.get(0))?)
332 })
333 .await
334 }
335
336 #[instrument(skip_all)]
340 pub async fn delete_command(&self, command_id: Uuid) -> Result<()> {
341 self.client
342 .conn(move |conn| {
343 let query = "DELETE FROM command WHERE id = ?1";
344 tracing::trace!("Deleting command:\n{query}");
345 let res = conn.execute(query, (&command_id,));
346 match res {
347 Ok(0) => Err(eyre!("Command not found: {command_id}").into()),
348 Ok(_) => Ok(()),
349 Err(err) => Err(Report::from(err).into()),
350 }
351 })
352 .await
353 }
354}
355
356fn rerank_query_results(items: Vec<QueryResultItem>, tuning: &SearchCommandTuning) -> Vec<Command> {
365 if items.is_empty() {
367 return Vec::new();
368 }
369 if items.len() == 1 {
370 return items.into_iter().map(|item| item.command).collect();
371 }
372
373 let (template_matches, mut other_items): (Vec<_>, Vec<_>) = items
376 .into_iter()
377 .partition(|item| item.text_score >= TEMPLATE_MATCH_RANK);
378 if !template_matches.is_empty() {
379 tracing::trace!("Found {} template matches", template_matches.len());
380 }
381
382 let mut final_commands: Vec<Command> = template_matches.into_iter().map(|item| item.command).collect();
384
385 if other_items.len() <= 1 {
387 final_commands.extend(other_items.into_iter().map(|item| item.command));
388 return final_commands;
389 }
390
391 let mut min_text = 0f64;
394 let mut min_path = 0f64;
395 let mut min_usage = 0f64;
396 let mut max_text = f64::NEG_INFINITY;
397 let mut max_path = f64::NEG_INFINITY;
398 let mut max_usage = f64::NEG_INFINITY;
399 for item in &other_items {
400 min_text = min_text.min(item.text_score);
401 max_text = max_text.max(item.text_score);
402 min_path = min_path.min(item.path_score);
403 max_path = max_path.max(item.path_score);
404 min_usage = min_usage.min(item.usage_score);
405 max_usage = max_usage.max(item.usage_score);
406 }
407
408 let range_text = (max_text > min_text).then_some(max_text - min_text);
410 let range_path = (max_path > min_path).then_some(max_path - min_path);
411 let range_usage = (max_usage > min_usage).then_some(max_usage - min_usage);
412
413 other_items.sort_by(|a, b| {
415 match b.is_workspace_command.cmp(&a.is_workspace_command) {
417 Ordering::Equal => {
418 let calculate_score = |item: &QueryResultItem| -> f64 {
420 let norm_text = range_text.map_or(0.5, |range| (item.text_score - min_text) / range);
423 let norm_path = range_path.map_or(0.5, |range| (item.path_score - min_path) / range);
424 let norm_usage = range_usage.map_or(0.5, |range| (item.usage_score - min_usage) / range);
425
426 (norm_text * tuning.text.points as f64)
428 + (norm_path * tuning.path.points as f64)
429 + (norm_usage * tuning.usage.points as f64)
430 };
431
432 let final_score_a = calculate_score(a);
433 let final_score_b = calculate_score(b);
434
435 final_score_b.partial_cmp(&final_score_a).unwrap_or(Ordering::Equal)
437 }
438 other => other,
440 }
441 });
442
443 final_commands.extend(other_items.into_iter().map(|item| item.command));
445 final_commands
446}
447
448impl<'a> TryFrom<&'a Row<'a>> for Command {
449 type Error = rusqlite::Error;
450
451 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
452 Ok(Self {
453 id: row.get(1)?,
455 category: row.get(2)?,
456 source: row.get(3)?,
457 alias: row.get(4)?,
458 cmd: row.get(5)?,
459 flat_cmd: row.get(6)?,
460 description: row.get(7)?,
461 flat_description: row.get(8)?,
462 tags: serde_json::from_value(row.get::<_, serde_json::Value>(9)?)
463 .map_err(|e| rusqlite::Error::FromSqlConversionFailure(9, Type::Text, Box::new(e)))?,
464 created_at: row.get(10)?,
465 updated_at: row.get(11)?,
466 })
467 }
468}
469
470struct QueryResultItem {
472 command: Command,
474 is_workspace_command: bool,
476 usage_score: f64,
478 path_score: f64,
480 text_score: f64,
482}
483
484impl<'a> TryFrom<&'a Row<'a>> for QueryResultItem {
485 type Error = rusqlite::Error;
486
487 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
488 Ok(Self {
489 command: Command::try_from(row)?,
490 is_workspace_command: row.get(12)?,
491 usage_score: row.get(13)?,
492 path_score: row.get(14)?,
493 text_score: row.get(15)?,
494 })
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use pretty_assertions::assert_eq;
501 use strum::IntoEnumIterator;
502 use tokio_stream::iter;
503 use uuid::Uuid;
504
505 use super::*;
506 use crate::{
507 errors::AppError,
508 model::{CATEGORY_USER, ImportExportItem, SOURCE_IMPORT, SOURCE_USER, SearchMode},
509 };
510
511 const PROJ_A_PATH: &str = "/home/user/project-a";
512 const PROJ_A_API_PATH: &str = "/home/user/project-a/api";
513 const PROJ_B_PATH: &str = "/home/user/project-b";
514 const UNRELATED_PATH: &str = "/var/log";
515
516 #[tokio::test]
517 async fn test_setup_workspace_storage() {
518 let storage = SqliteStorage::new_in_memory().await.unwrap();
519 storage.check_sqlite_version().await;
520 let res = storage.setup_workspace_storage().await;
521 assert!(res.is_ok(), "Expected workspace storage setup to succeed: {res:?}");
522 }
523
524 #[tokio::test]
525 async fn test_is_empty() {
526 let storage = SqliteStorage::new_in_memory().await.unwrap();
527 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
528
529 let cmd = Command {
530 id: Uuid::now_v7(),
531 cmd: "test_cmd".to_string(),
532 ..Default::default()
533 };
534 storage.insert_command(cmd).await.unwrap();
535
536 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
537 }
538
539 #[tokio::test]
540 async fn test_is_empty_with_workspace() {
541 let storage = SqliteStorage::new_in_memory().await.unwrap();
542 storage.setup_workspace_storage().await.unwrap();
543 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
544
545 let cmd = Command {
546 id: Uuid::now_v7(),
547 cmd: "test_cmd".to_string(),
548 ..Default::default()
549 };
550 storage.insert_command(cmd).await.unwrap();
551
552 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
553 }
554
555 #[tokio::test]
556 async fn test_find_tags_no_filters() -> Result<()> {
557 let storage = setup_ranking_storage().await;
558
559 let result = storage
560 .find_tags(SearchCommandsFilter::default(), None, &SearchCommandTuning::default())
561 .await?;
562
563 let expected = vec![
564 ("#git".to_string(), 5, false),
565 ("#build".to_string(), 2, false),
566 ("#commit".to_string(), 2, false),
567 ("#docker".to_string(), 2, false),
568 ("#list".to_string(), 2, false),
569 ("#k8s".to_string(), 1, false),
570 ("#npm".to_string(), 1, false),
571 ("#pod".to_string(), 1, false),
572 ("#push".to_string(), 1, false),
573 ("#unix".to_string(), 1, false),
574 ];
575
576 assert_eq!(result.len(), 10, "Expected 10 unique tags");
577 assert_eq!(result, expected, "Tags list or order mismatch");
578
579 Ok(())
580 }
581
582 #[tokio::test]
583 async fn test_find_tags_filter_by_tags_only() -> Result<()> {
584 let storage = setup_ranking_storage().await;
585
586 let filter1 = SearchCommandsFilter {
587 tags: Some(vec!["#git".to_string()]),
588 ..Default::default()
589 };
590 let result1 = storage
591 .find_tags(filter1, None, &SearchCommandTuning::default())
592 .await?;
593 let expected1 = vec![("#commit".to_string(), 2, false), ("#push".to_string(), 1, false)];
594 assert_eq!(result1.len(), 2,);
595 assert_eq!(result1, expected1);
596
597 let filter2 = SearchCommandsFilter {
598 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
599 ..Default::default()
600 };
601 let result2 = storage
602 .find_tags(filter2, None, &SearchCommandTuning::default())
603 .await?;
604 assert!(result2.is_empty());
605
606 let filter3 = SearchCommandsFilter {
607 tags: Some(vec!["#list".to_string()]),
608 ..Default::default()
609 };
610 let result3 = storage
611 .find_tags(filter3, None, &SearchCommandTuning::default())
612 .await?;
613 let expected3 = vec![("#docker".to_string(), 1, false), ("#unix".to_string(), 1, false)];
614 assert_eq!(result3.len(), 2);
615 assert_eq!(result3, expected3);
616
617 Ok(())
618 }
619
620 #[tokio::test]
621 async fn test_find_tags_filter_by_prefix_only() -> Result<()> {
622 let storage = setup_ranking_storage().await;
623
624 let result = storage
625 .find_tags(
626 SearchCommandsFilter::default(),
627 Some("#comm".to_string()),
628 &SearchCommandTuning::default(),
629 )
630 .await?;
631 let expected = vec![("#commit".to_string(), 2, false)];
632 assert_eq!(result.len(), 1);
633 assert_eq!(result, expected);
634
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn test_find_tags_filter_by_tags_and_prefix() -> Result<()> {
640 let storage = setup_ranking_storage().await;
641
642 let filter1 = SearchCommandsFilter {
643 tags: Some(vec!["#git".to_string()]),
644 ..Default::default()
645 };
646 let result1 = storage
647 .find_tags(filter1, Some("#comm".to_string()), &SearchCommandTuning::default())
648 .await?;
649 let expected1 = vec![("#commit".to_string(), 2, false)];
650 assert_eq!(result1.len(), 1);
651 assert_eq!(result1, expected1);
652
653 let filter2 = SearchCommandsFilter {
654 tags: Some(vec!["#git".to_string()]),
655 ..Default::default()
656 };
657 let result2 = storage
658 .find_tags(filter2, Some("#push".to_string()), &SearchCommandTuning::default())
659 .await?;
660 let expected2 = vec![("#push".to_string(), 1, true)];
661 assert_eq!(result2.len(), 1);
662 assert_eq!(result2, expected2);
663
664 Ok(())
665 }
666
667 #[tokio::test]
668 async fn test_find_commands_no_filter() {
669 let storage = setup_ranking_storage().await;
670 let filter = SearchCommandsFilter::default();
671 let (commands, _) = storage
672 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
673 .await
674 .unwrap();
675 assert_eq!(commands.len(), 10, "Expected all sample commands");
676 }
677
678 #[tokio::test]
679 async fn test_find_commands_filter_by_category() {
680 let storage = setup_ranking_storage().await;
681 let filter = SearchCommandsFilter {
682 category: Some(vec!["git".to_string()]),
683 ..Default::default()
684 };
685 let (commands, _) = storage
686 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
687 .await
688 .unwrap();
689 assert_eq!(commands.len(), 2);
690 assert!(commands.iter().all(|c| c.category == "git"));
691
692 let filter_no_match = SearchCommandsFilter {
693 category: Some(vec!["nonexistent".to_string()]),
694 ..Default::default()
695 };
696 let (commands_no_match, _) = storage
697 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
698 .await
699 .unwrap();
700 assert!(commands_no_match.is_empty());
701 }
702
703 #[tokio::test]
704 async fn test_find_commands_filter_by_source() {
705 let storage = setup_ranking_storage().await;
706 let filter = SearchCommandsFilter {
707 source: Some(SOURCE_TLDR.to_string()),
708 ..Default::default()
709 };
710 let (commands, _) = storage
711 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
712 .await
713 .unwrap();
714 assert_eq!(commands.len(), 3);
715 assert!(commands.iter().all(|c| c.source == SOURCE_TLDR));
716 }
717
718 #[tokio::test]
719 async fn test_find_commands_filter_by_tags() {
720 let storage = setup_ranking_storage().await;
721 let filter_single_tag = SearchCommandsFilter {
722 tags: Some(vec!["#git".to_string()]),
723 ..Default::default()
724 };
725 let (commands_single_tag, _) = storage
726 .find_commands(filter_single_tag, "/some/path", &SearchCommandTuning::default())
727 .await
728 .unwrap();
729 assert_eq!(commands_single_tag.len(), 5);
730
731 let filter_multiple_tags = SearchCommandsFilter {
732 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
733 ..Default::default()
734 };
735 let (commands_multiple_tags, _) = storage
736 .find_commands(filter_multiple_tags, "/some/path", &SearchCommandTuning::default())
737 .await
738 .unwrap();
739 assert_eq!(commands_multiple_tags.len(), 1);
740
741 let filter_empty_tags = SearchCommandsFilter {
742 tags: Some(vec![]),
743 ..Default::default()
744 };
745 let (commands_empty_tags, _) = storage
746 .find_commands(filter_empty_tags, "/some/path", &SearchCommandTuning::default())
747 .await
748 .unwrap();
749 assert_eq!(commands_empty_tags.len(), 10);
750 }
751
752 #[tokio::test]
753 async fn test_find_commands_alias_precedence() {
754 let storage = setup_ranking_storage().await;
755 storage
756 .setup_command(
757 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
758 [("/some/path", 100)],
759 )
760 .await;
761
762 for mode in SearchMode::iter() {
763 let filter = SearchCommandsFilter {
764 search_term: Some("gc".to_string()),
765 search_mode: mode,
766 ..Default::default()
767 };
768 let (commands, alias_match) = storage
769 .find_commands(filter, "", &SearchCommandTuning::default())
770 .await
771 .unwrap();
772 assert!(alias_match, "Expected alias match for mode {mode:?}");
773 assert_eq!(commands.len(), 1, "Expected only alias match for mode {mode:?}");
774 assert_eq!(
775 commands[0].cmd, "git commit -m",
776 "Expected correct alias command for mode {mode:?}"
777 );
778 }
779 }
780
781 #[tokio::test]
782 async fn test_find_commands_search_mode_exact() {
783 let storage = setup_ranking_storage().await;
784 storage.setup_workspace_storage().await.unwrap();
785 let filter_token_match = SearchCommandsFilter {
786 search_term: Some("commit".to_string()),
787 search_mode: SearchMode::Exact,
788 ..Default::default()
789 };
790 let (commands_token_match, _) = storage
791 .find_commands(filter_token_match, "/some/path", &SearchCommandTuning::default())
792 .await
793 .unwrap();
794 assert_eq!(commands_token_match.len(), 2);
795 assert_eq!(commands_token_match[0].cmd, "git commit -m");
796 assert_eq!(commands_token_match[1].cmd, "git commit -m '{{message}}'");
797
798 let filter_no_match = SearchCommandsFilter {
799 search_term: Some("nonexistentterm".to_string()),
800 search_mode: SearchMode::Exact,
801 ..Default::default()
802 };
803 let (commands_no_match, _) = storage
804 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
805 .await
806 .unwrap();
807 assert!(commands_no_match.is_empty());
808 }
809
810 #[tokio::test]
811 async fn test_find_commands_search_mode_relaxed() {
812 let storage = setup_ranking_storage().await;
813 storage.setup_workspace_storage().await.unwrap();
814 let filter = SearchCommandsFilter {
815 search_term: Some("docker list".to_string()),
816 search_mode: SearchMode::Relaxed,
817 ..Default::default()
818 };
819 let (commands, _) = storage
820 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
821 .await
822 .unwrap();
823 assert_eq!(commands.len(), 2);
824 assert!(commands.iter().any(|c| c.cmd == "docker ps -a"));
825 assert!(commands.iter().any(|c| c.cmd == "ls -lha"));
826 }
827
828 #[tokio::test]
829 async fn test_find_commands_search_mode_regex() {
830 let storage = setup_ranking_storage().await;
831 storage.setup_workspace_storage().await.unwrap();
832 let filter = SearchCommandsFilter {
833 search_term: Some(r"git\s.*it".to_string()),
834 search_mode: SearchMode::Regex,
835 ..Default::default()
836 };
837 let (commands, _) = storage
838 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
839 .await
840 .unwrap();
841 assert_eq!(commands.len(), 2);
842 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
843 assert_eq!(commands[1].cmd, "git commit -m");
844
845 let filter_invalid = SearchCommandsFilter {
846 search_term: Some("[[invalid_regex".to_string()),
847 search_mode: SearchMode::Regex,
848 ..Default::default()
849 };
850 assert!(matches!(
851 storage
852 .find_commands(filter_invalid, "/some/path", &SearchCommandTuning::default())
853 .await,
854 Err(AppError::UserFacing(UserFacingError::InvalidRegex))
855 ));
856 }
857
858 #[tokio::test]
859 async fn test_find_commands_search_mode_fuzzy() {
860 let storage = setup_ranking_storage().await;
861 storage.setup_workspace_storage().await.unwrap();
862 let filter = SearchCommandsFilter {
863 search_term: Some("gtcomit".to_string()),
864 search_mode: SearchMode::Fuzzy,
865 ..Default::default()
866 };
867 let (commands, _) = storage
868 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
869 .await
870 .unwrap();
871 assert_eq!(commands.len(), 2);
872 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
873 assert_eq!(commands[1].cmd, "git commit -m");
874
875 let filter_empty_fuzzy = SearchCommandsFilter {
876 search_term: Some("'' | ^".to_string()),
877 search_mode: SearchMode::Fuzzy,
878 ..Default::default()
879 };
880 assert!(matches!(
881 storage
882 .find_commands(filter_empty_fuzzy, "/some/path", &SearchCommandTuning::default())
883 .await,
884 Err(AppError::UserFacing(UserFacingError::InvalidFuzzy))
885 ));
886 }
887
888 #[tokio::test]
889 async fn test_find_commands_search_mode_auto() {
890 let storage = setup_ranking_storage().await;
891 let default_tuning = SearchCommandTuning::default();
892
893 let run_search = |term: &'static str, path: &'static str| {
895 let storage = storage.clone();
896 async move {
897 let filter = SearchCommandsFilter {
898 search_term: Some(term.to_string()),
899 search_mode: SearchMode::Auto,
900 ..Default::default()
901 };
902 storage.find_commands(filter, path, &default_tuning).await.unwrap()
903 }
904 };
905
906 let (commands, _) = run_search("list containers", UNRELATED_PATH).await;
908 assert!(!commands.is_empty(), "Expected results for 'list containers'");
909 assert_eq!(
910 commands[0].cmd, "docker ps -a",
911 "Expected 'docker ps -a' to be the top result for 'list containers'"
912 );
913
914 let (commands, _) = run_search("git commit", PROJ_A_PATH).await;
916 assert!(commands.len() >= 2, "Expected at least two results for 'git commit'");
917 assert_eq!(
918 commands[0].cmd, "git commit -m",
919 "Expected 'git commit -m' to be the top result for 'git commit' due to usage"
920 );
921 assert_eq!(
922 commands[1].cmd, "git commit -m '{{message}}'",
923 "Expected template command to be second for 'git commit'"
924 );
925
926 let (commands, _) = run_search("git commit -m 'my new feature'", PROJ_A_PATH).await;
928 assert!(!commands.is_empty(), "Expected results for template match");
929 assert_eq!(
930 commands[0].cmd, "git commit -m '{{message}}'",
931 "Expected template command to be the top result for a matching search term"
932 );
933
934 let (commands, _) = run_search("build", PROJ_A_API_PATH).await;
936 assert!(!commands.is_empty(), "Expected results for 'build'");
937 assert_eq!(
938 commands[0].cmd, "npm run build:prod",
939 "Expected 'npm run build:prod' to be top result for 'build' in its project path"
940 );
941
942 let (commands, _) = run_search("gt sta", PROJ_A_PATH).await;
944 assert!(!commands.is_empty(), "Expected results for fuzzy search 'gt sta'");
945 assert_eq!(
946 commands[0].cmd, "git status",
947 "Expected 'git status' as top result for fuzzy search 'gt sta'"
948 );
949
950 let (commands, _) = run_search("get pod monitoring", UNRELATED_PATH).await;
952 assert!(!commands.is_empty(), "Expected results for 'get pod monitoring'");
953 assert_eq!(
954 commands[0].cmd, "kubectl get pod -n monitoring my-specific-pod-12345",
955 "Expected specific 'kubectl' command to be found"
956 );
957
958 let (commands, _) = run_search("status", PROJ_A_API_PATH).await;
960 assert!(!commands.is_empty(), "Expected results for 'status'");
961 assert_eq!(
962 commands[0].cmd, "git status",
963 "Expected 'git status' to be top due to high usage in parent path"
964 );
965 }
966
967 #[tokio::test]
968 async fn test_find_commands_search_mode_auto_hastag_only() {
969 let storage = setup_ranking_storage().await;
970
971 let filter = SearchCommandsFilter {
974 search_term: Some("#".to_string()),
975 search_mode: SearchMode::Auto,
976 ..Default::default()
977 };
978
979 let res = storage
980 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
981 .await;
982 assert!(res.is_ok(), "Expected a success response, got: {res:?}")
983 }
984
985 #[tokio::test]
986 async fn test_find_commands_including_workspace() {
987 let storage = setup_ranking_storage().await;
988
989 storage.setup_workspace_storage().await.unwrap();
990 let commands_to_import = vec![
991 ImportExportItem::Command(Command {
992 id: Uuid::now_v7(),
993 cmd: "cmd1".to_string(),
994 ..Default::default()
995 }),
996 ImportExportItem::Command(Command {
997 id: Uuid::now_v7(),
998 cmd: "cmd2".to_string(),
999 ..Default::default()
1000 }),
1001 ];
1002 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1003 storage.import_items(stream, false, true).await.unwrap();
1004
1005 let (commands, _) = storage
1006 .find_commands(
1007 SearchCommandsFilter::default(),
1008 "/some/path",
1009 &SearchCommandTuning::default(),
1010 )
1011 .await
1012 .unwrap();
1013 assert_eq!(commands.len(), 12, "Expected 12 commands including workspace");
1014 }
1015
1016 #[tokio::test]
1017 async fn test_find_commands_with_text_including_workspace() {
1018 let storage = setup_ranking_storage().await;
1019
1020 storage.setup_workspace_storage().await.unwrap();
1021 let commands_to_import = vec![ImportExportItem::Command(Command {
1022 id: Uuid::now_v7(),
1023 cmd: "git checkout -b feature/{{name:kebab}}".to_string(),
1024 ..Default::default()
1025 })];
1026 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1027 storage.import_items(stream, false, true).await.unwrap();
1028
1029 let filter = SearchCommandsFilter {
1030 search_term: Some("git".to_string()),
1031 ..Default::default()
1032 };
1033
1034 let (commands, _) = storage
1035 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1036 .await
1037 .unwrap();
1038 assert_eq!(commands.len(), 6, "Expected 6 git commands including workspace");
1039 assert!(
1040 commands
1041 .iter()
1042 .any(|c| c.cmd == "git checkout -b feature/{{name:kebab}}")
1043 );
1044 }
1045
1046 #[tokio::test]
1047 async fn test_delete_tldr_commands() {
1048 let storage = SqliteStorage::new_in_memory().await.unwrap();
1049
1050 let tldr_cmd1 = Command {
1052 id: Uuid::now_v7(),
1053 category: "git".to_string(),
1054 source: SOURCE_TLDR.to_string(),
1055 cmd: "git status".to_string(),
1056 ..Default::default()
1057 };
1058 let tldr_cmd2 = Command {
1059 id: Uuid::now_v7(),
1060 category: "docker".to_string(),
1061 source: SOURCE_TLDR.to_string(),
1062 cmd: "docker ps".to_string(),
1063 ..Default::default()
1064 };
1065 let user_cmd = Command {
1066 id: Uuid::now_v7(),
1067 category: "git".to_string(),
1068 source: SOURCE_USER.to_string(),
1069 cmd: "git log".to_string(),
1070 ..Default::default()
1071 };
1072
1073 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1074 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1075 storage.insert_command(user_cmd.clone()).await.unwrap();
1076
1077 let removed = storage.delete_tldr_commands(None).await.unwrap();
1079 assert_eq!(removed, 2, "Should remove both tldr commands");
1080
1081 let (remaining, _) = storage
1082 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1083 .await
1084 .unwrap();
1085 assert_eq!(remaining.len(), 1, "Only user command should remain");
1086 assert_eq!(remaining[0].cmd, user_cmd.cmd);
1087
1088 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1090 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1091
1092 let removed_git = storage.delete_tldr_commands(Some("git".to_string())).await.unwrap();
1094 assert_eq!(removed_git, 1, "Should remove one tldr command in 'git' category");
1095
1096 let (remaining, _) = storage
1097 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1098 .await
1099 .unwrap();
1100 let remaining_cmds: Vec<_> = remaining.iter().map(|c| &c.cmd).collect();
1101 assert!(remaining_cmds.contains(&&tldr_cmd2.cmd));
1102 assert!(remaining_cmds.contains(&&user_cmd.cmd));
1103 assert!(!remaining_cmds.contains(&&tldr_cmd1.cmd));
1104 }
1105
1106 #[tokio::test]
1107 async fn test_insert_command() {
1108 let storage = SqliteStorage::new_in_memory().await.unwrap();
1109
1110 let mut cmd = Command {
1111 id: Uuid::now_v7(),
1112 category: "test".to_string(),
1113 cmd: "test_cmd".to_string(),
1114 description: Some("test desc".to_string()),
1115 tags: Some(vec!["tag1".to_string()]),
1116 ..Default::default()
1117 };
1118
1119 let mut inserted = storage.insert_command(cmd.clone()).await.unwrap();
1120 assert_eq!(inserted.cmd, cmd.cmd);
1121
1122 inserted.cmd = "other_cmd".to_string();
1124 match storage.insert_command(inserted).await {
1125 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1126 _ => panic!("Expected CommandAlreadyExists error on duplicate id"),
1127 }
1128
1129 cmd.id = Uuid::now_v7();
1131 match storage.insert_command(cmd).await {
1132 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1133 _ => panic!("Expected CommandAlreadyExists error on duplicate cmd"),
1134 }
1135 }
1136
1137 #[tokio::test]
1138 async fn test_update_command() {
1139 let storage = SqliteStorage::new_in_memory().await.unwrap();
1140
1141 let cmd = Command {
1142 id: Uuid::now_v7(),
1143 cmd: "original".to_string(),
1144 description: Some("desc".to_string()),
1145 ..Default::default()
1146 };
1147
1148 storage.insert_command(cmd.clone()).await.unwrap();
1149
1150 let mut updated = cmd.clone();
1151 updated.cmd = "updated".to_string();
1152 updated.description = Some("new desc".to_string());
1153
1154 let result = storage.update_command(updated.clone()).await.unwrap();
1155 assert_eq!(result.cmd, "updated");
1156 assert_eq!(result.description, Some("new desc".to_string()));
1157
1158 let mut non_existent = cmd;
1160 non_existent.id = Uuid::now_v7();
1161 match storage.update_command(non_existent).await {
1162 Err(_) => (),
1163 _ => panic!("Expected error when updating non-existent command"),
1164 }
1165
1166 let another_cmd = Command {
1168 id: Uuid::now_v7(),
1169 cmd: "another".to_string(),
1170 ..Default::default()
1171 };
1172 let mut result = storage.insert_command(another_cmd.clone()).await.unwrap();
1173 result.cmd = "updated".to_string();
1174 match storage.update_command(result).await {
1175 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1176 _ => panic!("Expected CommandAlreadyExists error when updating to existing cmd"),
1177 }
1178 }
1179
1180 #[tokio::test]
1181 async fn test_increment_command_usage() {
1182 let storage = SqliteStorage::new_in_memory().await.unwrap();
1183
1184 let command = storage
1186 .setup_command(
1187 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
1188 [("/some/path", 100)],
1189 )
1190 .await;
1191
1192 let count = storage.increment_command_usage(command.id, "/path").await.unwrap();
1194 assert_eq!(count, 1);
1195
1196 let count = storage.increment_command_usage(command.id, "/some/path").await.unwrap();
1198 assert_eq!(count, 101);
1199 }
1200
1201 #[tokio::test]
1202 async fn test_delete_command() {
1203 let storage = SqliteStorage::new_in_memory().await.unwrap();
1204
1205 let cmd = Command {
1206 id: Uuid::now_v7(),
1207 cmd: "to_delete".to_string(),
1208 ..Default::default()
1209 };
1210
1211 let cmd = storage.insert_command(cmd).await.unwrap();
1212 let res = storage.delete_command(cmd.id).await;
1213 assert!(res.is_ok());
1214
1215 match storage.delete_command(cmd.id).await {
1217 Err(_) => (),
1218 _ => panic!("Expected error when deleting non-existent command"),
1219 }
1220 }
1221
1222 async fn setup_ranking_storage() -> SqliteStorage {
1224 let storage = SqliteStorage::new_in_memory().await.unwrap();
1225 storage
1226 .setup_command(
1227 Command::new(
1228 CATEGORY_USER,
1229 SOURCE_USER,
1230 "kubectl get pod -n monitoring my-specific-pod-12345",
1231 )
1232 .with_description(Some(
1233 "Get a very specific pod by its full name in the monitoring namespace".to_string(),
1234 ))
1235 .with_tags(Some(vec!["#k8s".to_string(), "#pod".to_string()])),
1236 [("/other/path", 1)],
1237 )
1238 .await;
1239 storage
1240 .setup_command(
1241 Command::new(CATEGORY_USER, SOURCE_USER, "git status")
1242 .with_description(Some("Check the status of the git repository".to_string()))
1243 .with_tags(Some(vec!["#git".to_string()])),
1244 [(PROJ_A_PATH, 50), (PROJ_B_PATH, 50), (UNRELATED_PATH, 100)],
1245 )
1246 .await;
1247 storage
1248 .setup_command(
1249 Command::new(CATEGORY_USER, SOURCE_USER, "npm run build:prod")
1250 .with_description(Some("Build the project for production".to_string()))
1251 .with_tags(Some(vec!["#npm".to_string(), "#build".to_string()])),
1252 [(PROJ_A_API_PATH, 25)],
1253 )
1254 .await;
1255 storage
1256 .setup_command(
1257 Command::new(CATEGORY_USER, SOURCE_USER, "container-image-build.sh")
1258 .with_description(Some("A generic script to build a container image".to_string()))
1259 .with_tags(Some(vec!["#docker".to_string(), "#build".to_string()])),
1260 [(UNRELATED_PATH, 35)],
1261 )
1262 .await;
1263 storage
1264 .setup_command(
1265 Command::new(CATEGORY_USER, SOURCE_USER, "git commit -m '{{message}}'")
1266 .with_description(Some("Commit with a message".to_string()))
1267 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1268 [(PROJ_A_PATH, 10), (PROJ_B_PATH, 10)],
1269 )
1270 .await;
1271 storage
1272 .setup_command(
1273 Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")
1274 .with_alias(Some("gco".to_string()))
1275 .with_description(Some("Checkout the main branch".to_string()))
1276 .with_tags(Some(vec!["#git".to_string()])),
1277 [(PROJ_A_PATH, 30), (PROJ_B_PATH, 30)],
1278 )
1279 .await;
1280 storage
1281 .setup_command(
1282 Command::new("git", SOURCE_TLDR, "git commit -m")
1283 .with_alias(Some("gc".to_string()))
1284 .with_description(Some("Commit changes".to_string()))
1285 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1286 [(PROJ_A_PATH, 15)],
1287 )
1288 .await;
1289 storage
1290 .setup_command(
1291 Command::new("docker", SOURCE_TLDR, "docker ps -a")
1292 .with_description(Some("List all containers".to_string()))
1293 .with_tags(Some(vec!["#docker".to_string(), "#list".to_string()])),
1294 [(PROJ_A_PATH, 5), (PROJ_B_PATH, 5)],
1295 )
1296 .await;
1297 storage
1298 .setup_command(
1299 Command::new("git", SOURCE_TLDR, "git push")
1300 .with_description(Some("Push changes".to_string()))
1301 .with_tags(Some(vec!["#git".to_string(), "#push".to_string()])),
1302 [(PROJ_A_PATH, 20), (PROJ_B_PATH, 20)],
1303 )
1304 .await;
1305 storage
1306 .setup_command(
1307 Command::new(CATEGORY_USER, SOURCE_IMPORT, "ls -lha")
1308 .with_description(Some("List files".to_string()))
1309 .with_tags(Some(vec!["#unix".to_string(), "#list".to_string()])),
1310 [(PROJ_A_PATH, 100), (PROJ_B_PATH, 100), (UNRELATED_PATH, 100)],
1311 )
1312 .await;
1313
1314 storage
1315 }
1316
1317 impl SqliteStorage {
1318 async fn check_sqlite_version(&self) {
1320 let version: String = self
1321 .client
1322 .conn_mut(|conn| {
1323 conn.query_row("SELECT sqlite_version()", [], |row| row.get(0))
1324 .map_err(Into::into)
1325 })
1326 .await
1327 .unwrap();
1328 println!("Running with SQLite version: {version}");
1329 }
1330
1331 async fn setup_command(
1334 &self,
1335 command: Command,
1336 usage: impl IntoIterator<Item = (&str, i32)> + Send + 'static,
1337 ) -> Command {
1338 let command = self.insert_command(command).await.unwrap();
1339 self.client
1340 .conn_mut(move |conn| {
1341 for (path, usage_count) in usage {
1342 conn.execute(
1343 r#"
1344 INSERT INTO command_usage (command_id, path, usage_count)
1345 VALUES (?1, ?2, ?3)
1346 ON CONFLICT(command_id, path) DO UPDATE SET
1347 usage_count = excluded.usage_count"#,
1348 (&command.id, path, usage_count),
1349 )?;
1350 }
1351 Ok(command)
1352 })
1353 .await
1354 .unwrap()
1355 }
1356 }
1357}