1use std::{cmp::Ordering, sync::atomic::Ordering as AtomicOrdering};
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{Row, fallible_iterator::FallibleIterator, ffi, types::Type};
5use sea_query::SqliteQueryBuilder;
6use sea_query_rusqlite::RusqliteBinder;
7use tracing::instrument;
8use uuid::Uuid;
9
10use super::{SqliteStorage, queries::*};
11use crate::{
12 config::SearchCommandTuning,
13 errors::{Result, UserFacingError},
14 model::{Command, SOURCE_TLDR, SearchCommandsFilter},
15};
16
17impl SqliteStorage {
18 #[instrument(skip_all)]
21 pub async fn setup_workspace_storage(&self) -> Result<()> {
22 tracing::trace!("Creating workspace-specific tables");
23 self.client
24 .conn_mut(|conn| {
25 let schemas: Vec<String> = conn
27 .prepare(
28 r"SELECT sql
29 FROM sqlite_master
30 WHERE (type = 'table' AND name = 'variable_completion')
31 OR (type = 'table' AND name = 'command')
32 OR (type = 'table' AND name LIKE 'command_%fts')
33 OR (type = 'trigger' AND name LIKE 'command_%_fts' AND tbl_name = 'command')",
34 )?
35 .query_map([], |row| row.get(0))?
36 .collect::<Result<Vec<String>, _>>()?;
37
38 let tx = conn.transaction()?;
39
40 for schema in schemas {
42 let temp_schema = schema
43 .replace("variable_completion", "workspace_variable_completion")
44 .replace("command", "workspace_command")
45 .replace("CREATE TABLE ", "CREATE TEMP TABLE ")
46 .replace("CREATE VIRTUAL TABLE ", "CREATE VIRTUAL TABLE temp.")
47 .replace("CREATE TRIGGER ", "CREATE TEMP TRIGGER ");
48 tracing::trace!("Executing:\n{temp_schema}");
49 tx.execute(&temp_schema, [])?;
50 }
51
52 tx.commit()?;
53 Ok(())
54 })
55 .await?;
56
57 self.workspace_tables_loaded.store(true, AtomicOrdering::SeqCst);
58
59 Ok(())
60 }
61
62 #[instrument(skip_all)]
64 pub async fn is_empty(&self) -> Result<bool> {
65 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
66 self.client
67 .conn(move |conn| {
68 let query = if workspace_tables_loaded {
69 "SELECT NOT EXISTS (SELECT 1 FROM command UNION ALL SELECT 1 FROM workspace_command)"
70 } else {
71 "SELECT NOT EXISTS(SELECT 1 FROM command)"
72 };
73 tracing::trace!("Checking if storage is empty:\n{query}");
74 Ok(conn.query_row(query, [], |r| r.get(0))?)
75 })
76 .await
77 }
78
79 #[instrument(skip_all)]
81 pub async fn find_tags(
82 &self,
83 filter: SearchCommandsFilter,
84 tag_prefix: Option<String>,
85 tuning: &SearchCommandTuning,
86 ) -> Result<Vec<(String, u64, bool)>> {
87 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
88 let query = query_find_tags(filter, tag_prefix, tuning, workspace_tables_loaded)?;
89 if tracing::enabled!(tracing::Level::TRACE) {
90 tracing::trace!("Querying tags:\n{}", query.to_string(SqliteQueryBuilder));
91 }
92 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
93 self.client
94 .conn(move |conn| {
95 conn.prepare(&stmt)?
96 .query(&*values.as_params())?
97 .and_then(|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))
98 .collect()
99 })
100 .await
101 }
102
103 #[instrument(skip_all)]
108 pub async fn find_commands(
109 &self,
110 filter: SearchCommandsFilter,
111 working_path: impl Into<String>,
112 tuning: &SearchCommandTuning,
113 ) -> Result<(Vec<Command>, bool)> {
114 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
115 let cleaned_filter = filter.cleaned();
116
117 let mut query_alias = None;
119 if let Some(ref term) = cleaned_filter.search_term {
120 let (query, params) = if workspace_tables_loaded {
122 (
123 format!(
124 r#"WITH commands_unified AS (
125 SELECT rowid, *, 0 AS is_workspace FROM command
126 UNION ALL
127 SELECT rowid, *, 1 AS is_workspace FROM workspace_command
128 ),
129 commands_ranked AS (
130 SELECT *, ROW_NUMBER() OVER (PARTITION BY alias ORDER BY is_workspace ASC) as _rank
131 FROM commands_unified
132 )
133 SELECT *
134 FROM commands_ranked
135 WHERE alias IS NOT NULL AND alias = ?1 AND _rank = 1
136 LIMIT {QUERY_LIMIT}"#
137 ),
138 (term.clone(),),
139 )
140 } else {
141 (
142 format!(
143 r#"SELECT c.rowid, c.*
144 FROM command c
145 WHERE c.alias IS NOT NULL AND c.alias = ?1
146 LIMIT {QUERY_LIMIT}"#
147 ),
148 (term.clone(),),
149 )
150 };
151 query_alias = Some((query, params));
152 }
153
154 let query = query_find_commands(cleaned_filter, working_path, tuning, workspace_tables_loaded)?;
156 let query_trace = if tracing::enabled!(tracing::Level::TRACE) {
157 query.to_string(SqliteQueryBuilder)
158 } else {
159 String::default()
160 };
161 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
162
163 let tuning = *tuning;
165 self.client
166 .conn(move |conn| {
167 if let Some((query_alias, a_params)) = query_alias {
169 tracing::trace!("Querying aliased commands:\n{query_alias}");
170 let rows = conn
172 .prepare(&query_alias)?
173 .query(a_params)?
174 .map(|r| Command::try_from(r))
175 .collect::<Vec<_>>()?;
176 if !rows.is_empty() {
178 return Ok((rows, true));
179 }
180 }
181 if tracing::enabled!(tracing::Level::TRACE) {
183 tracing::trace!("Querying commands:\n{query_trace}");
184 }
185 Ok((
186 rerank_query_results(
187 conn.prepare(&stmt)?
188 .query(&*values.as_params())?
189 .and_then(|r| QueryResultItem::try_from(r))
190 .collect::<Result<Vec<_>, _>>()?,
191 &tuning,
192 ),
193 false,
194 ))
195 })
196 .await
197 }
198
199 #[instrument(skip_all)]
201 pub async fn delete_tldr_commands(&self, category: Option<String>) -> Result<u64> {
202 self.client
203 .conn_mut(move |conn| {
204 let mut query = String::from("DELETE FROM command WHERE source = ?1");
205 let mut params: Vec<String> = vec![SOURCE_TLDR.to_owned()];
206 if let Some(cat) = category {
207 query.push_str(" AND category = ?2");
208 params.push(cat);
209 }
210 tracing::trace!("Deleting tldr commands:\n{query}");
211 let affected = conn.execute(&query, rusqlite::params_from_iter(params))?;
212 Ok(affected as u64)
213 })
214 .await
215 }
216
217 #[instrument(skip_all)]
221 pub async fn insert_command(&self, command: Command) -> Result<Command> {
222 self.client
223 .conn(move |conn| {
224 let query = r#"INSERT INTO command (
225 id,
226 category,
227 source,
228 alias,
229 cmd,
230 flat_cmd,
231 description,
232 flat_description,
233 tags,
234 created_at,
235 updated_at
236 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)"#;
237 tracing::trace!("Inserting a command:\n{query}");
238 let res = conn.execute(
239 query,
240 (
241 &command.id,
242 &command.category,
243 &command.source,
244 &command.alias,
245 &command.cmd,
246 &command.flat_cmd,
247 &command.description,
248 &command.flat_description,
249 serde_json::to_value(&command.tags)?,
250 &command.created_at,
251 &command.updated_at,
252 ),
253 );
254 match res {
255 Ok(_) => Ok(command),
256 Err(err) => {
257 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
258 if code == ffi::SQLITE_CONSTRAINT_UNIQUE || code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY {
259 Err(UserFacingError::CommandAlreadyExists.into())
260 } else {
261 Err(Report::from(err).into())
262 }
263 }
264 }
265 })
266 .await
267 }
268
269 #[instrument(skip_all)]
273 pub async fn update_command(&self, command: Command) -> Result<Command> {
274 self.client
275 .conn(move |conn| {
276 let query = r#"UPDATE command SET
277 category = ?2,
278 source = ?3,
279 alias = ?4,
280 cmd = ?5,
281 flat_cmd = ?6,
282 description = ?7,
283 flat_description = ?8,
284 tags = ?9,
285 created_at = ?10,
286 updated_at = ?11
287 WHERE id = ?1"#;
288 tracing::trace!("Updating a command:\n{query}");
289 let res = conn.execute(
290 query,
291 (
292 &command.id,
293 &command.category,
294 &command.source,
295 &command.alias,
296 &command.cmd,
297 &command.flat_cmd,
298 &command.description,
299 &command.flat_description,
300 serde_json::to_value(&command.tags)?,
301 &command.created_at,
302 &command.updated_at,
303 ),
304 );
305 match res {
306 Ok(0) => Err(eyre!("Command not found: {}", command.id).into()),
307 Ok(_) => Ok(command),
308 Err(err) => {
309 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
310 if code == ffi::SQLITE_CONSTRAINT_UNIQUE {
311 Err(UserFacingError::CommandAlreadyExists.into())
312 } else {
313 Err(Report::from(err).into())
314 }
315 }
316 }
317 })
318 .await
319 }
320
321 #[instrument(skip_all)]
323 pub async fn increment_command_usage(
324 &self,
325 command_id: Uuid,
326 path: impl AsRef<str> + Send + 'static,
327 ) -> Result<i32> {
328 self.client
329 .conn_mut(move |conn| {
330 let query = r#"
331 INSERT INTO command_usage (command_id, path, usage_count)
332 VALUES (?1, ?2, 1)
333 ON CONFLICT(command_id, path) DO UPDATE SET
334 usage_count = usage_count + 1
335 RETURNING usage_count;"#;
336 tracing::trace!("Incrementing command usage:\n{query}");
337 Ok(conn.query_row(query, (&command_id, &path.as_ref()), |r| r.get(0))?)
338 })
339 .await
340 }
341
342 #[instrument(skip_all)]
346 pub async fn delete_command(&self, command_id: Uuid) -> Result<()> {
347 self.client
348 .conn(move |conn| {
349 let query = "DELETE FROM command WHERE id = ?1";
350 tracing::trace!("Deleting command:\n{query}");
351 let res = conn.execute(query, (&command_id,));
352 match res {
353 Ok(0) => Err(eyre!("Command not found: {command_id}").into()),
354 Ok(_) => Ok(()),
355 Err(err) => Err(Report::from(err).into()),
356 }
357 })
358 .await
359 }
360}
361
362fn rerank_query_results(items: Vec<QueryResultItem>, tuning: &SearchCommandTuning) -> Vec<Command> {
371 if items.is_empty() {
373 return Vec::new();
374 }
375 if items.len() == 1 {
376 return items.into_iter().map(|item| item.command).collect();
377 }
378
379 let (template_matches, mut other_items): (Vec<_>, Vec<_>) = items
382 .into_iter()
383 .partition(|item| item.text_score >= TEMPLATE_MATCH_RANK);
384 if !template_matches.is_empty() {
385 tracing::trace!("Found {} template matches", template_matches.len());
386 }
387
388 let mut final_commands: Vec<Command> = template_matches.into_iter().map(|item| item.command).collect();
390
391 if other_items.len() <= 1 {
393 final_commands.extend(other_items.into_iter().map(|item| item.command));
394 return final_commands;
395 }
396
397 let mut min_text = 0f64;
400 let mut min_path = 0f64;
401 let mut min_usage = 0f64;
402 let mut max_text = f64::NEG_INFINITY;
403 let mut max_path = f64::NEG_INFINITY;
404 let mut max_usage = f64::NEG_INFINITY;
405 for item in &other_items {
406 min_text = min_text.min(item.text_score);
407 max_text = max_text.max(item.text_score);
408 min_path = min_path.min(item.path_score);
409 max_path = max_path.max(item.path_score);
410 min_usage = min_usage.min(item.usage_score);
411 max_usage = max_usage.max(item.usage_score);
412 }
413
414 let range_text = (max_text > min_text).then_some(max_text - min_text);
416 let range_path = (max_path > min_path).then_some(max_path - min_path);
417 let range_usage = (max_usage > min_usage).then_some(max_usage - min_usage);
418
419 other_items.sort_by(|a, b| {
421 match b.is_workspace_command.cmp(&a.is_workspace_command) {
423 Ordering::Equal => {
424 let calculate_score = |item: &QueryResultItem| -> f64 {
426 let norm_text = range_text.map_or(0.5, |range| (item.text_score - min_text) / range);
429 let norm_path = range_path.map_or(0.5, |range| (item.path_score - min_path) / range);
430 let norm_usage = range_usage.map_or(0.5, |range| (item.usage_score - min_usage) / range);
431
432 (norm_text * tuning.text.points as f64)
434 + (norm_path * tuning.path.points as f64)
435 + (norm_usage * tuning.usage.points as f64)
436 };
437
438 let final_score_a = calculate_score(a);
439 let final_score_b = calculate_score(b);
440
441 final_score_b.partial_cmp(&final_score_a).unwrap_or(Ordering::Equal)
443 }
444 other => other,
446 }
447 });
448
449 final_commands.extend(other_items.into_iter().map(|item| item.command));
451 final_commands
452}
453
454impl<'a> TryFrom<&'a Row<'a>> for Command {
455 type Error = rusqlite::Error;
456
457 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
458 Ok(Self {
459 id: row.get(1)?,
461 category: row.get(2)?,
462 source: row.get(3)?,
463 alias: row.get(4)?,
464 cmd: row.get(5)?,
465 flat_cmd: row.get(6)?,
466 description: row.get(7)?,
467 flat_description: row.get(8)?,
468 tags: serde_json::from_value(row.get::<_, serde_json::Value>(9)?)
469 .map_err(|e| rusqlite::Error::FromSqlConversionFailure(9, Type::Text, Box::new(e)))?,
470 created_at: row.get(10)?,
471 updated_at: row.get(11)?,
472 is_destructive: false,
473 })
474 }
475}
476
477struct QueryResultItem {
479 command: Command,
481 is_workspace_command: bool,
483 usage_score: f64,
485 path_score: f64,
487 text_score: f64,
489}
490
491impl<'a> TryFrom<&'a Row<'a>> for QueryResultItem {
492 type Error = rusqlite::Error;
493
494 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
495 Ok(Self {
496 command: Command::try_from(row)?,
497 is_workspace_command: row.get(12)?,
498 usage_score: row.get(13)?,
499 path_score: row.get(14)?,
500 text_score: row.get(15)?,
501 })
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use pretty_assertions::assert_eq;
508 use strum::IntoEnumIterator;
509 use tokio_stream::iter;
510 use uuid::Uuid;
511
512 use super::*;
513 use crate::{
514 errors::AppError,
515 model::{CATEGORY_USER, ImportExportItem, SOURCE_IMPORT, SOURCE_USER, SearchMode},
516 };
517
518 const PROJ_A_PATH: &str = "/home/user/project-a";
519 const PROJ_A_API_PATH: &str = "/home/user/project-a/api";
520 const PROJ_B_PATH: &str = "/home/user/project-b";
521 const UNRELATED_PATH: &str = "/var/log";
522
523 #[tokio::test]
524 async fn test_setup_workspace_storage() {
525 let storage = SqliteStorage::new_in_memory().await.unwrap();
526 storage.check_sqlite_version().await;
527 let res = storage.setup_workspace_storage().await;
528 assert!(res.is_ok(), "Expected workspace storage setup to succeed: {res:?}");
529 }
530
531 #[tokio::test]
532 async fn test_is_empty() {
533 let storage = SqliteStorage::new_in_memory().await.unwrap();
534 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
535
536 let cmd = Command {
537 id: Uuid::now_v7(),
538 cmd: "test_cmd".to_string(),
539 ..Default::default()
540 };
541 storage.insert_command(cmd).await.unwrap();
542
543 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
544 }
545
546 #[tokio::test]
547 async fn test_is_empty_with_workspace() {
548 let storage = SqliteStorage::new_in_memory().await.unwrap();
549 storage.setup_workspace_storage().await.unwrap();
550 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
551
552 let cmd = Command {
553 id: Uuid::now_v7(),
554 cmd: "test_cmd".to_string(),
555 ..Default::default()
556 };
557 storage.insert_command(cmd).await.unwrap();
558
559 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
560 }
561
562 #[tokio::test]
563 async fn test_find_tags_no_filters() -> Result<()> {
564 let storage = setup_ranking_storage().await;
565
566 let result = storage
567 .find_tags(SearchCommandsFilter::default(), None, &SearchCommandTuning::default())
568 .await?;
569
570 let expected = vec![
571 ("#git".to_string(), 5, false),
572 ("#build".to_string(), 2, false),
573 ("#commit".to_string(), 2, false),
574 ("#docker".to_string(), 2, false),
575 ("#list".to_string(), 2, false),
576 ("#k8s".to_string(), 1, false),
577 ("#npm".to_string(), 1, false),
578 ("#pod".to_string(), 1, false),
579 ("#push".to_string(), 1, false),
580 ("#unix".to_string(), 1, false),
581 ];
582
583 assert_eq!(result.len(), 10, "Expected 10 unique tags");
584 assert_eq!(result, expected, "Tags list or order mismatch");
585
586 Ok(())
587 }
588
589 #[tokio::test]
590 async fn test_find_tags_filter_by_tags_only() -> Result<()> {
591 let storage = setup_ranking_storage().await;
592
593 let filter1 = SearchCommandsFilter {
594 tags: Some(vec!["#git".to_string()]),
595 ..Default::default()
596 };
597 let result1 = storage
598 .find_tags(filter1, None, &SearchCommandTuning::default())
599 .await?;
600 let expected1 = vec![("#commit".to_string(), 2, false), ("#push".to_string(), 1, false)];
601 assert_eq!(result1.len(), 2,);
602 assert_eq!(result1, expected1);
603
604 let filter2 = SearchCommandsFilter {
605 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
606 ..Default::default()
607 };
608 let result2 = storage
609 .find_tags(filter2, None, &SearchCommandTuning::default())
610 .await?;
611 assert!(result2.is_empty());
612
613 let filter3 = SearchCommandsFilter {
614 tags: Some(vec!["#list".to_string()]),
615 ..Default::default()
616 };
617 let result3 = storage
618 .find_tags(filter3, None, &SearchCommandTuning::default())
619 .await?;
620 let expected3 = vec![("#docker".to_string(), 1, false), ("#unix".to_string(), 1, false)];
621 assert_eq!(result3.len(), 2);
622 assert_eq!(result3, expected3);
623
624 Ok(())
625 }
626
627 #[tokio::test]
628 async fn test_find_tags_filter_by_prefix_only() -> Result<()> {
629 let storage = setup_ranking_storage().await;
630
631 let result = storage
632 .find_tags(
633 SearchCommandsFilter::default(),
634 Some("#comm".to_string()),
635 &SearchCommandTuning::default(),
636 )
637 .await?;
638 let expected = vec![("#commit".to_string(), 2, false)];
639 assert_eq!(result.len(), 1);
640 assert_eq!(result, expected);
641
642 Ok(())
643 }
644
645 #[tokio::test]
646 async fn test_find_tags_filter_by_tags_and_prefix() -> Result<()> {
647 let storage = setup_ranking_storage().await;
648
649 let filter1 = SearchCommandsFilter {
650 tags: Some(vec!["#git".to_string()]),
651 ..Default::default()
652 };
653 let result1 = storage
654 .find_tags(filter1, Some("#comm".to_string()), &SearchCommandTuning::default())
655 .await?;
656 let expected1 = vec![("#commit".to_string(), 2, false)];
657 assert_eq!(result1.len(), 1);
658 assert_eq!(result1, expected1);
659
660 let filter2 = SearchCommandsFilter {
661 tags: Some(vec!["#git".to_string()]),
662 ..Default::default()
663 };
664 let result2 = storage
665 .find_tags(filter2, Some("#push".to_string()), &SearchCommandTuning::default())
666 .await?;
667 let expected2 = vec![("#push".to_string(), 1, true)];
668 assert_eq!(result2.len(), 1);
669 assert_eq!(result2, expected2);
670
671 Ok(())
672 }
673
674 #[tokio::test]
675 async fn test_find_commands_no_filter() {
676 let storage = setup_ranking_storage().await;
677 let filter = SearchCommandsFilter::default();
678 let (commands, _) = storage
679 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
680 .await
681 .unwrap();
682 assert_eq!(commands.len(), 10, "Expected all sample commands");
683 }
684
685 #[tokio::test]
686 async fn test_find_commands_filter_by_category() {
687 let storage = setup_ranking_storage().await;
688 let filter = SearchCommandsFilter {
689 category: Some(vec!["git".to_string()]),
690 ..Default::default()
691 };
692 let (commands, _) = storage
693 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
694 .await
695 .unwrap();
696 assert_eq!(commands.len(), 2);
697 assert!(commands.iter().all(|c| c.category == "git"));
698
699 let filter_no_match = SearchCommandsFilter {
700 category: Some(vec!["nonexistent".to_string()]),
701 ..Default::default()
702 };
703 let (commands_no_match, _) = storage
704 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
705 .await
706 .unwrap();
707 assert!(commands_no_match.is_empty());
708 }
709
710 #[tokio::test]
711 async fn test_find_commands_filter_by_source() {
712 let storage = setup_ranking_storage().await;
713 let filter = SearchCommandsFilter {
714 source: Some(SOURCE_TLDR.to_string()),
715 ..Default::default()
716 };
717 let (commands, _) = storage
718 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
719 .await
720 .unwrap();
721 assert_eq!(commands.len(), 3);
722 assert!(commands.iter().all(|c| c.source == SOURCE_TLDR));
723 }
724
725 #[tokio::test]
726 async fn test_find_commands_filter_by_tags() {
727 let storage = setup_ranking_storage().await;
728 let filter_single_tag = SearchCommandsFilter {
729 tags: Some(vec!["#git".to_string()]),
730 ..Default::default()
731 };
732 let (commands_single_tag, _) = storage
733 .find_commands(filter_single_tag, "/some/path", &SearchCommandTuning::default())
734 .await
735 .unwrap();
736 assert_eq!(commands_single_tag.len(), 5);
737
738 let filter_multiple_tags = SearchCommandsFilter {
739 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
740 ..Default::default()
741 };
742 let (commands_multiple_tags, _) = storage
743 .find_commands(filter_multiple_tags, "/some/path", &SearchCommandTuning::default())
744 .await
745 .unwrap();
746 assert_eq!(commands_multiple_tags.len(), 1);
747
748 let filter_empty_tags = SearchCommandsFilter {
749 tags: Some(vec![]),
750 ..Default::default()
751 };
752 let (commands_empty_tags, _) = storage
753 .find_commands(filter_empty_tags, "/some/path", &SearchCommandTuning::default())
754 .await
755 .unwrap();
756 assert_eq!(commands_empty_tags.len(), 10);
757 }
758
759 #[tokio::test]
760 async fn test_find_commands_alias_precedence() {
761 let storage = setup_ranking_storage().await;
762 storage
763 .setup_command(
764 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
765 [("/some/path", 100)],
766 )
767 .await;
768
769 for mode in SearchMode::iter() {
770 let filter = SearchCommandsFilter {
771 search_term: Some("gc".to_string()),
772 search_mode: mode,
773 ..Default::default()
774 };
775 let (commands, alias_match) = storage
776 .find_commands(filter, "", &SearchCommandTuning::default())
777 .await
778 .unwrap();
779 assert!(alias_match, "Expected alias match for mode {mode:?}");
780 assert_eq!(commands.len(), 1, "Expected only alias match for mode {mode:?}");
781 assert_eq!(
782 commands[0].cmd, "git commit -m",
783 "Expected correct alias command for mode {mode:?}"
784 );
785 }
786 }
787
788 #[tokio::test]
789 async fn test_find_commands_search_mode_exact() {
790 let storage = setup_ranking_storage().await;
791 storage.setup_workspace_storage().await.unwrap();
792 let filter_token_match = SearchCommandsFilter {
793 search_term: Some("commit".to_string()),
794 search_mode: SearchMode::Exact,
795 ..Default::default()
796 };
797 let (commands_token_match, _) = storage
798 .find_commands(filter_token_match, "/some/path", &SearchCommandTuning::default())
799 .await
800 .unwrap();
801 assert_eq!(commands_token_match.len(), 2);
802 assert_eq!(commands_token_match[0].cmd, "git commit -m");
803 assert_eq!(commands_token_match[1].cmd, "git commit -m '{{message}}'");
804
805 let filter_no_match = SearchCommandsFilter {
806 search_term: Some("nonexistentterm".to_string()),
807 search_mode: SearchMode::Exact,
808 ..Default::default()
809 };
810 let (commands_no_match, _) = storage
811 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
812 .await
813 .unwrap();
814 assert!(commands_no_match.is_empty());
815 }
816
817 #[tokio::test]
818 async fn test_find_commands_search_mode_relaxed() {
819 let storage = setup_ranking_storage().await;
820 storage.setup_workspace_storage().await.unwrap();
821 let filter = SearchCommandsFilter {
822 search_term: Some("docker list".to_string()),
823 search_mode: SearchMode::Relaxed,
824 ..Default::default()
825 };
826 let (commands, _) = storage
827 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
828 .await
829 .unwrap();
830 assert_eq!(commands.len(), 2);
831 assert!(commands.iter().any(|c| c.cmd == "docker ps -a"));
832 assert!(commands.iter().any(|c| c.cmd == "ls -lha"));
833 }
834
835 #[tokio::test]
836 async fn test_find_commands_search_mode_regex() {
837 let storage = setup_ranking_storage().await;
838 storage.setup_workspace_storage().await.unwrap();
839 let filter = SearchCommandsFilter {
840 search_term: Some(r"git\s.*it".to_string()),
841 search_mode: SearchMode::Regex,
842 ..Default::default()
843 };
844 let (commands, _) = storage
845 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
846 .await
847 .unwrap();
848 assert_eq!(commands.len(), 2);
849 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
850 assert_eq!(commands[1].cmd, "git commit -m");
851
852 let filter_invalid = SearchCommandsFilter {
853 search_term: Some("[[invalid_regex".to_string()),
854 search_mode: SearchMode::Regex,
855 ..Default::default()
856 };
857 assert!(matches!(
858 storage
859 .find_commands(filter_invalid, "/some/path", &SearchCommandTuning::default())
860 .await,
861 Err(AppError::UserFacing(UserFacingError::InvalidRegex))
862 ));
863 }
864
865 #[tokio::test]
866 async fn test_find_commands_search_mode_fuzzy() {
867 let storage = setup_ranking_storage().await;
868 storage.setup_workspace_storage().await.unwrap();
869 let filter = SearchCommandsFilter {
870 search_term: Some("gtcomit".to_string()),
871 search_mode: SearchMode::Fuzzy,
872 ..Default::default()
873 };
874 let (commands, _) = storage
875 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
876 .await
877 .unwrap();
878 assert_eq!(commands.len(), 2);
879 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
880 assert_eq!(commands[1].cmd, "git commit -m");
881
882 let filter_empty_fuzzy = SearchCommandsFilter {
883 search_term: Some("'' | ^".to_string()),
884 search_mode: SearchMode::Fuzzy,
885 ..Default::default()
886 };
887 assert!(matches!(
888 storage
889 .find_commands(filter_empty_fuzzy, "/some/path", &SearchCommandTuning::default())
890 .await,
891 Err(AppError::UserFacing(UserFacingError::InvalidFuzzy))
892 ));
893 }
894
895 #[tokio::test]
896 async fn test_find_commands_search_mode_auto() {
897 let storage = setup_ranking_storage().await;
898 let default_tuning = SearchCommandTuning::default();
899
900 let run_search = |term: &'static str, path: &'static str| {
902 let storage = storage.clone();
903 async move {
904 let filter = SearchCommandsFilter {
905 search_term: Some(term.to_string()),
906 search_mode: SearchMode::Auto,
907 ..Default::default()
908 };
909 storage.find_commands(filter, path, &default_tuning).await.unwrap()
910 }
911 };
912
913 let (commands, _) = run_search("list containers", UNRELATED_PATH).await;
915 assert!(!commands.is_empty(), "Expected results for 'list containers'");
916 assert_eq!(
917 commands[0].cmd, "docker ps -a",
918 "Expected 'docker ps -a' to be the top result for 'list containers'"
919 );
920
921 let (commands, _) = run_search("git commit", PROJ_A_PATH).await;
923 assert!(commands.len() >= 2, "Expected at least two results for 'git commit'");
924 assert_eq!(
925 commands[0].cmd, "git commit -m",
926 "Expected 'git commit -m' to be the top result for 'git commit' due to usage"
927 );
928 assert_eq!(
929 commands[1].cmd, "git commit -m '{{message}}'",
930 "Expected template command to be second for 'git commit'"
931 );
932
933 let (commands, _) = run_search("git commit -m 'my new feature'", PROJ_A_PATH).await;
935 assert!(!commands.is_empty(), "Expected results for template match");
936 assert_eq!(
937 commands[0].cmd, "git commit -m '{{message}}'",
938 "Expected template command to be the top result for a matching search term"
939 );
940
941 let (commands, _) = run_search("build", PROJ_A_API_PATH).await;
943 assert!(!commands.is_empty(), "Expected results for 'build'");
944 assert_eq!(
945 commands[0].cmd, "npm run build:prod",
946 "Expected 'npm run build:prod' to be top result for 'build' in its project path"
947 );
948
949 let (commands, _) = run_search("gt sta", PROJ_A_PATH).await;
951 assert!(!commands.is_empty(), "Expected results for fuzzy search 'gt sta'");
952 assert_eq!(
953 commands[0].cmd, "git status",
954 "Expected 'git status' as top result for fuzzy search 'gt sta'"
955 );
956
957 let (commands, _) = run_search("get pod monitoring", UNRELATED_PATH).await;
959 assert!(!commands.is_empty(), "Expected results for 'get pod monitoring'");
960 assert_eq!(
961 commands[0].cmd, "kubectl get pod -n monitoring my-specific-pod-12345",
962 "Expected specific 'kubectl' command to be found"
963 );
964
965 let (commands, _) = run_search("status", PROJ_A_API_PATH).await;
967 assert!(!commands.is_empty(), "Expected results for 'status'");
968 assert_eq!(
969 commands[0].cmd, "git status",
970 "Expected 'git status' to be top due to high usage in parent path"
971 );
972 }
973
974 #[tokio::test]
975 async fn test_find_commands_search_mode_auto_hastag_only() {
976 let storage = setup_ranking_storage().await;
977
978 let filter = SearchCommandsFilter {
981 search_term: Some("#".to_string()),
982 search_mode: SearchMode::Auto,
983 ..Default::default()
984 };
985
986 let res = storage
987 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
988 .await;
989 assert!(res.is_ok(), "Expected a success response, got: {res:?}")
990 }
991
992 #[tokio::test]
993 async fn test_find_commands_including_workspace() {
994 let storage = setup_ranking_storage().await;
995
996 storage.setup_workspace_storage().await.unwrap();
997 let commands_to_import = vec![
998 ImportExportItem::Command(Command {
999 id: Uuid::now_v7(),
1000 cmd: "cmd1".to_string(),
1001 ..Default::default()
1002 }),
1003 ImportExportItem::Command(Command {
1004 id: Uuid::now_v7(),
1005 cmd: "cmd2".to_string(),
1006 ..Default::default()
1007 }),
1008 ];
1009 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1010 storage.import_items(stream, false, true).await.unwrap();
1011
1012 let (commands, _) = storage
1013 .find_commands(
1014 SearchCommandsFilter::default(),
1015 "/some/path",
1016 &SearchCommandTuning::default(),
1017 )
1018 .await
1019 .unwrap();
1020 assert_eq!(commands.len(), 12, "Expected 12 commands including workspace");
1021 }
1022
1023 #[tokio::test]
1024 async fn test_find_commands_with_text_including_workspace() {
1025 let storage = setup_ranking_storage().await;
1026
1027 storage.setup_workspace_storage().await.unwrap();
1028 let commands_to_import = vec![ImportExportItem::Command(Command {
1029 id: Uuid::now_v7(),
1030 cmd: "git checkout -b feature/{{name:kebab}}".to_string(),
1031 ..Default::default()
1032 })];
1033 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1034 storage.import_items(stream, false, true).await.unwrap();
1035
1036 let filter = SearchCommandsFilter {
1037 search_term: Some("git".to_string()),
1038 ..Default::default()
1039 };
1040
1041 let (commands, _) = storage
1042 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1043 .await
1044 .unwrap();
1045 assert_eq!(commands.len(), 6, "Expected 6 git commands including workspace");
1046 assert!(
1047 commands
1048 .iter()
1049 .any(|c| c.cmd == "git checkout -b feature/{{name:kebab}}")
1050 );
1051 }
1052
1053 #[tokio::test]
1054 async fn test_delete_tldr_commands() {
1055 let storage = SqliteStorage::new_in_memory().await.unwrap();
1056
1057 let tldr_cmd1 = Command {
1059 id: Uuid::now_v7(),
1060 category: "git".to_string(),
1061 source: SOURCE_TLDR.to_string(),
1062 cmd: "git status".to_string(),
1063 ..Default::default()
1064 };
1065 let tldr_cmd2 = Command {
1066 id: Uuid::now_v7(),
1067 category: "docker".to_string(),
1068 source: SOURCE_TLDR.to_string(),
1069 cmd: "docker ps".to_string(),
1070 ..Default::default()
1071 };
1072 let user_cmd = Command {
1073 id: Uuid::now_v7(),
1074 category: "git".to_string(),
1075 source: SOURCE_USER.to_string(),
1076 cmd: "git log".to_string(),
1077 ..Default::default()
1078 };
1079
1080 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1081 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1082 storage.insert_command(user_cmd.clone()).await.unwrap();
1083
1084 let removed = storage.delete_tldr_commands(None).await.unwrap();
1086 assert_eq!(removed, 2, "Should remove both tldr commands");
1087
1088 let (remaining, _) = storage
1089 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1090 .await
1091 .unwrap();
1092 assert_eq!(remaining.len(), 1, "Only user command should remain");
1093 assert_eq!(remaining[0].cmd, user_cmd.cmd);
1094
1095 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1097 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1098
1099 let removed_git = storage.delete_tldr_commands(Some("git".to_string())).await.unwrap();
1101 assert_eq!(removed_git, 1, "Should remove one tldr command in 'git' category");
1102
1103 let (remaining, _) = storage
1104 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1105 .await
1106 .unwrap();
1107 let remaining_cmds: Vec<_> = remaining.iter().map(|c| &c.cmd).collect();
1108 assert!(remaining_cmds.contains(&&tldr_cmd2.cmd));
1109 assert!(remaining_cmds.contains(&&user_cmd.cmd));
1110 assert!(!remaining_cmds.contains(&&tldr_cmd1.cmd));
1111 }
1112
1113 #[tokio::test]
1114 async fn test_insert_command() {
1115 let storage = SqliteStorage::new_in_memory().await.unwrap();
1116
1117 let mut cmd = Command {
1118 id: Uuid::now_v7(),
1119 category: "test".to_string(),
1120 cmd: "test_cmd".to_string(),
1121 description: Some("test desc".to_string()),
1122 tags: Some(vec!["tag1".to_string()]),
1123 ..Default::default()
1124 };
1125
1126 let mut inserted = storage.insert_command(cmd.clone()).await.unwrap();
1127 assert_eq!(inserted.cmd, cmd.cmd);
1128
1129 inserted.cmd = "other_cmd".to_string();
1131 match storage.insert_command(inserted).await {
1132 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1133 _ => panic!("Expected CommandAlreadyExists error on duplicate id"),
1134 }
1135
1136 cmd.id = Uuid::now_v7();
1138 match storage.insert_command(cmd).await {
1139 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1140 _ => panic!("Expected CommandAlreadyExists error on duplicate cmd"),
1141 }
1142 }
1143
1144 #[tokio::test]
1145 async fn test_update_command() {
1146 let storage = SqliteStorage::new_in_memory().await.unwrap();
1147
1148 let cmd = Command {
1149 id: Uuid::now_v7(),
1150 cmd: "original".to_string(),
1151 description: Some("desc".to_string()),
1152 ..Default::default()
1153 };
1154
1155 storage.insert_command(cmd.clone()).await.unwrap();
1156
1157 let mut updated = cmd.clone();
1158 updated.cmd = "updated".to_string();
1159 updated.description = Some("new desc".to_string());
1160
1161 let result = storage.update_command(updated.clone()).await.unwrap();
1162 assert_eq!(result.cmd, "updated");
1163 assert_eq!(result.description, Some("new desc".to_string()));
1164
1165 let mut non_existent = cmd;
1167 non_existent.id = Uuid::now_v7();
1168 match storage.update_command(non_existent).await {
1169 Err(_) => (),
1170 _ => panic!("Expected error when updating non-existent command"),
1171 }
1172
1173 let another_cmd = Command {
1175 id: Uuid::now_v7(),
1176 cmd: "another".to_string(),
1177 ..Default::default()
1178 };
1179 let mut result = storage.insert_command(another_cmd.clone()).await.unwrap();
1180 result.cmd = "updated".to_string();
1181 match storage.update_command(result).await {
1182 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1183 _ => panic!("Expected CommandAlreadyExists error when updating to existing cmd"),
1184 }
1185 }
1186
1187 #[tokio::test]
1188 async fn test_increment_command_usage() {
1189 let storage = SqliteStorage::new_in_memory().await.unwrap();
1190
1191 let command = storage
1193 .setup_command(
1194 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
1195 [("/some/path", 100)],
1196 )
1197 .await;
1198
1199 let count = storage.increment_command_usage(command.id, "/path").await.unwrap();
1201 assert_eq!(count, 1);
1202
1203 let count = storage.increment_command_usage(command.id, "/some/path").await.unwrap();
1205 assert_eq!(count, 101);
1206 }
1207
1208 #[tokio::test]
1209 async fn test_delete_command() {
1210 let storage = SqliteStorage::new_in_memory().await.unwrap();
1211
1212 let cmd = Command {
1213 id: Uuid::now_v7(),
1214 cmd: "to_delete".to_string(),
1215 ..Default::default()
1216 };
1217
1218 let cmd = storage.insert_command(cmd).await.unwrap();
1219 let res = storage.delete_command(cmd.id).await;
1220 assert!(res.is_ok());
1221
1222 match storage.delete_command(cmd.id).await {
1224 Err(_) => (),
1225 _ => panic!("Expected error when deleting non-existent command"),
1226 }
1227 }
1228
1229 async fn setup_ranking_storage() -> SqliteStorage {
1231 let storage = SqliteStorage::new_in_memory().await.unwrap();
1232 storage
1233 .setup_command(
1234 Command::new(
1235 CATEGORY_USER,
1236 SOURCE_USER,
1237 "kubectl get pod -n monitoring my-specific-pod-12345",
1238 )
1239 .with_description(Some(
1240 "Get a very specific pod by its full name in the monitoring namespace".to_string(),
1241 ))
1242 .with_tags(Some(vec!["#k8s".to_string(), "#pod".to_string()])),
1243 [("/other/path", 1)],
1244 )
1245 .await;
1246 storage
1247 .setup_command(
1248 Command::new(CATEGORY_USER, SOURCE_USER, "git status")
1249 .with_description(Some("Check the status of the git repository".to_string()))
1250 .with_tags(Some(vec!["#git".to_string()])),
1251 [(PROJ_A_PATH, 50), (PROJ_B_PATH, 50), (UNRELATED_PATH, 100)],
1252 )
1253 .await;
1254 storage
1255 .setup_command(
1256 Command::new(CATEGORY_USER, SOURCE_USER, "npm run build:prod")
1257 .with_description(Some("Build the project for production".to_string()))
1258 .with_tags(Some(vec!["#npm".to_string(), "#build".to_string()])),
1259 [(PROJ_A_API_PATH, 25)],
1260 )
1261 .await;
1262 storage
1263 .setup_command(
1264 Command::new(CATEGORY_USER, SOURCE_USER, "container-image-build.sh")
1265 .with_description(Some("A generic script to build a container image".to_string()))
1266 .with_tags(Some(vec!["#docker".to_string(), "#build".to_string()])),
1267 [(UNRELATED_PATH, 35)],
1268 )
1269 .await;
1270 storage
1271 .setup_command(
1272 Command::new(CATEGORY_USER, SOURCE_USER, "git commit -m '{{message}}'")
1273 .with_description(Some("Commit with a message".to_string()))
1274 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1275 [(PROJ_A_PATH, 10), (PROJ_B_PATH, 10)],
1276 )
1277 .await;
1278 storage
1279 .setup_command(
1280 Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")
1281 .with_alias(Some("gco".to_string()))
1282 .with_description(Some("Checkout the main branch".to_string()))
1283 .with_tags(Some(vec!["#git".to_string()])),
1284 [(PROJ_A_PATH, 30), (PROJ_B_PATH, 30)],
1285 )
1286 .await;
1287 storage
1288 .setup_command(
1289 Command::new("git", SOURCE_TLDR, "git commit -m")
1290 .with_alias(Some("gc".to_string()))
1291 .with_description(Some("Commit changes".to_string()))
1292 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1293 [(PROJ_A_PATH, 15)],
1294 )
1295 .await;
1296 storage
1297 .setup_command(
1298 Command::new("docker", SOURCE_TLDR, "docker ps -a")
1299 .with_description(Some("List all containers".to_string()))
1300 .with_tags(Some(vec!["#docker".to_string(), "#list".to_string()])),
1301 [(PROJ_A_PATH, 5), (PROJ_B_PATH, 5)],
1302 )
1303 .await;
1304 storage
1305 .setup_command(
1306 Command::new("git", SOURCE_TLDR, "git push")
1307 .with_description(Some("Push changes".to_string()))
1308 .with_tags(Some(vec!["#git".to_string(), "#push".to_string()])),
1309 [(PROJ_A_PATH, 20), (PROJ_B_PATH, 20)],
1310 )
1311 .await;
1312 storage
1313 .setup_command(
1314 Command::new(CATEGORY_USER, SOURCE_IMPORT, "ls -lha")
1315 .with_description(Some("List files".to_string()))
1316 .with_tags(Some(vec!["#unix".to_string(), "#list".to_string()])),
1317 [(PROJ_A_PATH, 100), (PROJ_B_PATH, 100), (UNRELATED_PATH, 100)],
1318 )
1319 .await;
1320
1321 storage
1322 }
1323
1324 impl SqliteStorage {
1325 async fn check_sqlite_version(&self) {
1327 let version: String = self
1328 .client
1329 .conn_mut(|conn| {
1330 conn.query_row("SELECT sqlite_version()", [], |row| row.get(0))
1331 .map_err(Into::into)
1332 })
1333 .await
1334 .unwrap();
1335 println!("Running with SQLite version: {version}");
1336 }
1337
1338 async fn setup_command(
1341 &self,
1342 command: Command,
1343 usage: impl IntoIterator<Item = (&str, i32)> + Send + 'static,
1344 ) -> Command {
1345 let command = self.insert_command(command).await.unwrap();
1346 self.client
1347 .conn_mut(move |conn| {
1348 for (path, usage_count) in usage {
1349 conn.execute(
1350 r#"
1351 INSERT INTO command_usage (command_id, path, usage_count)
1352 VALUES (?1, ?2, ?3)
1353 ON CONFLICT(command_id, path) DO UPDATE SET
1354 usage_count = excluded.usage_count"#,
1355 (&command.id, path, usage_count),
1356 )?;
1357 }
1358 Ok(command)
1359 })
1360 .await
1361 .unwrap()
1362 }
1363 }
1364}