1use std::{cmp::Ordering, pin::pin};
2
3use chrono::{DateTime, Utc};
4use color_eyre::{
5 Report, Result,
6 eyre::{Context, eyre},
7};
8use futures_util::StreamExt;
9use regex::Regex;
10use rusqlite::{Row, fallible_iterator::FallibleIterator, ffi, types::Type};
11use sea_query::SqliteQueryBuilder;
12use sea_query_rusqlite::RusqliteBinder;
13use tokio::sync::mpsc;
14use tokio_stream::{Stream, wrappers::ReceiverStream};
15use tracing::instrument;
16use uuid::Uuid;
17
18use super::{SqliteStorage, queries::*};
19use crate::{
20 config::SearchCommandTuning,
21 errors::{InsertError, SearchError, UpdateError},
22 model::{CATEGORY_USER, Command, SOURCE_TLDR, SearchCommandsFilter},
23};
24
25impl SqliteStorage {
26 #[instrument(skip_all)]
29 pub async fn setup_workspace_storage(&self) -> Result<()> {
30 self.client
31 .conn_mut::<_, _, Report>(|conn| {
32 let schemas: Vec<String> = conn
34 .prepare(
35 r"SELECT sql
36 FROM sqlite_master
37 WHERE (type = 'table' AND name = 'command')
38 OR (type = 'table' AND name LIKE 'command_%fts')
39 OR (type = 'trigger' AND name LIKE 'command_%_fts' AND tbl_name = 'command')",
40 )?
41 .query_map([], |row| row.get(0))?
42 .collect::<Result<Vec<String>, _>>()?;
43
44 let tx = conn.transaction()?;
45
46 for schema in schemas {
48 let temp_schema = schema
49 .replace("command", "workspace_command")
50 .replace("CREATE TABLE", "CREATE TEMP TABLE")
51 .replace("CREATE VIRTUAL TABLE ", "CREATE VIRTUAL TABLE temp.")
52 .replace("CREATE TRIGGER", "CREATE TEMP TRIGGER");
53 tx.execute(&temp_schema, [])?;
54 }
55
56 tx.commit()?;
57 Ok(())
58 })
59 .await
60 .wrap_err("Failed to create temporary workspace storage from schema")?;
61
62 self.workspace_tables_loaded
63 .store(true, std::sync::atomic::Ordering::SeqCst);
64
65 Ok(())
66 }
67
68 #[instrument(skip_all)]
70 pub async fn is_empty(&self) -> Result<bool> {
71 let workspace_tables_loaded = self.workspace_tables_loaded.load(std::sync::atomic::Ordering::SeqCst);
72 self.client
73 .conn::<_, _, Report>(move |conn| {
74 if workspace_tables_loaded {
75 Ok(conn.query_row(
76 "SELECT NOT EXISTS (SELECT 1 FROM command UNION ALL SELECT 1 FROM workspace_command)",
77 [],
78 |r| r.get(0),
79 )?)
80 } else {
81 Ok(conn.query_row("SELECT NOT EXISTS(SELECT 1 FROM command)", [], |r| r.get(0))?)
82 }
83 })
84 .await
85 .wrap_err("Couldn't check if storage is empty")
86 }
87
88 #[instrument(skip_all)]
90 pub async fn find_tags(
91 &self,
92 filter: SearchCommandsFilter,
93 tag_prefix: Option<String>,
94 tuning: &SearchCommandTuning,
95 ) -> Result<Vec<(String, u64, bool)>, SearchError> {
96 let workspace_tables_loaded = self.workspace_tables_loaded.load(std::sync::atomic::Ordering::SeqCst);
97 let query = query_find_tags(filter, tag_prefix, tuning, workspace_tables_loaded)?;
98 if tracing::enabled!(tracing::Level::TRACE) {
99 tracing::trace!("Querying tags:\n{}", query.to_string(SqliteQueryBuilder));
100 }
101 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
102 Ok(self
103 .client
104 .conn::<_, _, Report>(move |conn| {
105 conn.prepare(&stmt)?
106 .query(&*values.as_params())?
107 .and_then(|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))
108 .collect()
109 })
110 .await
111 .wrap_err("Couldn't find tags")?)
112 }
113
114 #[instrument(skip_all)]
119 pub async fn find_commands(
120 &self,
121 filter: SearchCommandsFilter,
122 working_path: impl Into<String>,
123 tuning: &SearchCommandTuning,
124 ) -> Result<(Vec<Command>, bool), SearchError> {
125 let workspace_tables_loaded = self.workspace_tables_loaded.load(std::sync::atomic::Ordering::SeqCst);
126 let cleaned_filter = filter.cleaned();
127
128 let mut query_alias = None;
130 if let Some(ref term) = cleaned_filter.search_term {
131 query_alias = Some((
133 format!(
134 r#"SELECT c.rowid, c.*
135 FROM command c
136 WHERE c.alias IS NOT NULL AND c.alias = ?1
137 ORDER BY c.cmd ASC
138 LIMIT {QUERY_LIMIT}"#
139 ),
140 (term.clone(),),
141 ));
142 }
143
144 let query = query_find_commands(cleaned_filter, working_path, tuning, workspace_tables_loaded)?;
146 let query_trace = if tracing::enabled!(tracing::Level::TRACE) {
147 query.to_string(SqliteQueryBuilder)
148 } else {
149 String::default()
150 };
151 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
152
153 let tuning = *tuning;
155 Ok(self
156 .client
157 .conn::<_, _, Report>(move |conn| {
158 if let Some((query_alias, a_params)) = query_alias {
160 let rows = conn
162 .prepare(&query_alias)?
163 .query(a_params)?
164 .map(|r| Command::try_from(r))
165 .collect::<Vec<_>>()?;
166 if !rows.is_empty() {
168 return Ok((rows, true));
169 }
170 }
171 if tracing::enabled!(tracing::Level::TRACE) {
173 tracing::trace!("Querying commands:\n{query_trace}");
174 }
175 Ok((
176 rerank_query_results(
177 conn.prepare(&stmt)?
178 .query(&*values.as_params())?
179 .and_then(|r| QueryResultItem::try_from(r))
180 .collect::<Result<Vec<_>, _>>()?,
181 &tuning,
182 ),
183 false,
184 ))
185 })
186 .await
187 .wrap_err("Couldn't search commands")?)
188 }
189
190 #[instrument(skip_all)]
197 pub async fn import_commands(
198 &self,
199 commands: impl Stream<Item = Result<Command>> + Send + 'static,
200 filter: Option<Regex>,
201 overwrite: bool,
202 workspace: bool,
203 ) -> Result<(u64, u64)> {
204 let (tx, mut rx) = mpsc::channel(100);
206
207 tokio::spawn(async move {
209 let mut commands = pin!(commands);
211 while let Some(command_result) = commands.next().await {
212 if tx.send(command_result).await.is_err() {
213 tracing::debug!("Import stream channel closed by receiver");
215 break;
216 }
217 }
218 });
219
220 let table = if workspace { "workspace_command" } else { "command" };
222
223 self.client
224 .conn_mut::<_, _, Report>(move |conn| {
225 let mut inserted = 0;
226 let mut skipped_or_updated = 0;
227 let tx = conn.transaction()?;
228 let mut stmt = if overwrite {
229 tx.prepare(&format!(
230 r#"INSERT INTO {table} (
231 id,
232 category,
233 source,
234 alias,
235 cmd,
236 flat_cmd,
237 description,
238 flat_description,
239 tags,
240 created_at
241 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
242 ON CONFLICT (cmd) DO UPDATE SET
243 alias = COALESCE(excluded.alias, alias),
244 cmd = excluded.cmd,
245 flat_cmd = excluded.flat_cmd,
246 description = COALESCE(excluded.description, description),
247 flat_description = COALESCE(excluded.flat_description, flat_description),
248 tags = COALESCE(excluded.tags, tags),
249 updated_at = excluded.created_at
250 RETURNING updated_at;"#
251 ))?
252 } else {
253 tx.prepare(&format!(
254 r#"INSERT OR IGNORE INTO {table} (
255 id,
256 category,
257 source,
258 alias,
259 cmd,
260 flat_cmd,
261 description,
262 flat_description,
263 tags,
264 created_at
265 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
266 RETURNING updated_at;"#,
267 ))?
268 };
269
270 while let Some(command_result) = rx.blocking_recv() {
272 let command = command_result?;
273
274 if let Some(ref filter) = filter {
276 let matches_filter = filter.is_match(&command.cmd)
278 || command.description.as_ref().is_some_and(|d| filter.is_match(d));
279 if !matches_filter {
280 continue;
281 }
282 }
283
284 let mut rows = stmt.query((
285 &command.id,
286 &command.category,
287 &command.source,
288 &command.alias,
289 &command.cmd,
290 &command.flat_cmd,
291 &command.description,
292 &command.flat_description,
293 serde_json::to_value(&command.tags)?,
294 &command.created_at,
295 ))?;
296
297 match rows.next()? {
298 None => skipped_or_updated += 1,
300 Some(r) => {
302 let updated_at = r.get::<_, Option<DateTime<Utc>>>(0)?;
303 match updated_at {
304 None => inserted += 1,
306 Some(_) => skipped_or_updated += 1,
308 }
309 }
310 }
311 }
312
313 drop(stmt);
314 tx.commit()?;
315 Ok((inserted, skipped_or_updated))
316 })
317 .await
318 .wrap_err("Couldn't import commands")
319 }
320
321 #[instrument(skip_all)]
323 pub async fn export_user_commands(
324 &self,
325 filter: Option<Regex>,
326 ) -> impl Stream<Item = Result<Command>> + Send + 'static {
327 let (tx, rx) = mpsc::channel(100);
329
330 let client = self.client.clone();
332 tokio::spawn(async move {
333 let res: Result<(), Report> = client
334 .conn_mut(move |conn| {
335 let mut q_values = vec![CATEGORY_USER.to_owned()];
337 let mut query = String::from(
338 r"SELECT
339 rowid,
340 id,
341 category,
342 source,
343 alias,
344 cmd,
345 flat_cmd,
346 description,
347 flat_description,
348 tags,
349 created_at,
350 updated_at
351 FROM command
352 WHERE category = ?1",
353 );
354 if let Some(filter) = filter {
355 q_values.push(filter.as_str().to_owned());
356 query.push_str(" AND (cmd REGEXP ?2 OR (description IS NOT NULL AND description REGEXP ?2))");
357 }
358 query.push_str("\nORDER BY cmd ASC");
359
360 let mut stmt = conn.prepare(&query)?;
362 let records_iter =
363 stmt.query_and_then(rusqlite::params_from_iter(q_values), |r| Command::try_from(r))?;
364
365 for record_result in records_iter {
367 if tx
368 .blocking_send(record_result.wrap_err("Error fetching command"))
369 .is_err()
370 {
371 tracing::debug!("Async stream receiver dropped, closing db query");
372 break;
373 }
374 }
375
376 Ok(())
377 })
378 .await;
379 if let Err(e) = res {
380 panic!("Couldn't fetch commands to export: {e:?}");
381 }
382 });
383
384 ReceiverStream::new(rx)
386 }
387
388 #[instrument(skip_all)]
390 pub async fn delete_tldr_commands(&self, category: Option<String>) -> Result<u64> {
391 self.client
392 .conn_mut::<_, _, Report>(move |conn| {
393 let mut query = String::from("DELETE FROM command WHERE source = ?1");
394 let mut params: Vec<String> = vec![SOURCE_TLDR.to_owned()];
395 if let Some(cat) = category {
396 query.push_str(" AND category = ?2");
397 params.push(cat);
398 }
399 let affected = conn.execute(&query, rusqlite::params_from_iter(params))?;
400 Ok(affected as u64)
401 })
402 .await
403 .wrap_err("Couldn't remove tldr commands")
404 }
405
406 #[instrument(skip_all)]
410 pub async fn insert_command(&self, command: Command) -> Result<Command, InsertError> {
411 self.client
412 .conn(move |conn| {
413 let res = conn.execute(
414 r#"INSERT INTO command (
415 id,
416 category,
417 source,
418 alias,
419 cmd,
420 flat_cmd,
421 description,
422 flat_description,
423 tags,
424 created_at,
425 updated_at
426 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)"#,
427 (
428 &command.id,
429 &command.category,
430 &command.source,
431 &command.alias,
432 &command.cmd,
433 &command.flat_cmd,
434 &command.description,
435 &command.flat_description,
436 serde_json::to_value(&command.tags).wrap_err("Couldn't insert a command")?,
437 &command.created_at,
438 &command.updated_at,
439 ),
440 );
441 match res {
442 Ok(_) => Ok(command),
443 Err(err) => {
444 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
445 if code == ffi::SQLITE_CONSTRAINT_UNIQUE || code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY {
446 Err(InsertError::AlreadyExists)
447 } else {
448 Err(Report::from(err).wrap_err("Couldn't insert a command").into())
449 }
450 }
451 }
452 })
453 .await
454 }
455
456 #[instrument(skip_all)]
460 pub async fn update_command(&self, command: Command) -> Result<Command, UpdateError> {
461 self.client
462 .conn(move |conn| {
463 let res = conn.execute(
464 r#"UPDATE command SET
465 category = ?2,
466 source = ?3,
467 alias = ?4,
468 cmd = ?5,
469 flat_cmd = ?6,
470 description = ?7,
471 flat_description = ?8,
472 tags = ?9,
473 created_at = ?10,
474 updated_at = ?11
475 WHERE id = ?1"#,
476 (
477 &command.id,
478 &command.category,
479 &command.source,
480 &command.alias,
481 &command.cmd,
482 &command.flat_cmd,
483 &command.description,
484 &command.flat_description,
485 serde_json::to_value(&command.tags).wrap_err("Couldn't update a command")?,
486 &command.created_at,
487 &command.updated_at,
488 ),
489 );
490 match res {
491 Ok(0) => Err(eyre!("Command not found: {}", command.id)
492 .wrap_err("Couldn't update a command")
493 .into()),
494 Ok(_) => Ok(command),
495 Err(err) => {
496 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
497 if code == ffi::SQLITE_CONSTRAINT_UNIQUE {
498 Err(UpdateError::AlreadyExists)
499 } else {
500 Err(Report::from(err).wrap_err("Couldn't update a command").into())
501 }
502 }
503 }
504 })
505 .await
506 }
507
508 #[instrument(skip_all)]
510 pub async fn increment_command_usage(
511 &self,
512 command_id: Uuid,
513 path: impl AsRef<str> + Send + 'static,
514 ) -> Result<i32, UpdateError> {
515 self.client
516 .conn_mut(move |conn| {
517 let res = conn.query_row(
518 r#"
519 INSERT INTO command_usage (command_id, path, usage_count)
520 VALUES (?1, ?2, 1)
521 ON CONFLICT(command_id, path) DO UPDATE SET
522 usage_count = usage_count + 1
523 RETURNING usage_count;"#,
524 (&command_id, &path.as_ref()),
525 |r| r.get(0),
526 );
527 match res {
528 Ok(u) => Ok(u),
529 Err(err) => Err(Report::from(err).wrap_err("Couldn't update a command usage").into()),
530 }
531 })
532 .await
533 }
534
535 #[instrument(skip_all)]
539 pub async fn delete_command(&self, command_id: Uuid) -> Result<()> {
540 self.client
541 .conn(move |conn| {
542 let res = conn.execute("DELETE FROM command WHERE id = ?1", (&command_id,));
543 match res {
544 Ok(0) => Err(eyre!("Command not found: {command_id}").wrap_err("Couldn't delete a command")),
545 Ok(_) => Ok(()),
546 Err(err) => Err(Report::from(err).wrap_err("Couldn't delete a command")),
547 }
548 })
549 .await
550 }
551}
552
553fn rerank_query_results(items: Vec<QueryResultItem>, tuning: &SearchCommandTuning) -> Vec<Command> {
562 if items.is_empty() {
564 return Vec::new();
565 }
566 if items.len() == 1 {
567 return items.into_iter().map(|item| item.command).collect();
568 }
569
570 let (template_matches, mut other_items): (Vec<_>, Vec<_>) = items
573 .into_iter()
574 .partition(|item| item.text_score >= TEMPLATE_MATCH_RANK);
575 if !template_matches.is_empty() {
576 tracing::trace!("Found {} template matches", template_matches.len());
577 }
578
579 let mut final_commands: Vec<Command> = template_matches.into_iter().map(|item| item.command).collect();
581
582 if other_items.len() <= 1 {
584 final_commands.extend(other_items.into_iter().map(|item| item.command));
585 return final_commands;
586 }
587
588 let mut min_text = f64::INFINITY;
590 let mut max_text = f64::NEG_INFINITY;
591 let mut min_path = f64::INFINITY;
592 let mut max_path = f64::NEG_INFINITY;
593 let mut min_usage = f64::INFINITY;
594 let mut max_usage = f64::NEG_INFINITY;
595 for item in &other_items {
596 min_text = min_text.min(item.text_score);
597 max_text = max_text.max(item.text_score);
598 min_path = min_path.min(item.path_score);
599 max_path = max_path.max(item.path_score);
600 min_usage = min_usage.min(item.usage_score);
601 max_usage = max_usage.max(item.usage_score);
602 }
603
604 let range_text = (max_text > min_text).then_some(max_text - min_text);
606 let range_path = (max_path > min_path).then_some(max_path - min_path);
607 let range_usage = (max_usage > min_usage).then_some(max_usage - min_usage);
608
609 other_items.sort_by(|a, b| {
611 match b.is_workspace_command.cmp(&a.is_workspace_command) {
613 Ordering::Equal => {
614 let calculate_score = |item: &QueryResultItem| -> f64 {
616 let norm_text = range_text.map_or(0.5, |range| (item.text_score - min_text) / range);
619 let norm_path = range_path.map_or(0.5, |range| (item.path_score - min_path) / range);
620 let norm_usage = range_usage.map_or(0.5, |range| (item.usage_score - min_usage) / range);
621
622 (norm_text * tuning.text.points as f64)
624 + (norm_path * tuning.path.points as f64)
625 + (norm_usage * tuning.usage.points as f64)
626 };
627
628 let final_score_a = calculate_score(a);
629 let final_score_b = calculate_score(b);
630
631 final_score_b.partial_cmp(&final_score_a).unwrap_or(Ordering::Equal)
633 }
634 other => other,
636 }
637 });
638
639 final_commands.extend(other_items.into_iter().map(|item| item.command));
641 final_commands
642}
643
644impl<'a> TryFrom<&'a Row<'a>> for Command {
645 type Error = rusqlite::Error;
646
647 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
648 Ok(Self {
649 id: row.get(1)?,
651 category: row.get(2)?,
652 source: row.get(3)?,
653 alias: row.get(4)?,
654 cmd: row.get(5)?,
655 flat_cmd: row.get(6)?,
656 description: row.get(7)?,
657 flat_description: row.get(8)?,
658 tags: serde_json::from_value(row.get::<_, serde_json::Value>(9)?)
659 .map_err(|e| rusqlite::Error::FromSqlConversionFailure(9, Type::Text, Box::new(e)))?,
660 created_at: row.get(10)?,
661 updated_at: row.get(11)?,
662 })
663 }
664}
665
666struct QueryResultItem {
668 command: Command,
670 is_workspace_command: bool,
672 usage_score: f64,
674 path_score: f64,
676 text_score: f64,
678}
679
680impl<'a> TryFrom<&'a Row<'a>> for QueryResultItem {
681 type Error = rusqlite::Error;
682
683 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
684 Ok(Self {
685 command: Command::try_from(row)?,
686 is_workspace_command: row.get(12)?,
687 usage_score: row.get(13)?,
688 path_score: row.get(14)?,
689 text_score: row.get(15)?,
690 })
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use futures_util::StreamExt;
697 use pretty_assertions::assert_eq;
698 use strum::IntoEnumIterator;
699 use tokio_stream::iter;
700 use uuid::Uuid;
701
702 use super::*;
703 use crate::model::{CATEGORY_USER, SOURCE_IMPORT, SOURCE_USER, SearchMode};
704
705 const PROJ_A_PATH: &str = "/home/user/project-a";
706 const PROJ_A_API_PATH: &str = "/home/user/project-a/api";
707 const PROJ_B_PATH: &str = "/home/user/project-b";
708 const UNRELATED_PATH: &str = "/var/log";
709
710 #[tokio::test]
711 async fn test_setup_workspace_storage() {
712 let storage = SqliteStorage::new_in_memory().await.unwrap();
713 storage.check_sqlite_version().await;
714 let res = storage.setup_workspace_storage().await;
715 assert!(res.is_ok(), "Expected workspace storage setup to succeed: {res:?}");
716 }
717
718 #[tokio::test]
719 async fn test_is_empty() {
720 let storage = SqliteStorage::new_in_memory().await.unwrap();
721 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
722
723 let cmd = Command {
724 id: Uuid::now_v7(),
725 cmd: "test_cmd".to_string(),
726 ..Default::default()
727 };
728 storage.insert_command(cmd).await.unwrap();
729
730 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
731 }
732
733 #[tokio::test]
734 async fn test_is_empty_with_workspace() {
735 let storage = SqliteStorage::new_in_memory().await.unwrap();
736 storage.setup_workspace_storage().await.unwrap();
737 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
738
739 let cmd = Command {
740 id: Uuid::now_v7(),
741 cmd: "test_cmd".to_string(),
742 ..Default::default()
743 };
744 storage.insert_command(cmd).await.unwrap();
745
746 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
747 }
748
749 #[tokio::test]
750 async fn test_find_tags_no_filters() -> Result<(), SearchError> {
751 let storage = setup_ranking_storage().await;
752
753 let result = storage
754 .find_tags(SearchCommandsFilter::default(), None, &SearchCommandTuning::default())
755 .await?;
756
757 let expected = vec![
758 ("#git".to_string(), 5, false),
759 ("#build".to_string(), 2, false),
760 ("#commit".to_string(), 2, false),
761 ("#docker".to_string(), 2, false),
762 ("#list".to_string(), 2, false),
763 ("#k8s".to_string(), 1, false),
764 ("#npm".to_string(), 1, false),
765 ("#pod".to_string(), 1, false),
766 ("#push".to_string(), 1, false),
767 ("#unix".to_string(), 1, false),
768 ];
769
770 assert_eq!(result.len(), 10, "Expected 10 unique tags");
771 assert_eq!(result, expected, "Tags list or order mismatch");
772
773 Ok(())
774 }
775
776 #[tokio::test]
777 async fn test_find_tags_filter_by_tags_only() -> Result<(), SearchError> {
778 let storage = setup_ranking_storage().await;
779
780 let filter1 = SearchCommandsFilter {
781 tags: Some(vec!["#git".to_string()]),
782 ..Default::default()
783 };
784 let result1 = storage
785 .find_tags(filter1, None, &SearchCommandTuning::default())
786 .await?;
787 let expected1 = vec![("#commit".to_string(), 2, false), ("#push".to_string(), 1, false)];
788 assert_eq!(result1.len(), 2,);
789 assert_eq!(result1, expected1);
790
791 let filter2 = SearchCommandsFilter {
792 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
793 ..Default::default()
794 };
795 let result2 = storage
796 .find_tags(filter2, None, &SearchCommandTuning::default())
797 .await?;
798 assert!(result2.is_empty());
799
800 let filter3 = SearchCommandsFilter {
801 tags: Some(vec!["#list".to_string()]),
802 ..Default::default()
803 };
804 let result3 = storage
805 .find_tags(filter3, None, &SearchCommandTuning::default())
806 .await?;
807 let expected3 = vec![("#docker".to_string(), 1, false), ("#unix".to_string(), 1, false)];
808 assert_eq!(result3.len(), 2);
809 assert_eq!(result3, expected3);
810
811 Ok(())
812 }
813
814 #[tokio::test]
815 async fn test_find_tags_filter_by_prefix_only() -> Result<(), SearchError> {
816 let storage = setup_ranking_storage().await;
817
818 let result = storage
819 .find_tags(
820 SearchCommandsFilter::default(),
821 Some("#comm".to_string()),
822 &SearchCommandTuning::default(),
823 )
824 .await?;
825 let expected = vec![("#commit".to_string(), 2, false)];
826 assert_eq!(result.len(), 1);
827 assert_eq!(result, expected);
828
829 Ok(())
830 }
831
832 #[tokio::test]
833 async fn test_find_tags_filter_by_tags_and_prefix() -> Result<(), SearchError> {
834 let storage = setup_ranking_storage().await;
835
836 let filter1 = SearchCommandsFilter {
837 tags: Some(vec!["#git".to_string()]),
838 ..Default::default()
839 };
840 let result1 = storage
841 .find_tags(filter1, Some("#comm".to_string()), &SearchCommandTuning::default())
842 .await?;
843 let expected1 = vec![("#commit".to_string(), 2, false)];
844 assert_eq!(result1.len(), 1);
845 assert_eq!(result1, expected1);
846
847 let filter2 = SearchCommandsFilter {
848 tags: Some(vec!["#git".to_string()]),
849 ..Default::default()
850 };
851 let result2 = storage
852 .find_tags(filter2, Some("#push".to_string()), &SearchCommandTuning::default())
853 .await?;
854 let expected2 = vec![("#push".to_string(), 1, true)];
855 assert_eq!(result2.len(), 1);
856 assert_eq!(result2, expected2);
857
858 Ok(())
859 }
860
861 #[tokio::test]
862 async fn test_find_commands_no_filter() {
863 let storage = setup_ranking_storage().await;
864 let filter = SearchCommandsFilter::default();
865 let (commands, _) = storage
866 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
867 .await
868 .unwrap();
869 assert_eq!(commands.len(), 10, "Expected all sample commands");
870 }
871
872 #[tokio::test]
873 async fn test_find_commands_filter_by_category() {
874 let storage = setup_ranking_storage().await;
875 let filter = SearchCommandsFilter {
876 category: Some(vec!["git".to_string()]),
877 ..Default::default()
878 };
879 let (commands, _) = storage
880 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
881 .await
882 .unwrap();
883 assert_eq!(commands.len(), 2);
884 assert!(commands.iter().all(|c| c.category == "git"));
885
886 let filter_no_match = SearchCommandsFilter {
887 category: Some(vec!["nonexistent".to_string()]),
888 ..Default::default()
889 };
890 let (commands_no_match, _) = storage
891 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
892 .await
893 .unwrap();
894 assert!(commands_no_match.is_empty());
895 }
896
897 #[tokio::test]
898 async fn test_find_commands_filter_by_source() {
899 let storage = setup_ranking_storage().await;
900 let filter = SearchCommandsFilter {
901 source: Some(SOURCE_TLDR.to_string()),
902 ..Default::default()
903 };
904 let (commands, _) = storage
905 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
906 .await
907 .unwrap();
908 assert_eq!(commands.len(), 3);
909 assert!(commands.iter().all(|c| c.source == SOURCE_TLDR));
910 }
911
912 #[tokio::test]
913 async fn test_find_commands_filter_by_tags() {
914 let storage = setup_ranking_storage().await;
915 let filter_single_tag = SearchCommandsFilter {
916 tags: Some(vec!["#git".to_string()]),
917 ..Default::default()
918 };
919 let (commands_single_tag, _) = storage
920 .find_commands(filter_single_tag, "/some/path", &SearchCommandTuning::default())
921 .await
922 .unwrap();
923 assert_eq!(commands_single_tag.len(), 5);
924
925 let filter_multiple_tags = SearchCommandsFilter {
926 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
927 ..Default::default()
928 };
929 let (commands_multiple_tags, _) = storage
930 .find_commands(filter_multiple_tags, "/some/path", &SearchCommandTuning::default())
931 .await
932 .unwrap();
933 assert_eq!(commands_multiple_tags.len(), 1);
934
935 let filter_empty_tags = SearchCommandsFilter {
936 tags: Some(vec![]),
937 ..Default::default()
938 };
939 let (commands_empty_tags, _) = storage
940 .find_commands(filter_empty_tags, "/some/path", &SearchCommandTuning::default())
941 .await
942 .unwrap();
943 assert_eq!(commands_empty_tags.len(), 10);
944 }
945
946 #[tokio::test]
947 async fn test_find_commands_alias_precedence() {
948 let storage = setup_ranking_storage().await;
949 storage
950 .setup_command(
951 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
952 [("/some/path", 100)],
953 )
954 .await;
955
956 for mode in SearchMode::iter() {
957 let filter = SearchCommandsFilter {
958 search_term: Some("gc".to_string()),
959 search_mode: mode,
960 ..Default::default()
961 };
962 let (commands, alias_match) = storage
963 .find_commands(filter, "", &SearchCommandTuning::default())
964 .await
965 .unwrap();
966 assert!(alias_match, "Expected alias match for mode {mode:?}");
967 assert_eq!(commands.len(), 1, "Expected only alias match for mode {mode:?}");
968 assert_eq!(
969 commands[0].cmd, "git commit -m",
970 "Expected correct alias command for mode {mode:?}"
971 );
972 }
973 }
974
975 #[tokio::test]
976 async fn test_find_commands_search_mode_exact() {
977 let storage = setup_ranking_storage().await;
978 let filter_token_match = SearchCommandsFilter {
979 search_term: Some("commit".to_string()),
980 search_mode: SearchMode::Exact,
981 ..Default::default()
982 };
983 let (commands_token_match, _) = storage
984 .find_commands(filter_token_match, "/some/path", &SearchCommandTuning::default())
985 .await
986 .unwrap();
987 assert_eq!(commands_token_match.len(), 2);
988 assert_eq!(commands_token_match[0].cmd, "git commit -m");
989 assert_eq!(commands_token_match[1].cmd, "git commit -m '{{message}}'");
990
991 let filter_no_match = SearchCommandsFilter {
992 search_term: Some("nonexistentterm".to_string()),
993 search_mode: SearchMode::Exact,
994 ..Default::default()
995 };
996 let (commands_no_match, _) = storage
997 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
998 .await
999 .unwrap();
1000 assert!(commands_no_match.is_empty());
1001 }
1002
1003 #[tokio::test]
1004 async fn test_find_commands_search_mode_relaxed() {
1005 let storage = setup_ranking_storage().await;
1006 let filter = SearchCommandsFilter {
1007 search_term: Some("docker list".to_string()),
1008 search_mode: SearchMode::Relaxed,
1009 ..Default::default()
1010 };
1011 let (commands, _) = storage
1012 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1013 .await
1014 .unwrap();
1015 assert_eq!(commands.len(), 2);
1016 assert!(commands.iter().any(|c| c.cmd == "docker ps -a"));
1017 assert!(commands.iter().any(|c| c.cmd == "ls -lha"));
1018 }
1019
1020 #[tokio::test]
1021 async fn test_find_commands_search_mode_regex() {
1022 let storage = setup_ranking_storage().await;
1023 let filter = SearchCommandsFilter {
1024 search_term: Some(r"git\s.*it".to_string()),
1025 search_mode: SearchMode::Regex,
1026 ..Default::default()
1027 };
1028 let (commands, _) = storage
1029 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1030 .await
1031 .unwrap();
1032 assert_eq!(commands.len(), 2);
1033 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
1034 assert_eq!(commands[1].cmd, "git commit -m");
1035
1036 let filter_invalid = SearchCommandsFilter {
1037 search_term: Some("[[invalid_regex".to_string()),
1038 search_mode: SearchMode::Regex,
1039 ..Default::default()
1040 };
1041 assert!(matches!(
1042 storage
1043 .find_commands(filter_invalid, "/some/path", &SearchCommandTuning::default())
1044 .await,
1045 Err(SearchError::InvalidRegex(_))
1046 ));
1047 }
1048
1049 #[tokio::test]
1050 async fn test_find_commands_search_mode_fuzzy() {
1051 let storage = setup_ranking_storage().await;
1052 let filter = SearchCommandsFilter {
1053 search_term: Some("gtcomit".to_string()),
1054 search_mode: SearchMode::Fuzzy,
1055 ..Default::default()
1056 };
1057 let (commands, _) = storage
1058 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1059 .await
1060 .unwrap();
1061 assert_eq!(commands.len(), 2);
1062 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
1063 assert_eq!(commands[1].cmd, "git commit -m");
1064
1065 let filter_empty_fuzzy = SearchCommandsFilter {
1066 search_term: Some("'' | ^".to_string()),
1067 search_mode: SearchMode::Fuzzy,
1068 ..Default::default()
1069 };
1070 assert!(matches!(
1071 storage
1072 .find_commands(filter_empty_fuzzy, "/some/path", &SearchCommandTuning::default())
1073 .await,
1074 Err(SearchError::InvalidFuzzy)
1075 ));
1076 }
1077
1078 #[tokio::test]
1079 async fn test_find_commands_search_mode_auto() {
1080 let storage = setup_ranking_storage().await;
1081 let default_tuning = SearchCommandTuning::default();
1082
1083 let run_search = |term: &'static str, path: &'static str| {
1085 let storage = storage.clone();
1086 async move {
1087 let filter = SearchCommandsFilter {
1088 search_term: Some(term.to_string()),
1089 search_mode: SearchMode::Auto,
1090 ..Default::default()
1091 };
1092 storage.find_commands(filter, path, &default_tuning).await.unwrap()
1093 }
1094 };
1095
1096 let (commands, _) = run_search("list containers", UNRELATED_PATH).await;
1098 assert!(!commands.is_empty(), "Expected results for 'list containers'");
1099 assert_eq!(
1100 commands[0].cmd, "docker ps -a",
1101 "Expected 'docker ps -a' to be the top result for 'list containers'"
1102 );
1103
1104 let (commands, _) = run_search("git commit", PROJ_A_PATH).await;
1106 assert!(commands.len() >= 2, "Expected at least two results for 'git commit'");
1107 assert_eq!(
1108 commands[0].cmd, "git commit -m",
1109 "Expected 'git commit -m' to be the top result for 'git commit' due to usage"
1110 );
1111 assert_eq!(
1112 commands[1].cmd, "git commit -m '{{message}}'",
1113 "Expected template command to be second for 'git commit'"
1114 );
1115
1116 let (commands, _) = run_search("git commit -m 'my new feature'", PROJ_A_PATH).await;
1118 assert!(!commands.is_empty(), "Expected results for template match");
1119 assert_eq!(
1120 commands[0].cmd, "git commit -m '{{message}}'",
1121 "Expected template command to be the top result for a matching search term"
1122 );
1123
1124 let (commands, _) = run_search("build", PROJ_A_API_PATH).await;
1126 assert!(!commands.is_empty(), "Expected results for 'build'");
1127 assert_eq!(
1128 commands[0].cmd, "npm run build:prod",
1129 "Expected 'npm run build:prod' to be top result for 'build' in its project path"
1130 );
1131
1132 let (commands, _) = run_search("gt sta", PROJ_A_PATH).await;
1134 assert!(!commands.is_empty(), "Expected results for fuzzy search 'gt sta'");
1135 assert_eq!(
1136 commands[0].cmd, "git status",
1137 "Expected 'git status' as top result for fuzzy search 'gt sta'"
1138 );
1139
1140 let (commands, _) = run_search("get pod monitoring", UNRELATED_PATH).await;
1142 assert!(!commands.is_empty(), "Expected results for 'get pod monitoring'");
1143 assert_eq!(
1144 commands[0].cmd, "kubectl get pod -n monitoring my-specific-pod-12345",
1145 "Expected specific 'kubectl' command to be found"
1146 );
1147
1148 let (commands, _) = run_search("status", PROJ_A_API_PATH).await;
1150 assert!(!commands.is_empty(), "Expected results for 'status'");
1151 assert_eq!(
1152 commands[0].cmd, "git status",
1153 "Expected 'git status' to be top due to high usage in parent path"
1154 );
1155 }
1156
1157 #[tokio::test]
1158 async fn test_find_commands_search_mode_auto_hastag_only() {
1159 let storage = setup_ranking_storage().await;
1160
1161 let filter = SearchCommandsFilter {
1164 search_term: Some("#".to_string()),
1165 search_mode: SearchMode::Auto,
1166 ..Default::default()
1167 };
1168
1169 let res = storage
1170 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1171 .await;
1172 assert!(res.is_ok(), "Expected a success response, got: {res:?}")
1173 }
1174
1175 #[tokio::test]
1176 async fn test_find_commands_including_workspace() {
1177 let storage = setup_ranking_storage().await;
1178
1179 storage.setup_workspace_storage().await.unwrap();
1180 let commands_to_import = vec![
1181 Command {
1182 id: Uuid::now_v7(),
1183 cmd: "cmd1".to_string(),
1184 ..Default::default()
1185 },
1186 Command {
1187 id: Uuid::now_v7(),
1188 cmd: "cmd2".to_string(),
1189 ..Default::default()
1190 },
1191 ];
1192 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1193 storage.import_commands(stream, None, false, true).await.unwrap();
1194
1195 let (commands, _) = storage
1196 .find_commands(
1197 SearchCommandsFilter::default(),
1198 "/some/path",
1199 &SearchCommandTuning::default(),
1200 )
1201 .await
1202 .unwrap();
1203 assert_eq!(commands.len(), 12, "Expected 12 commands including workspace");
1204 }
1205
1206 #[tokio::test]
1207 async fn test_find_commands_with_text_including_workspace() {
1208 let storage = setup_ranking_storage().await;
1209
1210 storage.setup_workspace_storage().await.unwrap();
1211 let commands_to_import = vec![Command {
1212 id: Uuid::now_v7(),
1213 cmd: "git checkout -b feature/{{name:kebab}}".to_string(),
1214 ..Default::default()
1215 }];
1216 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1217 storage.import_commands(stream, None, false, true).await.unwrap();
1218
1219 let filter = SearchCommandsFilter {
1220 search_term: Some("git".to_string()),
1221 ..Default::default()
1222 };
1223
1224 let (commands, _) = storage
1225 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1226 .await
1227 .unwrap();
1228 assert_eq!(commands.len(), 6, "Expected 6 git commands including workspace");
1229 assert!(
1230 commands
1231 .iter()
1232 .any(|c| c.cmd == "git checkout -b feature/{{name:kebab}}")
1233 );
1234 }
1235
1236 #[tokio::test]
1237 async fn test_import_commands_no_overwrite() {
1238 let storage = SqliteStorage::new_in_memory().await.unwrap();
1239
1240 let commands_to_import = vec![
1241 Command {
1242 id: Uuid::now_v7(),
1243 cmd: "cmd1".to_string(),
1244 ..Default::default()
1245 },
1246 Command {
1247 id: Uuid::now_v7(),
1248 cmd: "cmd2".to_string(),
1249 ..Default::default()
1250 },
1251 ];
1252
1253 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1254 let (inserted, skipped_or_updated) = storage.import_commands(stream, None, false, false).await.unwrap();
1255
1256 assert_eq!(inserted, 2, "Expected 2 commands inserted");
1257 assert_eq!(skipped_or_updated, 0, "Expected 0 commands skipped or updated");
1258
1259 let stream = iter(commands_to_import.into_iter().map(Ok));
1261 let (inserted, skipped_or_updated) = storage.import_commands(stream, None, false, false).await.unwrap();
1262
1263 assert_eq!(
1264 inserted, 0,
1265 "Expected 0 commands inserted on second import (no overwrite)"
1266 );
1267 assert_eq!(
1268 skipped_or_updated, 2,
1269 "Expected 2 commands skipped on second import (no overwrite)"
1270 );
1271 }
1272
1273 #[tokio::test]
1274 async fn test_import_commands_overwrite() {
1275 let storage = SqliteStorage::new_in_memory().await.unwrap();
1276
1277 let existing_cmd = Command {
1278 id: Uuid::now_v7(),
1279 cmd: "existing_cmd".to_string(),
1280 description: Some("original desc".to_string()),
1281 alias: Some("original_alias".to_string()),
1282 tags: Some(vec!["tag_a".to_string()]),
1283 ..Default::default()
1284 };
1285 storage.insert_command(existing_cmd.clone()).await.unwrap();
1286
1287 let new_cmd = Command {
1288 id: Uuid::now_v7(),
1289 cmd: "new_cmd".to_string(),
1290 ..Default::default()
1291 };
1292
1293 let commands_to_import = vec![
1295 Command {
1296 id: Uuid::now_v7(),
1297 cmd: "existing_cmd".to_string(),
1298 description: Some("updated desc".to_string()),
1299 alias: None,
1300 tags: Some(vec!["tag_b".to_string()]),
1301 ..Default::default()
1302 },
1303 new_cmd.clone(),
1304 ];
1305
1306 let stream = iter(commands_to_import.into_iter().map(Ok));
1307 let (inserted, skipped_or_updated) = storage.import_commands(stream, None, true, false).await.unwrap();
1308
1309 assert_eq!(inserted, 1, "Expected 1 new command inserted");
1310 assert_eq!(skipped_or_updated, 1, "Expected 1 existing command updated");
1311
1312 let filter = SearchCommandsFilter {
1314 search_term: Some("existing_cmd".to_string()),
1315 search_mode: SearchMode::Exact,
1316 ..Default::default()
1317 };
1318 let (found_commands, _) = storage
1319 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1320 .await
1321 .unwrap();
1322 assert_eq!(found_commands.len(), 1);
1323 let updated_cmd_in_db = &found_commands[0];
1324 assert_eq!(
1325 updated_cmd_in_db.description,
1326 Some("updated desc".to_string()),
1327 "Description should be updated"
1328 );
1329 assert_eq!(
1330 updated_cmd_in_db.alias,
1331 Some("original_alias".to_string()),
1332 "Alias should NOT be updated to NULL"
1333 );
1334 assert_eq!(
1335 updated_cmd_in_db.tags,
1336 Some(vec!["tag_b".to_string()]),
1337 "Tags should be updated"
1338 );
1339 }
1340
1341 #[tokio::test]
1342 async fn test_import_commands_with_filter() {
1343 let storage = SqliteStorage::new_in_memory().await.unwrap();
1344
1345 let commands_to_import = vec![
1346 Command {
1347 id: Uuid::now_v7(),
1348 cmd: "git commit".to_string(),
1349 ..Default::default()
1350 },
1351 Command {
1352 id: Uuid::now_v7(),
1353 cmd: "docker ps".to_string(),
1354 ..Default::default()
1355 },
1356 Command {
1357 id: Uuid::now_v7(),
1358 cmd: "git push".to_string(),
1359 ..Default::default()
1360 },
1361 ];
1362
1363 let filter = Some(Regex::new("^git").unwrap());
1364 let stream = iter(commands_to_import.into_iter().map(Ok));
1365 let (inserted, _) = storage.import_commands(stream, filter, false, false).await.unwrap();
1366
1367 assert_eq!(inserted, 2, "Expected 2 commands to be inserted");
1368
1369 let (all_commands, _) = storage
1370 .find_commands(
1371 SearchCommandsFilter::default(),
1372 "/some/path",
1373 &SearchCommandTuning::default(),
1374 )
1375 .await
1376 .unwrap();
1377 assert!(all_commands.iter().all(|c| c.cmd.starts_with("git")));
1378 assert!(!all_commands.iter().any(|c| c.cmd.starts_with("docker")));
1379 }
1380
1381 #[tokio::test]
1382 async fn test_import_workspace_commands() {
1383 let storage = SqliteStorage::new_in_memory().await.unwrap();
1384 storage.setup_workspace_storage().await.unwrap();
1385
1386 let commands_to_import = vec![
1387 Command {
1388 id: Uuid::now_v7(),
1389 cmd: "cmd1".to_string(),
1390 ..Default::default()
1391 },
1392 Command {
1393 id: Uuid::now_v7(),
1394 cmd: "cmd2".to_string(),
1395 ..Default::default()
1396 },
1397 ];
1398
1399 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1400 let (inserted, skipped_or_updated) = storage.import_commands(stream, None, false, true).await.unwrap();
1401
1402 assert_eq!(inserted, 2, "Expected 2 commands inserted");
1403 assert_eq!(skipped_or_updated, 0, "Expected 0 commands skipped or updated");
1404 }
1405
1406 #[tokio::test]
1407 async fn test_export_user_commands_no_filter() {
1408 let storage = setup_ranking_storage().await;
1409 let mut exported_commands = Vec::new();
1410 let mut stream = storage.export_user_commands(None).await;
1411 while let Some(Ok(cmd)) = stream.next().await {
1412 exported_commands.push(cmd);
1413 }
1414
1415 assert_eq!(exported_commands.len(), 7, "Expected 7 user commands to be exported");
1416 }
1417
1418 #[tokio::test]
1419 async fn test_export_user_commands_with_filter() {
1420 let storage = setup_ranking_storage().await;
1421 let filter = Regex::new(r"^git").unwrap(); let mut exported_commands = Vec::new();
1423 let mut stream = storage.export_user_commands(Some(filter)).await;
1424 while let Some(Ok(cmd)) = stream.next().await {
1425 exported_commands.push(cmd);
1426 }
1427
1428 assert_eq!(exported_commands.len(), 3, "Expected 3 git commands to be exported");
1429
1430 let exported_cmd_values: Vec<String> = exported_commands.into_iter().map(|c| c.cmd).collect();
1431 assert!(exported_cmd_values.contains(&"git status".to_string()));
1432 assert!(exported_cmd_values.contains(&"git checkout main".to_string()));
1433 }
1434
1435 #[tokio::test]
1436 async fn test_delete_tldr_commands() {
1437 let storage = SqliteStorage::new_in_memory().await.unwrap();
1438
1439 let tldr_cmd1 = Command {
1441 id: Uuid::now_v7(),
1442 category: "git".to_string(),
1443 source: SOURCE_TLDR.to_string(),
1444 cmd: "git status".to_string(),
1445 ..Default::default()
1446 };
1447 let tldr_cmd2 = Command {
1448 id: Uuid::now_v7(),
1449 category: "docker".to_string(),
1450 source: SOURCE_TLDR.to_string(),
1451 cmd: "docker ps".to_string(),
1452 ..Default::default()
1453 };
1454 let user_cmd = Command {
1455 id: Uuid::now_v7(),
1456 category: "git".to_string(),
1457 source: SOURCE_USER.to_string(),
1458 cmd: "git log".to_string(),
1459 ..Default::default()
1460 };
1461
1462 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1463 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1464 storage.insert_command(user_cmd.clone()).await.unwrap();
1465
1466 let removed = storage.delete_tldr_commands(None).await.unwrap();
1468 assert_eq!(removed, 2, "Should remove both tldr commands");
1469
1470 let (remaining, _) = storage
1471 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1472 .await
1473 .unwrap();
1474 assert_eq!(remaining.len(), 1, "Only user command should remain");
1475 assert_eq!(remaining[0].cmd, user_cmd.cmd);
1476
1477 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1479 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1480
1481 let removed_git = storage.delete_tldr_commands(Some("git".to_string())).await.unwrap();
1483 assert_eq!(removed_git, 1, "Should remove one tldr command in 'git' category");
1484
1485 let (remaining, _) = storage
1486 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1487 .await
1488 .unwrap();
1489 let remaining_cmds: Vec<_> = remaining.iter().map(|c| &c.cmd).collect();
1490 assert!(remaining_cmds.contains(&&tldr_cmd2.cmd));
1491 assert!(remaining_cmds.contains(&&user_cmd.cmd));
1492 assert!(!remaining_cmds.contains(&&tldr_cmd1.cmd));
1493 }
1494
1495 #[tokio::test]
1496 async fn test_insert_command() {
1497 let storage = SqliteStorage::new_in_memory().await.unwrap();
1498
1499 let mut cmd = Command {
1500 id: Uuid::now_v7(),
1501 category: "test".to_string(),
1502 cmd: "test_cmd".to_string(),
1503 description: Some("test desc".to_string()),
1504 tags: Some(vec!["tag1".to_string()]),
1505 ..Default::default()
1506 };
1507
1508 let mut inserted = storage.insert_command(cmd.clone()).await.unwrap();
1509 assert_eq!(inserted.cmd, cmd.cmd);
1510
1511 inserted.cmd = "other_cmd".to_string();
1513 match storage.insert_command(inserted).await {
1514 Err(InsertError::AlreadyExists) => (),
1515 _ => panic!("Expected AlreadyExists error on duplicate id"),
1516 }
1517
1518 cmd.id = Uuid::now_v7();
1520 match storage.insert_command(cmd).await {
1521 Err(InsertError::AlreadyExists) => (),
1522 _ => panic!("Expected AlreadyExists error on duplicate cmd"),
1523 }
1524 }
1525
1526 #[tokio::test]
1527 async fn test_update_command() {
1528 let storage = SqliteStorage::new_in_memory().await.unwrap();
1529
1530 let cmd = Command {
1531 id: Uuid::now_v7(),
1532 cmd: "original".to_string(),
1533 description: Some("desc".to_string()),
1534 ..Default::default()
1535 };
1536
1537 storage.insert_command(cmd.clone()).await.unwrap();
1538
1539 let mut updated = cmd.clone();
1540 updated.cmd = "updated".to_string();
1541 updated.description = Some("new desc".to_string());
1542
1543 let result = storage.update_command(updated.clone()).await.unwrap();
1544 assert_eq!(result.cmd, "updated");
1545 assert_eq!(result.description, Some("new desc".to_string()));
1546
1547 let mut non_existent = cmd;
1549 non_existent.id = Uuid::now_v7();
1550 match storage.update_command(non_existent).await {
1551 Err(_) => (),
1552 _ => panic!("Expected error when updating non-existent command"),
1553 }
1554
1555 let another_cmd = Command {
1557 id: Uuid::now_v7(),
1558 cmd: "another".to_string(),
1559 ..Default::default()
1560 };
1561 let mut result = storage.insert_command(another_cmd.clone()).await.unwrap();
1562 result.cmd = "updated".to_string();
1563 match storage.update_command(result).await {
1564 Err(UpdateError::AlreadyExists) => (),
1565 _ => panic!("Expected AlreadyExists error when updating to existing cmd"),
1566 }
1567 }
1568
1569 #[tokio::test]
1570 async fn test_increment_command_usage() {
1571 let storage = SqliteStorage::new_in_memory().await.unwrap();
1572
1573 let command = storage
1575 .setup_command(
1576 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
1577 [("/some/path", 100)],
1578 )
1579 .await;
1580
1581 let count = storage.increment_command_usage(command.id, "/path").await.unwrap();
1583 assert_eq!(count, 1);
1584
1585 let count = storage.increment_command_usage(command.id, "/some/path").await.unwrap();
1587 assert_eq!(count, 101);
1588 }
1589
1590 #[tokio::test]
1591 async fn test_delete_command() {
1592 let storage = SqliteStorage::new_in_memory().await.unwrap();
1593
1594 let cmd = Command {
1595 id: Uuid::now_v7(),
1596 cmd: "to_delete".to_string(),
1597 ..Default::default()
1598 };
1599
1600 let cmd = storage.insert_command(cmd).await.unwrap();
1601 let res = storage.delete_command(cmd.id).await;
1602 assert!(res.is_ok());
1603
1604 match storage.delete_command(cmd.id).await {
1606 Err(_) => (),
1607 _ => panic!("Expected error when deleting non-existent command"),
1608 }
1609 }
1610
1611 async fn setup_ranking_storage() -> SqliteStorage {
1613 let storage = SqliteStorage::new_in_memory().await.unwrap();
1614 storage
1615 .setup_command(
1616 Command::new(
1617 CATEGORY_USER,
1618 SOURCE_USER,
1619 "kubectl get pod -n monitoring my-specific-pod-12345",
1620 )
1621 .with_description(Some(
1622 "Get a very specific pod by its full name in the monitoring namespace".to_string(),
1623 ))
1624 .with_tags(Some(vec!["#k8s".to_string(), "#pod".to_string()])),
1625 [("/other/path", 1)],
1626 )
1627 .await;
1628 storage
1629 .setup_command(
1630 Command::new(CATEGORY_USER, SOURCE_USER, "git status")
1631 .with_description(Some("Check the status of the git repository".to_string()))
1632 .with_tags(Some(vec!["#git".to_string()])),
1633 [(PROJ_A_PATH, 50), (PROJ_B_PATH, 50), (UNRELATED_PATH, 100)],
1634 )
1635 .await;
1636 storage
1637 .setup_command(
1638 Command::new(CATEGORY_USER, SOURCE_USER, "npm run build:prod")
1639 .with_description(Some("Build the project for production".to_string()))
1640 .with_tags(Some(vec!["#npm".to_string(), "#build".to_string()])),
1641 [(PROJ_A_API_PATH, 25)],
1642 )
1643 .await;
1644 storage
1645 .setup_command(
1646 Command::new(CATEGORY_USER, SOURCE_USER, "container-image-build.sh")
1647 .with_description(Some("A generic script to build a container image".to_string()))
1648 .with_tags(Some(vec!["#docker".to_string(), "#build".to_string()])),
1649 [(UNRELATED_PATH, 35)],
1650 )
1651 .await;
1652 storage
1653 .setup_command(
1654 Command::new(CATEGORY_USER, SOURCE_USER, "git commit -m '{{message}}'")
1655 .with_description(Some("Commit with a message".to_string()))
1656 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1657 [(PROJ_A_PATH, 10), (PROJ_B_PATH, 10)],
1658 )
1659 .await;
1660 storage
1661 .setup_command(
1662 Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")
1663 .with_alias(Some("gco".to_string()))
1664 .with_description(Some("Checkout the main branch".to_string()))
1665 .with_tags(Some(vec!["#git".to_string()])),
1666 [(PROJ_A_PATH, 30), (PROJ_B_PATH, 30)],
1667 )
1668 .await;
1669 storage
1670 .setup_command(
1671 Command::new("git", SOURCE_TLDR, "git commit -m")
1672 .with_alias(Some("gc".to_string()))
1673 .with_description(Some("Commit changes".to_string()))
1674 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1675 [(PROJ_A_PATH, 15)],
1676 )
1677 .await;
1678 storage
1679 .setup_command(
1680 Command::new("docker", SOURCE_TLDR, "docker ps -a")
1681 .with_description(Some("List all containers".to_string()))
1682 .with_tags(Some(vec!["#docker".to_string(), "#list".to_string()])),
1683 [(PROJ_A_PATH, 5), (PROJ_B_PATH, 5)],
1684 )
1685 .await;
1686 storage
1687 .setup_command(
1688 Command::new("git", SOURCE_TLDR, "git push")
1689 .with_description(Some("Push changes".to_string()))
1690 .with_tags(Some(vec!["#git".to_string(), "#push".to_string()])),
1691 [(PROJ_A_PATH, 20), (PROJ_B_PATH, 20)],
1692 )
1693 .await;
1694 storage
1695 .setup_command(
1696 Command::new(CATEGORY_USER, SOURCE_IMPORT, "ls -lha")
1697 .with_description(Some("List files".to_string()))
1698 .with_tags(Some(vec!["#unix".to_string(), "#list".to_string()])),
1699 [(PROJ_A_PATH, 100), (PROJ_B_PATH, 100), (UNRELATED_PATH, 100)],
1700 )
1701 .await;
1702
1703 storage
1704 }
1705
1706 impl SqliteStorage {
1707 async fn check_sqlite_version(&self) {
1709 let version: String = self
1710 .client
1711 .conn_mut::<_, _, Report>(|conn| {
1712 conn.query_row("SELECT sqlite_version()", [], |row| row.get(0))
1713 .map_err(Into::into)
1714 })
1715 .await
1716 .unwrap();
1717 println!("Running with SQLite version: {version}");
1718 }
1719
1720 async fn setup_command(
1723 &self,
1724 command: Command,
1725 usage: impl IntoIterator<Item = (&str, i32)> + Send + 'static,
1726 ) -> Command {
1727 let command = self.insert_command(command).await.unwrap();
1728 self.client
1729 .conn_mut::<_, _, Report>(move |conn| {
1730 for (path, usage_count) in usage {
1731 conn.execute(
1732 r#"
1733 INSERT INTO command_usage (command_id, path, usage_count)
1734 VALUES (?1, ?2, ?3)
1735 ON CONFLICT(command_id, path) DO UPDATE SET
1736 usage_count = excluded.usage_count"#,
1737 (&command.id, path, usage_count),
1738 )?;
1739 }
1740 Ok(command)
1741 })
1742 .await
1743 .unwrap()
1744 }
1745 }
1746}