1use std::{cmp::Ordering, sync::atomic::Ordering as AtomicOrdering};
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{Row, fallible_iterator::FallibleIterator, ffi, types::Type};
5use sea_query::SqliteQueryBuilder;
6use sea_query_rusqlite::RusqliteBinder;
7use tracing::instrument;
8use uuid::Uuid;
9
10use super::{SqliteStorage, queries::*};
11use crate::{
12 config::SearchCommandTuning,
13 errors::{Result, UserFacingError},
14 model::{Command, SOURCE_TLDR, SearchCommandsFilter},
15};
16
17impl SqliteStorage {
18 #[instrument(skip_all)]
21 pub async fn setup_workspace_storage(&self) -> Result<()> {
22 tracing::trace!("Creating workspace-specific tables");
23 self.client
24 .conn_mut(|conn| {
25 let schemas: Vec<String> = conn
27 .prepare(
28 r"SELECT sql
29 FROM sqlite_master
30 WHERE (type = 'table' AND name = 'variable_completion')
31 OR (type = 'table' AND name = 'command')
32 OR (type = 'table' AND name LIKE 'command_%fts')
33 OR (type = 'trigger' AND name LIKE 'command_%_fts' AND tbl_name = 'command')",
34 )?
35 .query_map([], |row| row.get(0))?
36 .collect::<Result<Vec<String>, _>>()?;
37
38 let tx = conn.transaction()?;
39
40 for schema in schemas {
42 let temp_schema = schema
43 .replace("variable_completion", "workspace_variable_completion")
44 .replace("command", "workspace_command")
45 .replace("CREATE TABLE ", "CREATE TEMP TABLE ")
46 .replace("CREATE VIRTUAL TABLE ", "CREATE VIRTUAL TABLE temp.")
47 .replace("CREATE TRIGGER ", "CREATE TEMP TRIGGER ");
48 tracing::trace!("Executing:\n{temp_schema}");
49 tx.execute(&temp_schema, [])?;
50 }
51
52 tx.commit()?;
53 Ok(())
54 })
55 .await?;
56
57 self.workspace_tables_loaded.store(true, AtomicOrdering::SeqCst);
58
59 Ok(())
60 }
61
62 #[instrument(skip_all)]
64 pub async fn is_empty(&self) -> Result<bool> {
65 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
66 self.client
67 .conn(move |conn| {
68 let query = if workspace_tables_loaded {
69 "SELECT NOT EXISTS (SELECT 1 FROM command UNION ALL SELECT 1 FROM workspace_command)"
70 } else {
71 "SELECT NOT EXISTS(SELECT 1 FROM command)"
72 };
73 tracing::trace!("Checking if storage is empty:\n{query}");
74 Ok(conn.query_row(query, [], |r| r.get(0))?)
75 })
76 .await
77 }
78
79 #[instrument(skip_all)]
81 pub async fn find_tags(
82 &self,
83 filter: SearchCommandsFilter,
84 tag_prefix: Option<String>,
85 tuning: &SearchCommandTuning,
86 ) -> Result<Vec<(String, u64, bool)>> {
87 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
88 let query = query_find_tags(filter, tag_prefix, tuning, workspace_tables_loaded)?;
89 if tracing::enabled!(tracing::Level::TRACE) {
90 tracing::trace!("Querying tags:\n{}", query.to_string(SqliteQueryBuilder));
91 }
92 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
93 self.client
94 .conn(move |conn| {
95 conn.prepare(&stmt)?
96 .query(&*values.as_params())?
97 .and_then(|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))
98 .collect()
99 })
100 .await
101 }
102
103 #[instrument(skip_all)]
108 pub async fn find_commands(
109 &self,
110 filter: SearchCommandsFilter,
111 working_path: impl Into<String>,
112 tuning: &SearchCommandTuning,
113 ) -> Result<(Vec<Command>, bool)> {
114 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
115 let cleaned_filter = filter.cleaned();
116
117 let mut query_alias = None;
119 if let Some(ref term) = cleaned_filter.search_term {
120 let (query, params) = if workspace_tables_loaded {
122 (
123 format!(
124 r#"WITH commands_unified AS (
125 SELECT rowid, *, 0 AS is_workspace FROM command
126 UNION ALL
127 SELECT rowid, *, 1 AS is_workspace FROM workspace_command
128 ),
129 commands_ranked AS (
130 SELECT *, ROW_NUMBER() OVER (PARTITION BY alias ORDER BY is_workspace ASC) as _rank
131 FROM commands_unified
132 )
133 SELECT *
134 FROM commands_ranked
135 WHERE alias IS NOT NULL AND alias = ?1 AND _rank = 1
136 LIMIT {QUERY_LIMIT}"#
137 ),
138 (term.clone(),),
139 )
140 } else {
141 (
142 format!(
143 r#"SELECT c.rowid, c.*
144 FROM command c
145 WHERE c.alias IS NOT NULL AND c.alias = ?1
146 LIMIT {QUERY_LIMIT}"#
147 ),
148 (term.clone(),),
149 )
150 };
151 query_alias = Some((query, params));
152 }
153
154 let query = query_find_commands(cleaned_filter, working_path, tuning, workspace_tables_loaded)?;
156 let query_trace = if tracing::enabled!(tracing::Level::TRACE) {
157 query.to_string(SqliteQueryBuilder)
158 } else {
159 String::default()
160 };
161 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
162
163 let tuning = *tuning;
165 self.client
166 .conn(move |conn| {
167 if let Some((query_alias, a_params)) = query_alias {
169 tracing::trace!("Querying aliased commands:\n{query_alias}");
170 let rows = conn
172 .prepare(&query_alias)?
173 .query(a_params)?
174 .map(|r| Command::try_from(r))
175 .collect::<Vec<_>>()?;
176 if !rows.is_empty() {
178 return Ok((rows, true));
179 }
180 }
181 if tracing::enabled!(tracing::Level::TRACE) {
183 tracing::trace!("Querying commands:\n{query_trace}");
184 }
185 Ok((
186 rerank_query_results(
187 conn.prepare(&stmt)?
188 .query(&*values.as_params())?
189 .and_then(|r| QueryResultItem::try_from(r))
190 .collect::<Result<Vec<_>, _>>()?,
191 &tuning,
192 ),
193 false,
194 ))
195 })
196 .await
197 }
198
199 #[instrument(skip_all)]
201 pub async fn delete_tldr_commands(&self, category: Option<String>) -> Result<u64> {
202 self.client
203 .conn_mut(move |conn| {
204 let mut query = String::from("DELETE FROM command WHERE source = ?1");
205 let mut params: Vec<String> = vec![SOURCE_TLDR.to_owned()];
206 if let Some(cat) = category {
207 query.push_str(" AND category = ?2");
208 params.push(cat);
209 }
210 tracing::trace!("Deleting tldr commands:\n{query}");
211 let affected = conn.execute(&query, rusqlite::params_from_iter(params))?;
212 Ok(affected as u64)
213 })
214 .await
215 }
216
217 #[instrument(skip_all)]
221 pub async fn insert_command(&self, command: Command) -> Result<Command> {
222 self.client
223 .conn(move |conn| {
224 let query = r#"INSERT INTO command (
225 id,
226 category,
227 source,
228 alias,
229 cmd,
230 flat_cmd,
231 description,
232 flat_description,
233 tags,
234 created_at,
235 updated_at
236 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)"#;
237 tracing::trace!("Inserting a command:\n{query}");
238 let res = conn.execute(
239 query,
240 (
241 &command.id,
242 &command.category,
243 &command.source,
244 &command.alias,
245 &command.cmd,
246 &command.flat_cmd,
247 &command.description,
248 &command.flat_description,
249 serde_json::to_value(&command.tags)?,
250 &command.created_at,
251 &command.updated_at,
252 ),
253 );
254 match res {
255 Ok(_) => Ok(command),
256 Err(err) => {
257 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
258 if code == ffi::SQLITE_CONSTRAINT_UNIQUE || code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY {
259 Err(UserFacingError::CommandAlreadyExists.into())
260 } else {
261 Err(Report::from(err).into())
262 }
263 }
264 }
265 })
266 .await
267 }
268
269 #[instrument(skip_all)]
273 pub async fn update_command(&self, command: Command) -> Result<Command> {
274 self.client
275 .conn(move |conn| {
276 let query = r#"UPDATE command SET
277 category = ?2,
278 source = ?3,
279 alias = ?4,
280 cmd = ?5,
281 flat_cmd = ?6,
282 description = ?7,
283 flat_description = ?8,
284 tags = ?9,
285 created_at = ?10,
286 updated_at = ?11
287 WHERE id = ?1"#;
288 tracing::trace!("Updating a command:\n{query}");
289 let res = conn.execute(
290 query,
291 (
292 &command.id,
293 &command.category,
294 &command.source,
295 &command.alias,
296 &command.cmd,
297 &command.flat_cmd,
298 &command.description,
299 &command.flat_description,
300 serde_json::to_value(&command.tags)?,
301 &command.created_at,
302 &command.updated_at,
303 ),
304 );
305 match res {
306 Ok(0) => Err(eyre!("Command not found: {}", command.id).into()),
307 Ok(_) => Ok(command),
308 Err(err) => {
309 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
310 if code == ffi::SQLITE_CONSTRAINT_UNIQUE {
311 Err(UserFacingError::CommandAlreadyExists.into())
312 } else {
313 Err(Report::from(err).into())
314 }
315 }
316 }
317 })
318 .await
319 }
320
321 #[instrument(skip_all)]
323 pub async fn increment_command_usage(
324 &self,
325 command_id: Uuid,
326 path: impl AsRef<str> + Send + 'static,
327 ) -> Result<i32> {
328 self.client
329 .conn_mut(move |conn| {
330 let query = r#"
331 INSERT INTO command_usage (command_id, path, usage_count)
332 VALUES (?1, ?2, 1)
333 ON CONFLICT(command_id, path) DO UPDATE SET
334 usage_count = usage_count + 1
335 RETURNING usage_count;"#;
336 tracing::trace!("Incrementing command usage:\n{query}");
337 Ok(conn.query_row(query, (&command_id, &path.as_ref()), |r| r.get(0))?)
338 })
339 .await
340 }
341
342 #[instrument(skip_all)]
346 pub async fn delete_command(&self, command_id: Uuid) -> Result<()> {
347 self.client
348 .conn(move |conn| {
349 let query = "DELETE FROM command WHERE id = ?1";
350 tracing::trace!("Deleting command:\n{query}");
351 let res = conn.execute(query, (&command_id,));
352 match res {
353 Ok(0) => Err(eyre!("Command not found: {command_id}").into()),
354 Ok(_) => Ok(()),
355 Err(err) => Err(Report::from(err).into()),
356 }
357 })
358 .await
359 }
360}
361
362fn rerank_query_results(items: Vec<QueryResultItem>, tuning: &SearchCommandTuning) -> Vec<Command> {
371 if items.is_empty() {
373 return Vec::new();
374 }
375 if items.len() == 1 {
376 return items.into_iter().map(|item| item.command).collect();
377 }
378
379 let (template_matches, mut other_items): (Vec<_>, Vec<_>) = items
382 .into_iter()
383 .partition(|item| item.text_score >= TEMPLATE_MATCH_RANK);
384 if !template_matches.is_empty() {
385 tracing::trace!("Found {} template matches", template_matches.len());
386 }
387
388 let mut final_commands: Vec<Command> = template_matches.into_iter().map(|item| item.command).collect();
390
391 if other_items.len() <= 1 {
393 final_commands.extend(other_items.into_iter().map(|item| item.command));
394 return final_commands;
395 }
396
397 let mut min_text = 0f64;
400 let mut min_path = 0f64;
401 let mut min_usage = 0f64;
402 let mut max_text = f64::NEG_INFINITY;
403 let mut max_path = f64::NEG_INFINITY;
404 let mut max_usage = f64::NEG_INFINITY;
405 for item in &other_items {
406 min_text = min_text.min(item.text_score);
407 max_text = max_text.max(item.text_score);
408 min_path = min_path.min(item.path_score);
409 max_path = max_path.max(item.path_score);
410 min_usage = min_usage.min(item.usage_score);
411 max_usage = max_usage.max(item.usage_score);
412 }
413
414 let range_text = (max_text > min_text).then_some(max_text - min_text);
416 let range_path = (max_path > min_path).then_some(max_path - min_path);
417 let range_usage = (max_usage > min_usage).then_some(max_usage - min_usage);
418
419 other_items.sort_by(|a, b| {
421 match b.is_workspace_command.cmp(&a.is_workspace_command) {
423 Ordering::Equal => {
424 let calculate_score = |item: &QueryResultItem| -> f64 {
426 let norm_text = range_text.map_or(0.5, |range| (item.text_score - min_text) / range);
429 let norm_path = range_path.map_or(0.5, |range| (item.path_score - min_path) / range);
430 let norm_usage = range_usage.map_or(0.5, |range| (item.usage_score - min_usage) / range);
431
432 (norm_text * tuning.text.points as f64)
434 + (norm_path * tuning.path.points as f64)
435 + (norm_usage * tuning.usage.points as f64)
436 };
437
438 let final_score_a = calculate_score(a);
439 let final_score_b = calculate_score(b);
440
441 final_score_b.partial_cmp(&final_score_a).unwrap_or(Ordering::Equal)
443 }
444 other => other,
446 }
447 });
448
449 final_commands.extend(other_items.into_iter().map(|item| item.command));
451 final_commands
452}
453
454impl<'a> TryFrom<&'a Row<'a>> for Command {
455 type Error = rusqlite::Error;
456
457 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
458 Ok(Self {
459 id: row.get(1)?,
461 category: row.get(2)?,
462 source: row.get(3)?,
463 alias: row.get(4)?,
464 cmd: row.get(5)?,
465 flat_cmd: row.get(6)?,
466 description: row.get(7)?,
467 flat_description: row.get(8)?,
468 tags: serde_json::from_value(row.get::<_, serde_json::Value>(9)?)
469 .map_err(|e| rusqlite::Error::FromSqlConversionFailure(9, Type::Text, Box::new(e)))?,
470 created_at: row.get(10)?,
471 updated_at: row.get(11)?,
472 })
473 }
474}
475
476struct QueryResultItem {
478 command: Command,
480 is_workspace_command: bool,
482 usage_score: f64,
484 path_score: f64,
486 text_score: f64,
488}
489
490impl<'a> TryFrom<&'a Row<'a>> for QueryResultItem {
491 type Error = rusqlite::Error;
492
493 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
494 Ok(Self {
495 command: Command::try_from(row)?,
496 is_workspace_command: row.get(12)?,
497 usage_score: row.get(13)?,
498 path_score: row.get(14)?,
499 text_score: row.get(15)?,
500 })
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use pretty_assertions::assert_eq;
507 use strum::IntoEnumIterator;
508 use tokio_stream::iter;
509 use uuid::Uuid;
510
511 use super::*;
512 use crate::{
513 errors::AppError,
514 model::{CATEGORY_USER, ImportExportItem, SOURCE_IMPORT, SOURCE_USER, SearchMode},
515 };
516
517 const PROJ_A_PATH: &str = "/home/user/project-a";
518 const PROJ_A_API_PATH: &str = "/home/user/project-a/api";
519 const PROJ_B_PATH: &str = "/home/user/project-b";
520 const UNRELATED_PATH: &str = "/var/log";
521
522 #[tokio::test]
523 async fn test_setup_workspace_storage() {
524 let storage = SqliteStorage::new_in_memory().await.unwrap();
525 storage.check_sqlite_version().await;
526 let res = storage.setup_workspace_storage().await;
527 assert!(res.is_ok(), "Expected workspace storage setup to succeed: {res:?}");
528 }
529
530 #[tokio::test]
531 async fn test_is_empty() {
532 let storage = SqliteStorage::new_in_memory().await.unwrap();
533 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
534
535 let cmd = Command {
536 id: Uuid::now_v7(),
537 cmd: "test_cmd".to_string(),
538 ..Default::default()
539 };
540 storage.insert_command(cmd).await.unwrap();
541
542 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
543 }
544
545 #[tokio::test]
546 async fn test_is_empty_with_workspace() {
547 let storage = SqliteStorage::new_in_memory().await.unwrap();
548 storage.setup_workspace_storage().await.unwrap();
549 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
550
551 let cmd = Command {
552 id: Uuid::now_v7(),
553 cmd: "test_cmd".to_string(),
554 ..Default::default()
555 };
556 storage.insert_command(cmd).await.unwrap();
557
558 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
559 }
560
561 #[tokio::test]
562 async fn test_find_tags_no_filters() -> Result<()> {
563 let storage = setup_ranking_storage().await;
564
565 let result = storage
566 .find_tags(SearchCommandsFilter::default(), None, &SearchCommandTuning::default())
567 .await?;
568
569 let expected = vec![
570 ("#git".to_string(), 5, false),
571 ("#build".to_string(), 2, false),
572 ("#commit".to_string(), 2, false),
573 ("#docker".to_string(), 2, false),
574 ("#list".to_string(), 2, false),
575 ("#k8s".to_string(), 1, false),
576 ("#npm".to_string(), 1, false),
577 ("#pod".to_string(), 1, false),
578 ("#push".to_string(), 1, false),
579 ("#unix".to_string(), 1, false),
580 ];
581
582 assert_eq!(result.len(), 10, "Expected 10 unique tags");
583 assert_eq!(result, expected, "Tags list or order mismatch");
584
585 Ok(())
586 }
587
588 #[tokio::test]
589 async fn test_find_tags_filter_by_tags_only() -> Result<()> {
590 let storage = setup_ranking_storage().await;
591
592 let filter1 = SearchCommandsFilter {
593 tags: Some(vec!["#git".to_string()]),
594 ..Default::default()
595 };
596 let result1 = storage
597 .find_tags(filter1, None, &SearchCommandTuning::default())
598 .await?;
599 let expected1 = vec![("#commit".to_string(), 2, false), ("#push".to_string(), 1, false)];
600 assert_eq!(result1.len(), 2,);
601 assert_eq!(result1, expected1);
602
603 let filter2 = SearchCommandsFilter {
604 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
605 ..Default::default()
606 };
607 let result2 = storage
608 .find_tags(filter2, None, &SearchCommandTuning::default())
609 .await?;
610 assert!(result2.is_empty());
611
612 let filter3 = SearchCommandsFilter {
613 tags: Some(vec!["#list".to_string()]),
614 ..Default::default()
615 };
616 let result3 = storage
617 .find_tags(filter3, None, &SearchCommandTuning::default())
618 .await?;
619 let expected3 = vec![("#docker".to_string(), 1, false), ("#unix".to_string(), 1, false)];
620 assert_eq!(result3.len(), 2);
621 assert_eq!(result3, expected3);
622
623 Ok(())
624 }
625
626 #[tokio::test]
627 async fn test_find_tags_filter_by_prefix_only() -> Result<()> {
628 let storage = setup_ranking_storage().await;
629
630 let result = storage
631 .find_tags(
632 SearchCommandsFilter::default(),
633 Some("#comm".to_string()),
634 &SearchCommandTuning::default(),
635 )
636 .await?;
637 let expected = vec![("#commit".to_string(), 2, false)];
638 assert_eq!(result.len(), 1);
639 assert_eq!(result, expected);
640
641 Ok(())
642 }
643
644 #[tokio::test]
645 async fn test_find_tags_filter_by_tags_and_prefix() -> Result<()> {
646 let storage = setup_ranking_storage().await;
647
648 let filter1 = SearchCommandsFilter {
649 tags: Some(vec!["#git".to_string()]),
650 ..Default::default()
651 };
652 let result1 = storage
653 .find_tags(filter1, Some("#comm".to_string()), &SearchCommandTuning::default())
654 .await?;
655 let expected1 = vec![("#commit".to_string(), 2, false)];
656 assert_eq!(result1.len(), 1);
657 assert_eq!(result1, expected1);
658
659 let filter2 = SearchCommandsFilter {
660 tags: Some(vec!["#git".to_string()]),
661 ..Default::default()
662 };
663 let result2 = storage
664 .find_tags(filter2, Some("#push".to_string()), &SearchCommandTuning::default())
665 .await?;
666 let expected2 = vec![("#push".to_string(), 1, true)];
667 assert_eq!(result2.len(), 1);
668 assert_eq!(result2, expected2);
669
670 Ok(())
671 }
672
673 #[tokio::test]
674 async fn test_find_commands_no_filter() {
675 let storage = setup_ranking_storage().await;
676 let filter = SearchCommandsFilter::default();
677 let (commands, _) = storage
678 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
679 .await
680 .unwrap();
681 assert_eq!(commands.len(), 10, "Expected all sample commands");
682 }
683
684 #[tokio::test]
685 async fn test_find_commands_filter_by_category() {
686 let storage = setup_ranking_storage().await;
687 let filter = SearchCommandsFilter {
688 category: Some(vec!["git".to_string()]),
689 ..Default::default()
690 };
691 let (commands, _) = storage
692 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
693 .await
694 .unwrap();
695 assert_eq!(commands.len(), 2);
696 assert!(commands.iter().all(|c| c.category == "git"));
697
698 let filter_no_match = SearchCommandsFilter {
699 category: Some(vec!["nonexistent".to_string()]),
700 ..Default::default()
701 };
702 let (commands_no_match, _) = storage
703 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
704 .await
705 .unwrap();
706 assert!(commands_no_match.is_empty());
707 }
708
709 #[tokio::test]
710 async fn test_find_commands_filter_by_source() {
711 let storage = setup_ranking_storage().await;
712 let filter = SearchCommandsFilter {
713 source: Some(SOURCE_TLDR.to_string()),
714 ..Default::default()
715 };
716 let (commands, _) = storage
717 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
718 .await
719 .unwrap();
720 assert_eq!(commands.len(), 3);
721 assert!(commands.iter().all(|c| c.source == SOURCE_TLDR));
722 }
723
724 #[tokio::test]
725 async fn test_find_commands_filter_by_tags() {
726 let storage = setup_ranking_storage().await;
727 let filter_single_tag = SearchCommandsFilter {
728 tags: Some(vec!["#git".to_string()]),
729 ..Default::default()
730 };
731 let (commands_single_tag, _) = storage
732 .find_commands(filter_single_tag, "/some/path", &SearchCommandTuning::default())
733 .await
734 .unwrap();
735 assert_eq!(commands_single_tag.len(), 5);
736
737 let filter_multiple_tags = SearchCommandsFilter {
738 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
739 ..Default::default()
740 };
741 let (commands_multiple_tags, _) = storage
742 .find_commands(filter_multiple_tags, "/some/path", &SearchCommandTuning::default())
743 .await
744 .unwrap();
745 assert_eq!(commands_multiple_tags.len(), 1);
746
747 let filter_empty_tags = SearchCommandsFilter {
748 tags: Some(vec![]),
749 ..Default::default()
750 };
751 let (commands_empty_tags, _) = storage
752 .find_commands(filter_empty_tags, "/some/path", &SearchCommandTuning::default())
753 .await
754 .unwrap();
755 assert_eq!(commands_empty_tags.len(), 10);
756 }
757
758 #[tokio::test]
759 async fn test_find_commands_alias_precedence() {
760 let storage = setup_ranking_storage().await;
761 storage
762 .setup_command(
763 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
764 [("/some/path", 100)],
765 )
766 .await;
767
768 for mode in SearchMode::iter() {
769 let filter = SearchCommandsFilter {
770 search_term: Some("gc".to_string()),
771 search_mode: mode,
772 ..Default::default()
773 };
774 let (commands, alias_match) = storage
775 .find_commands(filter, "", &SearchCommandTuning::default())
776 .await
777 .unwrap();
778 assert!(alias_match, "Expected alias match for mode {mode:?}");
779 assert_eq!(commands.len(), 1, "Expected only alias match for mode {mode:?}");
780 assert_eq!(
781 commands[0].cmd, "git commit -m",
782 "Expected correct alias command for mode {mode:?}"
783 );
784 }
785 }
786
787 #[tokio::test]
788 async fn test_find_commands_search_mode_exact() {
789 let storage = setup_ranking_storage().await;
790 storage.setup_workspace_storage().await.unwrap();
791 let filter_token_match = SearchCommandsFilter {
792 search_term: Some("commit".to_string()),
793 search_mode: SearchMode::Exact,
794 ..Default::default()
795 };
796 let (commands_token_match, _) = storage
797 .find_commands(filter_token_match, "/some/path", &SearchCommandTuning::default())
798 .await
799 .unwrap();
800 assert_eq!(commands_token_match.len(), 2);
801 assert_eq!(commands_token_match[0].cmd, "git commit -m");
802 assert_eq!(commands_token_match[1].cmd, "git commit -m '{{message}}'");
803
804 let filter_no_match = SearchCommandsFilter {
805 search_term: Some("nonexistentterm".to_string()),
806 search_mode: SearchMode::Exact,
807 ..Default::default()
808 };
809 let (commands_no_match, _) = storage
810 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
811 .await
812 .unwrap();
813 assert!(commands_no_match.is_empty());
814 }
815
816 #[tokio::test]
817 async fn test_find_commands_search_mode_relaxed() {
818 let storage = setup_ranking_storage().await;
819 storage.setup_workspace_storage().await.unwrap();
820 let filter = SearchCommandsFilter {
821 search_term: Some("docker list".to_string()),
822 search_mode: SearchMode::Relaxed,
823 ..Default::default()
824 };
825 let (commands, _) = storage
826 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
827 .await
828 .unwrap();
829 assert_eq!(commands.len(), 2);
830 assert!(commands.iter().any(|c| c.cmd == "docker ps -a"));
831 assert!(commands.iter().any(|c| c.cmd == "ls -lha"));
832 }
833
834 #[tokio::test]
835 async fn test_find_commands_search_mode_regex() {
836 let storage = setup_ranking_storage().await;
837 storage.setup_workspace_storage().await.unwrap();
838 let filter = SearchCommandsFilter {
839 search_term: Some(r"git\s.*it".to_string()),
840 search_mode: SearchMode::Regex,
841 ..Default::default()
842 };
843 let (commands, _) = storage
844 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
845 .await
846 .unwrap();
847 assert_eq!(commands.len(), 2);
848 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
849 assert_eq!(commands[1].cmd, "git commit -m");
850
851 let filter_invalid = SearchCommandsFilter {
852 search_term: Some("[[invalid_regex".to_string()),
853 search_mode: SearchMode::Regex,
854 ..Default::default()
855 };
856 assert!(matches!(
857 storage
858 .find_commands(filter_invalid, "/some/path", &SearchCommandTuning::default())
859 .await,
860 Err(AppError::UserFacing(UserFacingError::InvalidRegex))
861 ));
862 }
863
864 #[tokio::test]
865 async fn test_find_commands_search_mode_fuzzy() {
866 let storage = setup_ranking_storage().await;
867 storage.setup_workspace_storage().await.unwrap();
868 let filter = SearchCommandsFilter {
869 search_term: Some("gtcomit".to_string()),
870 search_mode: SearchMode::Fuzzy,
871 ..Default::default()
872 };
873 let (commands, _) = storage
874 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
875 .await
876 .unwrap();
877 assert_eq!(commands.len(), 2);
878 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
879 assert_eq!(commands[1].cmd, "git commit -m");
880
881 let filter_empty_fuzzy = SearchCommandsFilter {
882 search_term: Some("'' | ^".to_string()),
883 search_mode: SearchMode::Fuzzy,
884 ..Default::default()
885 };
886 assert!(matches!(
887 storage
888 .find_commands(filter_empty_fuzzy, "/some/path", &SearchCommandTuning::default())
889 .await,
890 Err(AppError::UserFacing(UserFacingError::InvalidFuzzy))
891 ));
892 }
893
894 #[tokio::test]
895 async fn test_find_commands_search_mode_auto() {
896 let storage = setup_ranking_storage().await;
897 let default_tuning = SearchCommandTuning::default();
898
899 let run_search = |term: &'static str, path: &'static str| {
901 let storage = storage.clone();
902 async move {
903 let filter = SearchCommandsFilter {
904 search_term: Some(term.to_string()),
905 search_mode: SearchMode::Auto,
906 ..Default::default()
907 };
908 storage.find_commands(filter, path, &default_tuning).await.unwrap()
909 }
910 };
911
912 let (commands, _) = run_search("list containers", UNRELATED_PATH).await;
914 assert!(!commands.is_empty(), "Expected results for 'list containers'");
915 assert_eq!(
916 commands[0].cmd, "docker ps -a",
917 "Expected 'docker ps -a' to be the top result for 'list containers'"
918 );
919
920 let (commands, _) = run_search("git commit", PROJ_A_PATH).await;
922 assert!(commands.len() >= 2, "Expected at least two results for 'git commit'");
923 assert_eq!(
924 commands[0].cmd, "git commit -m",
925 "Expected 'git commit -m' to be the top result for 'git commit' due to usage"
926 );
927 assert_eq!(
928 commands[1].cmd, "git commit -m '{{message}}'",
929 "Expected template command to be second for 'git commit'"
930 );
931
932 let (commands, _) = run_search("git commit -m 'my new feature'", PROJ_A_PATH).await;
934 assert!(!commands.is_empty(), "Expected results for template match");
935 assert_eq!(
936 commands[0].cmd, "git commit -m '{{message}}'",
937 "Expected template command to be the top result for a matching search term"
938 );
939
940 let (commands, _) = run_search("build", PROJ_A_API_PATH).await;
942 assert!(!commands.is_empty(), "Expected results for 'build'");
943 assert_eq!(
944 commands[0].cmd, "npm run build:prod",
945 "Expected 'npm run build:prod' to be top result for 'build' in its project path"
946 );
947
948 let (commands, _) = run_search("gt sta", PROJ_A_PATH).await;
950 assert!(!commands.is_empty(), "Expected results for fuzzy search 'gt sta'");
951 assert_eq!(
952 commands[0].cmd, "git status",
953 "Expected 'git status' as top result for fuzzy search 'gt sta'"
954 );
955
956 let (commands, _) = run_search("get pod monitoring", UNRELATED_PATH).await;
958 assert!(!commands.is_empty(), "Expected results for 'get pod monitoring'");
959 assert_eq!(
960 commands[0].cmd, "kubectl get pod -n monitoring my-specific-pod-12345",
961 "Expected specific 'kubectl' command to be found"
962 );
963
964 let (commands, _) = run_search("status", PROJ_A_API_PATH).await;
966 assert!(!commands.is_empty(), "Expected results for 'status'");
967 assert_eq!(
968 commands[0].cmd, "git status",
969 "Expected 'git status' to be top due to high usage in parent path"
970 );
971 }
972
973 #[tokio::test]
974 async fn test_find_commands_search_mode_auto_hastag_only() {
975 let storage = setup_ranking_storage().await;
976
977 let filter = SearchCommandsFilter {
980 search_term: Some("#".to_string()),
981 search_mode: SearchMode::Auto,
982 ..Default::default()
983 };
984
985 let res = storage
986 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
987 .await;
988 assert!(res.is_ok(), "Expected a success response, got: {res:?}")
989 }
990
991 #[tokio::test]
992 async fn test_find_commands_including_workspace() {
993 let storage = setup_ranking_storage().await;
994
995 storage.setup_workspace_storage().await.unwrap();
996 let commands_to_import = vec![
997 ImportExportItem::Command(Command {
998 id: Uuid::now_v7(),
999 cmd: "cmd1".to_string(),
1000 ..Default::default()
1001 }),
1002 ImportExportItem::Command(Command {
1003 id: Uuid::now_v7(),
1004 cmd: "cmd2".to_string(),
1005 ..Default::default()
1006 }),
1007 ];
1008 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1009 storage.import_items(stream, false, true).await.unwrap();
1010
1011 let (commands, _) = storage
1012 .find_commands(
1013 SearchCommandsFilter::default(),
1014 "/some/path",
1015 &SearchCommandTuning::default(),
1016 )
1017 .await
1018 .unwrap();
1019 assert_eq!(commands.len(), 12, "Expected 12 commands including workspace");
1020 }
1021
1022 #[tokio::test]
1023 async fn test_find_commands_with_text_including_workspace() {
1024 let storage = setup_ranking_storage().await;
1025
1026 storage.setup_workspace_storage().await.unwrap();
1027 let commands_to_import = vec![ImportExportItem::Command(Command {
1028 id: Uuid::now_v7(),
1029 cmd: "git checkout -b feature/{{name:kebab}}".to_string(),
1030 ..Default::default()
1031 })];
1032 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1033 storage.import_items(stream, false, true).await.unwrap();
1034
1035 let filter = SearchCommandsFilter {
1036 search_term: Some("git".to_string()),
1037 ..Default::default()
1038 };
1039
1040 let (commands, _) = storage
1041 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1042 .await
1043 .unwrap();
1044 assert_eq!(commands.len(), 6, "Expected 6 git commands including workspace");
1045 assert!(
1046 commands
1047 .iter()
1048 .any(|c| c.cmd == "git checkout -b feature/{{name:kebab}}")
1049 );
1050 }
1051
1052 #[tokio::test]
1053 async fn test_delete_tldr_commands() {
1054 let storage = SqliteStorage::new_in_memory().await.unwrap();
1055
1056 let tldr_cmd1 = Command {
1058 id: Uuid::now_v7(),
1059 category: "git".to_string(),
1060 source: SOURCE_TLDR.to_string(),
1061 cmd: "git status".to_string(),
1062 ..Default::default()
1063 };
1064 let tldr_cmd2 = Command {
1065 id: Uuid::now_v7(),
1066 category: "docker".to_string(),
1067 source: SOURCE_TLDR.to_string(),
1068 cmd: "docker ps".to_string(),
1069 ..Default::default()
1070 };
1071 let user_cmd = Command {
1072 id: Uuid::now_v7(),
1073 category: "git".to_string(),
1074 source: SOURCE_USER.to_string(),
1075 cmd: "git log".to_string(),
1076 ..Default::default()
1077 };
1078
1079 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1080 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1081 storage.insert_command(user_cmd.clone()).await.unwrap();
1082
1083 let removed = storage.delete_tldr_commands(None).await.unwrap();
1085 assert_eq!(removed, 2, "Should remove both tldr commands");
1086
1087 let (remaining, _) = storage
1088 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1089 .await
1090 .unwrap();
1091 assert_eq!(remaining.len(), 1, "Only user command should remain");
1092 assert_eq!(remaining[0].cmd, user_cmd.cmd);
1093
1094 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1096 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1097
1098 let removed_git = storage.delete_tldr_commands(Some("git".to_string())).await.unwrap();
1100 assert_eq!(removed_git, 1, "Should remove one tldr command in 'git' category");
1101
1102 let (remaining, _) = storage
1103 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1104 .await
1105 .unwrap();
1106 let remaining_cmds: Vec<_> = remaining.iter().map(|c| &c.cmd).collect();
1107 assert!(remaining_cmds.contains(&&tldr_cmd2.cmd));
1108 assert!(remaining_cmds.contains(&&user_cmd.cmd));
1109 assert!(!remaining_cmds.contains(&&tldr_cmd1.cmd));
1110 }
1111
1112 #[tokio::test]
1113 async fn test_insert_command() {
1114 let storage = SqliteStorage::new_in_memory().await.unwrap();
1115
1116 let mut cmd = Command {
1117 id: Uuid::now_v7(),
1118 category: "test".to_string(),
1119 cmd: "test_cmd".to_string(),
1120 description: Some("test desc".to_string()),
1121 tags: Some(vec!["tag1".to_string()]),
1122 ..Default::default()
1123 };
1124
1125 let mut inserted = storage.insert_command(cmd.clone()).await.unwrap();
1126 assert_eq!(inserted.cmd, cmd.cmd);
1127
1128 inserted.cmd = "other_cmd".to_string();
1130 match storage.insert_command(inserted).await {
1131 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1132 _ => panic!("Expected CommandAlreadyExists error on duplicate id"),
1133 }
1134
1135 cmd.id = Uuid::now_v7();
1137 match storage.insert_command(cmd).await {
1138 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1139 _ => panic!("Expected CommandAlreadyExists error on duplicate cmd"),
1140 }
1141 }
1142
1143 #[tokio::test]
1144 async fn test_update_command() {
1145 let storage = SqliteStorage::new_in_memory().await.unwrap();
1146
1147 let cmd = Command {
1148 id: Uuid::now_v7(),
1149 cmd: "original".to_string(),
1150 description: Some("desc".to_string()),
1151 ..Default::default()
1152 };
1153
1154 storage.insert_command(cmd.clone()).await.unwrap();
1155
1156 let mut updated = cmd.clone();
1157 updated.cmd = "updated".to_string();
1158 updated.description = Some("new desc".to_string());
1159
1160 let result = storage.update_command(updated.clone()).await.unwrap();
1161 assert_eq!(result.cmd, "updated");
1162 assert_eq!(result.description, Some("new desc".to_string()));
1163
1164 let mut non_existent = cmd;
1166 non_existent.id = Uuid::now_v7();
1167 match storage.update_command(non_existent).await {
1168 Err(_) => (),
1169 _ => panic!("Expected error when updating non-existent command"),
1170 }
1171
1172 let another_cmd = Command {
1174 id: Uuid::now_v7(),
1175 cmd: "another".to_string(),
1176 ..Default::default()
1177 };
1178 let mut result = storage.insert_command(another_cmd.clone()).await.unwrap();
1179 result.cmd = "updated".to_string();
1180 match storage.update_command(result).await {
1181 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1182 _ => panic!("Expected CommandAlreadyExists error when updating to existing cmd"),
1183 }
1184 }
1185
1186 #[tokio::test]
1187 async fn test_increment_command_usage() {
1188 let storage = SqliteStorage::new_in_memory().await.unwrap();
1189
1190 let command = storage
1192 .setup_command(
1193 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
1194 [("/some/path", 100)],
1195 )
1196 .await;
1197
1198 let count = storage.increment_command_usage(command.id, "/path").await.unwrap();
1200 assert_eq!(count, 1);
1201
1202 let count = storage.increment_command_usage(command.id, "/some/path").await.unwrap();
1204 assert_eq!(count, 101);
1205 }
1206
1207 #[tokio::test]
1208 async fn test_delete_command() {
1209 let storage = SqliteStorage::new_in_memory().await.unwrap();
1210
1211 let cmd = Command {
1212 id: Uuid::now_v7(),
1213 cmd: "to_delete".to_string(),
1214 ..Default::default()
1215 };
1216
1217 let cmd = storage.insert_command(cmd).await.unwrap();
1218 let res = storage.delete_command(cmd.id).await;
1219 assert!(res.is_ok());
1220
1221 match storage.delete_command(cmd.id).await {
1223 Err(_) => (),
1224 _ => panic!("Expected error when deleting non-existent command"),
1225 }
1226 }
1227
1228 async fn setup_ranking_storage() -> SqliteStorage {
1230 let storage = SqliteStorage::new_in_memory().await.unwrap();
1231 storage
1232 .setup_command(
1233 Command::new(
1234 CATEGORY_USER,
1235 SOURCE_USER,
1236 "kubectl get pod -n monitoring my-specific-pod-12345",
1237 )
1238 .with_description(Some(
1239 "Get a very specific pod by its full name in the monitoring namespace".to_string(),
1240 ))
1241 .with_tags(Some(vec!["#k8s".to_string(), "#pod".to_string()])),
1242 [("/other/path", 1)],
1243 )
1244 .await;
1245 storage
1246 .setup_command(
1247 Command::new(CATEGORY_USER, SOURCE_USER, "git status")
1248 .with_description(Some("Check the status of the git repository".to_string()))
1249 .with_tags(Some(vec!["#git".to_string()])),
1250 [(PROJ_A_PATH, 50), (PROJ_B_PATH, 50), (UNRELATED_PATH, 100)],
1251 )
1252 .await;
1253 storage
1254 .setup_command(
1255 Command::new(CATEGORY_USER, SOURCE_USER, "npm run build:prod")
1256 .with_description(Some("Build the project for production".to_string()))
1257 .with_tags(Some(vec!["#npm".to_string(), "#build".to_string()])),
1258 [(PROJ_A_API_PATH, 25)],
1259 )
1260 .await;
1261 storage
1262 .setup_command(
1263 Command::new(CATEGORY_USER, SOURCE_USER, "container-image-build.sh")
1264 .with_description(Some("A generic script to build a container image".to_string()))
1265 .with_tags(Some(vec!["#docker".to_string(), "#build".to_string()])),
1266 [(UNRELATED_PATH, 35)],
1267 )
1268 .await;
1269 storage
1270 .setup_command(
1271 Command::new(CATEGORY_USER, SOURCE_USER, "git commit -m '{{message}}'")
1272 .with_description(Some("Commit with a message".to_string()))
1273 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1274 [(PROJ_A_PATH, 10), (PROJ_B_PATH, 10)],
1275 )
1276 .await;
1277 storage
1278 .setup_command(
1279 Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")
1280 .with_alias(Some("gco".to_string()))
1281 .with_description(Some("Checkout the main branch".to_string()))
1282 .with_tags(Some(vec!["#git".to_string()])),
1283 [(PROJ_A_PATH, 30), (PROJ_B_PATH, 30)],
1284 )
1285 .await;
1286 storage
1287 .setup_command(
1288 Command::new("git", SOURCE_TLDR, "git commit -m")
1289 .with_alias(Some("gc".to_string()))
1290 .with_description(Some("Commit changes".to_string()))
1291 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1292 [(PROJ_A_PATH, 15)],
1293 )
1294 .await;
1295 storage
1296 .setup_command(
1297 Command::new("docker", SOURCE_TLDR, "docker ps -a")
1298 .with_description(Some("List all containers".to_string()))
1299 .with_tags(Some(vec!["#docker".to_string(), "#list".to_string()])),
1300 [(PROJ_A_PATH, 5), (PROJ_B_PATH, 5)],
1301 )
1302 .await;
1303 storage
1304 .setup_command(
1305 Command::new("git", SOURCE_TLDR, "git push")
1306 .with_description(Some("Push changes".to_string()))
1307 .with_tags(Some(vec!["#git".to_string(), "#push".to_string()])),
1308 [(PROJ_A_PATH, 20), (PROJ_B_PATH, 20)],
1309 )
1310 .await;
1311 storage
1312 .setup_command(
1313 Command::new(CATEGORY_USER, SOURCE_IMPORT, "ls -lha")
1314 .with_description(Some("List files".to_string()))
1315 .with_tags(Some(vec!["#unix".to_string(), "#list".to_string()])),
1316 [(PROJ_A_PATH, 100), (PROJ_B_PATH, 100), (UNRELATED_PATH, 100)],
1317 )
1318 .await;
1319
1320 storage
1321 }
1322
1323 impl SqliteStorage {
1324 async fn check_sqlite_version(&self) {
1326 let version: String = self
1327 .client
1328 .conn_mut(|conn| {
1329 conn.query_row("SELECT sqlite_version()", [], |row| row.get(0))
1330 .map_err(Into::into)
1331 })
1332 .await
1333 .unwrap();
1334 println!("Running with SQLite version: {version}");
1335 }
1336
1337 async fn setup_command(
1340 &self,
1341 command: Command,
1342 usage: impl IntoIterator<Item = (&str, i32)> + Send + 'static,
1343 ) -> Command {
1344 let command = self.insert_command(command).await.unwrap();
1345 self.client
1346 .conn_mut(move |conn| {
1347 for (path, usage_count) in usage {
1348 conn.execute(
1349 r#"
1350 INSERT INTO command_usage (command_id, path, usage_count)
1351 VALUES (?1, ?2, ?3)
1352 ON CONFLICT(command_id, path) DO UPDATE SET
1353 usage_count = excluded.usage_count"#,
1354 (&command.id, path, usage_count),
1355 )?;
1356 }
1357 Ok(command)
1358 })
1359 .await
1360 .unwrap()
1361 }
1362 }
1363}