1use std::{cmp::Ordering, pin::pin, sync::atomic::Ordering as AtomicOrdering};
2
3use chrono::{DateTime, Utc};
4use color_eyre::{Report, eyre::eyre};
5use futures_util::StreamExt;
6use regex::Regex;
7use rusqlite::{Row, fallible_iterator::FallibleIterator, ffi, types::Type};
8use sea_query::SqliteQueryBuilder;
9use sea_query_rusqlite::RusqliteBinder;
10use tokio::sync::mpsc;
11use tokio_stream::{Stream, wrappers::ReceiverStream};
12use tracing::instrument;
13use uuid::Uuid;
14
15use super::{SqliteStorage, queries::*};
16use crate::{
17 config::SearchCommandTuning,
18 errors::{AppError, Result, UserFacingError},
19 model::{CATEGORY_USER, Command, SOURCE_TLDR, SearchCommandsFilter},
20};
21
22impl SqliteStorage {
23 #[instrument(skip_all)]
26 pub async fn setup_workspace_storage(&self) -> Result<()> {
27 self.client
28 .conn_mut(|conn| {
29 let schemas: Vec<String> = conn
31 .prepare(
32 r"SELECT sql
33 FROM sqlite_master
34 WHERE (type = 'table' AND name = 'command')
35 OR (type = 'table' AND name LIKE 'command_%fts')
36 OR (type = 'trigger' AND name LIKE 'command_%_fts' AND tbl_name = 'command')",
37 )?
38 .query_map([], |row| row.get(0))?
39 .collect::<Result<Vec<String>, _>>()?;
40
41 let tx = conn.transaction()?;
42
43 for schema in schemas {
45 let temp_schema = schema
46 .replace("command", "workspace_command")
47 .replace("CREATE TABLE", "CREATE TEMP TABLE")
48 .replace("CREATE VIRTUAL TABLE ", "CREATE VIRTUAL TABLE temp.")
49 .replace("CREATE TRIGGER", "CREATE TEMP TRIGGER");
50 tx.execute(&temp_schema, [])?;
51 }
52
53 tx.commit()?;
54 Ok(())
55 })
56 .await?;
57
58 self.workspace_tables_loaded.store(true, AtomicOrdering::SeqCst);
59
60 Ok(())
61 }
62
63 #[instrument(skip_all)]
65 pub async fn is_empty(&self) -> Result<bool> {
66 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
67 self.client
68 .conn(move |conn| {
69 if workspace_tables_loaded {
70 Ok(conn.query_row(
71 "SELECT NOT EXISTS (SELECT 1 FROM command UNION ALL SELECT 1 FROM workspace_command)",
72 [],
73 |r| r.get(0),
74 )?)
75 } else {
76 Ok(conn.query_row("SELECT NOT EXISTS(SELECT 1 FROM command)", [], |r| r.get(0))?)
77 }
78 })
79 .await
80 }
81
82 #[instrument(skip_all)]
84 pub async fn find_tags(
85 &self,
86 filter: SearchCommandsFilter,
87 tag_prefix: Option<String>,
88 tuning: &SearchCommandTuning,
89 ) -> Result<Vec<(String, u64, bool)>> {
90 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
91 let query = query_find_tags(filter, tag_prefix, tuning, workspace_tables_loaded)?;
92 if tracing::enabled!(tracing::Level::TRACE) {
93 tracing::trace!("Querying tags:\n{}", query.to_string(SqliteQueryBuilder));
94 }
95 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
96 self.client
97 .conn(move |conn| {
98 conn.prepare(&stmt)?
99 .query(&*values.as_params())?
100 .and_then(|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))
101 .collect()
102 })
103 .await
104 }
105
106 #[instrument(skip_all)]
111 pub async fn find_commands(
112 &self,
113 filter: SearchCommandsFilter,
114 working_path: impl Into<String>,
115 tuning: &SearchCommandTuning,
116 ) -> Result<(Vec<Command>, bool)> {
117 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
118 let cleaned_filter = filter.cleaned();
119
120 let mut query_alias = None;
122 if let Some(ref term) = cleaned_filter.search_term {
123 if workspace_tables_loaded {
125 query_alias = Some((
126 format!(
127 r#"SELECT *
128 FROM (
129 SELECT rowid, * FROM workspace_command
130 UNION ALL
131 SELECT rowid, * FROM command
132 ) c
133 WHERE c.alias IS NOT NULL AND c.alias = ?1
134 LIMIT {QUERY_LIMIT}"#
135 ),
136 (term.clone(),),
137 ));
138 } else {
139 query_alias = Some((
140 format!(
141 r#"SELECT c.rowid, c.*
142 FROM command c
143 WHERE c.alias IS NOT NULL AND c.alias = ?1
144 LIMIT {QUERY_LIMIT}"#
145 ),
146 (term.clone(),),
147 ));
148 }
149 }
150
151 let query = query_find_commands(cleaned_filter, working_path, tuning, workspace_tables_loaded)?;
153 let query_trace = if tracing::enabled!(tracing::Level::TRACE) {
154 query.to_string(SqliteQueryBuilder)
155 } else {
156 String::default()
157 };
158 let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
159
160 let tuning = *tuning;
162 self.client
163 .conn(move |conn| {
164 if let Some((query_alias, a_params)) = query_alias {
166 let rows = conn
168 .prepare(&query_alias)?
169 .query(a_params)?
170 .map(|r| Command::try_from(r))
171 .collect::<Vec<_>>()?;
172 if !rows.is_empty() {
174 return Ok((rows, true));
175 }
176 }
177 if tracing::enabled!(tracing::Level::TRACE) {
179 tracing::trace!("Querying commands:\n{query_trace}");
180 }
181 Ok((
182 rerank_query_results(
183 conn.prepare(&stmt)?
184 .query(&*values.as_params())?
185 .and_then(|r| QueryResultItem::try_from(r))
186 .collect::<Result<Vec<_>, _>>()?,
187 &tuning,
188 ),
189 false,
190 ))
191 })
192 .await
193 }
194
195 #[instrument(skip_all)]
202 pub async fn import_commands(
203 &self,
204 commands: impl Stream<Item = Result<Command>> + Send + 'static,
205 overwrite: bool,
206 workspace: bool,
207 ) -> Result<(u64, u64)> {
208 let (tx, mut rx) = mpsc::channel(100);
210
211 tokio::spawn(async move {
213 let mut commands = pin!(commands);
215 while let Some(command_result) = commands.next().await {
216 if tx.send(command_result).await.is_err() {
217 tracing::debug!("Import stream channel closed by receiver");
219 break;
220 }
221 }
222 });
223
224 let table = if workspace { "workspace_command" } else { "command" };
226
227 self.client
228 .conn_mut(move |conn| {
229 let mut inserted = 0;
230 let mut skipped_or_updated = 0;
231 let tx = conn.transaction()?;
232 let mut stmt = if overwrite {
233 tx.prepare(&format!(
234 r#"INSERT INTO {table} (
235 id,
236 category,
237 source,
238 alias,
239 cmd,
240 flat_cmd,
241 description,
242 flat_description,
243 tags,
244 created_at
245 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
246 ON CONFLICT (cmd) DO UPDATE SET
247 alias = COALESCE(excluded.alias, alias),
248 cmd = excluded.cmd,
249 flat_cmd = excluded.flat_cmd,
250 description = COALESCE(excluded.description, description),
251 flat_description = COALESCE(excluded.flat_description, flat_description),
252 tags = COALESCE(excluded.tags, tags),
253 updated_at = excluded.created_at
254 RETURNING updated_at;"#
255 ))?
256 } else {
257 tx.prepare(&format!(
258 r#"INSERT OR IGNORE INTO {table} (
259 id,
260 category,
261 source,
262 alias,
263 cmd,
264 flat_cmd,
265 description,
266 flat_description,
267 tags,
268 created_at
269 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
270 RETURNING updated_at;"#,
271 ))?
272 };
273
274 while let Some(command_result) = rx.blocking_recv() {
276 let command = command_result?;
277
278 let mut rows = stmt.query((
279 &command.id,
280 &command.category,
281 &command.source,
282 &command.alias,
283 &command.cmd,
284 &command.flat_cmd,
285 &command.description,
286 &command.flat_description,
287 serde_json::to_value(&command.tags)?,
288 &command.created_at,
289 ))?;
290
291 match rows.next()? {
292 None => skipped_or_updated += 1,
294 Some(r) => {
296 let updated_at = r.get::<_, Option<DateTime<Utc>>>(0)?;
297 match updated_at {
298 None => inserted += 1,
300 Some(_) => skipped_or_updated += 1,
302 }
303 }
304 }
305 }
306
307 drop(stmt);
308 tx.commit()?;
309 Ok((inserted, skipped_or_updated))
310 })
311 .await
312 }
313
314 #[instrument(skip_all)]
316 pub async fn export_user_commands(
317 &self,
318 filter: Option<Regex>,
319 ) -> impl Stream<Item = Result<Command>> + Send + 'static {
320 let (tx, rx) = mpsc::channel(100);
322
323 let client = self.client.clone();
325 tokio::spawn(async move {
326 let res = client
327 .conn_mut(move |conn| {
328 let mut q_values = vec![CATEGORY_USER.to_owned()];
330 let mut query = String::from(
331 r"SELECT
332 rowid,
333 id,
334 category,
335 source,
336 alias,
337 cmd,
338 flat_cmd,
339 description,
340 flat_description,
341 tags,
342 created_at,
343 updated_at
344 FROM command
345 WHERE category = ?1",
346 );
347 if let Some(filter) = filter {
348 q_values.push(filter.as_str().to_owned());
349 query.push_str(" AND (cmd REGEXP ?2 OR (description IS NOT NULL AND description REGEXP ?2))");
350 }
351 query.push_str("\nORDER BY cmd ASC");
352
353 let mut stmt = conn.prepare(&query)?;
355 let records_iter =
356 stmt.query_and_then(rusqlite::params_from_iter(q_values), |r| Command::try_from(r))?;
357
358 for record_result in records_iter {
360 if tx.blocking_send(record_result.map_err(AppError::from)).is_err() {
361 tracing::debug!("Async stream receiver dropped, closing db query");
362 break;
363 }
364 }
365
366 Ok(())
367 })
368 .await;
369 if let Err(err) = res {
370 panic!("Couldn't fetch commands to export: {err:?}");
371 }
372 });
373
374 ReceiverStream::new(rx)
376 }
377
378 #[instrument(skip_all)]
380 pub async fn delete_tldr_commands(&self, category: Option<String>) -> Result<u64> {
381 self.client
382 .conn_mut(move |conn| {
383 let mut query = String::from("DELETE FROM command WHERE source = ?1");
384 let mut params: Vec<String> = vec![SOURCE_TLDR.to_owned()];
385 if let Some(cat) = category {
386 query.push_str(" AND category = ?2");
387 params.push(cat);
388 }
389 let affected = conn.execute(&query, rusqlite::params_from_iter(params))?;
390 Ok(affected as u64)
391 })
392 .await
393 }
394
395 #[instrument(skip_all)]
399 pub async fn insert_command(&self, command: Command) -> Result<Command> {
400 self.client
401 .conn(move |conn| {
402 let res = conn.execute(
403 r#"INSERT INTO command (
404 id,
405 category,
406 source,
407 alias,
408 cmd,
409 flat_cmd,
410 description,
411 flat_description,
412 tags,
413 created_at,
414 updated_at
415 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)"#,
416 (
417 &command.id,
418 &command.category,
419 &command.source,
420 &command.alias,
421 &command.cmd,
422 &command.flat_cmd,
423 &command.description,
424 &command.flat_description,
425 serde_json::to_value(&command.tags)?,
426 &command.created_at,
427 &command.updated_at,
428 ),
429 );
430 match res {
431 Ok(_) => Ok(command),
432 Err(err) => {
433 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
434 if code == ffi::SQLITE_CONSTRAINT_UNIQUE || code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY {
435 Err(UserFacingError::CommandAlreadyExists.into())
436 } else {
437 Err(Report::from(err).into())
438 }
439 }
440 }
441 })
442 .await
443 }
444
445 #[instrument(skip_all)]
449 pub async fn update_command(&self, command: Command) -> Result<Command> {
450 self.client
451 .conn(move |conn| {
452 let res = conn.execute(
453 r#"UPDATE command SET
454 category = ?2,
455 source = ?3,
456 alias = ?4,
457 cmd = ?5,
458 flat_cmd = ?6,
459 description = ?7,
460 flat_description = ?8,
461 tags = ?9,
462 created_at = ?10,
463 updated_at = ?11
464 WHERE id = ?1"#,
465 (
466 &command.id,
467 &command.category,
468 &command.source,
469 &command.alias,
470 &command.cmd,
471 &command.flat_cmd,
472 &command.description,
473 &command.flat_description,
474 serde_json::to_value(&command.tags)?,
475 &command.created_at,
476 &command.updated_at,
477 ),
478 );
479 match res {
480 Ok(0) => Err(eyre!("Command not found: {}", command.id).into()),
481 Ok(_) => Ok(command),
482 Err(err) => {
483 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
484 if code == ffi::SQLITE_CONSTRAINT_UNIQUE {
485 Err(UserFacingError::CommandAlreadyExists.into())
486 } else {
487 Err(Report::from(err).into())
488 }
489 }
490 }
491 })
492 .await
493 }
494
495 #[instrument(skip_all)]
497 pub async fn increment_command_usage(
498 &self,
499 command_id: Uuid,
500 path: impl AsRef<str> + Send + 'static,
501 ) -> Result<i32> {
502 self.client
503 .conn_mut(move |conn| {
504 Ok(conn.query_row(
505 r#"
506 INSERT INTO command_usage (command_id, path, usage_count)
507 VALUES (?1, ?2, 1)
508 ON CONFLICT(command_id, path) DO UPDATE SET
509 usage_count = usage_count + 1
510 RETURNING usage_count;"#,
511 (&command_id, &path.as_ref()),
512 |r| r.get(0),
513 )?)
514 })
515 .await
516 }
517
518 #[instrument(skip_all)]
522 pub async fn delete_command(&self, command_id: Uuid) -> Result<()> {
523 self.client
524 .conn(move |conn| {
525 let res = conn.execute("DELETE FROM command WHERE id = ?1", (&command_id,));
526 match res {
527 Ok(0) => Err(eyre!("Command not found: {command_id}").into()),
528 Ok(_) => Ok(()),
529 Err(err) => Err(Report::from(err).into()),
530 }
531 })
532 .await
533 }
534}
535
536fn rerank_query_results(items: Vec<QueryResultItem>, tuning: &SearchCommandTuning) -> Vec<Command> {
545 if items.is_empty() {
547 return Vec::new();
548 }
549 if items.len() == 1 {
550 return items.into_iter().map(|item| item.command).collect();
551 }
552
553 let (template_matches, mut other_items): (Vec<_>, Vec<_>) = items
556 .into_iter()
557 .partition(|item| item.text_score >= TEMPLATE_MATCH_RANK);
558 if !template_matches.is_empty() {
559 tracing::trace!("Found {} template matches", template_matches.len());
560 }
561
562 let mut final_commands: Vec<Command> = template_matches.into_iter().map(|item| item.command).collect();
564
565 if other_items.len() <= 1 {
567 final_commands.extend(other_items.into_iter().map(|item| item.command));
568 return final_commands;
569 }
570
571 let mut min_text = 0f64;
574 let mut min_path = 0f64;
575 let mut min_usage = 0f64;
576 let mut max_text = f64::NEG_INFINITY;
577 let mut max_path = f64::NEG_INFINITY;
578 let mut max_usage = f64::NEG_INFINITY;
579 for item in &other_items {
580 min_text = min_text.min(item.text_score);
581 max_text = max_text.max(item.text_score);
582 min_path = min_path.min(item.path_score);
583 max_path = max_path.max(item.path_score);
584 min_usage = min_usage.min(item.usage_score);
585 max_usage = max_usage.max(item.usage_score);
586 }
587
588 let range_text = (max_text > min_text).then_some(max_text - min_text);
590 let range_path = (max_path > min_path).then_some(max_path - min_path);
591 let range_usage = (max_usage > min_usage).then_some(max_usage - min_usage);
592
593 other_items.sort_by(|a, b| {
595 match b.is_workspace_command.cmp(&a.is_workspace_command) {
597 Ordering::Equal => {
598 let calculate_score = |item: &QueryResultItem| -> f64 {
600 let norm_text = range_text.map_or(0.5, |range| (item.text_score - min_text) / range);
603 let norm_path = range_path.map_or(0.5, |range| (item.path_score - min_path) / range);
604 let norm_usage = range_usage.map_or(0.5, |range| (item.usage_score - min_usage) / range);
605
606 (norm_text * tuning.text.points as f64)
608 + (norm_path * tuning.path.points as f64)
609 + (norm_usage * tuning.usage.points as f64)
610 };
611
612 let final_score_a = calculate_score(a);
613 let final_score_b = calculate_score(b);
614
615 final_score_b.partial_cmp(&final_score_a).unwrap_or(Ordering::Equal)
617 }
618 other => other,
620 }
621 });
622
623 final_commands.extend(other_items.into_iter().map(|item| item.command));
625 final_commands
626}
627
628impl<'a> TryFrom<&'a Row<'a>> for Command {
629 type Error = rusqlite::Error;
630
631 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
632 Ok(Self {
633 id: row.get(1)?,
635 category: row.get(2)?,
636 source: row.get(3)?,
637 alias: row.get(4)?,
638 cmd: row.get(5)?,
639 flat_cmd: row.get(6)?,
640 description: row.get(7)?,
641 flat_description: row.get(8)?,
642 tags: serde_json::from_value(row.get::<_, serde_json::Value>(9)?)
643 .map_err(|e| rusqlite::Error::FromSqlConversionFailure(9, Type::Text, Box::new(e)))?,
644 created_at: row.get(10)?,
645 updated_at: row.get(11)?,
646 })
647 }
648}
649
650struct QueryResultItem {
652 command: Command,
654 is_workspace_command: bool,
656 usage_score: f64,
658 path_score: f64,
660 text_score: f64,
662}
663
664impl<'a> TryFrom<&'a Row<'a>> for QueryResultItem {
665 type Error = rusqlite::Error;
666
667 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
668 Ok(Self {
669 command: Command::try_from(row)?,
670 is_workspace_command: row.get(12)?,
671 usage_score: row.get(13)?,
672 path_score: row.get(14)?,
673 text_score: row.get(15)?,
674 })
675 }
676}
677
678#[cfg(test)]
679mod tests {
680 use futures_util::StreamExt;
681 use pretty_assertions::assert_eq;
682 use strum::IntoEnumIterator;
683 use tokio_stream::iter;
684 use uuid::Uuid;
685
686 use super::*;
687 use crate::{
688 errors::AppError,
689 model::{CATEGORY_USER, SOURCE_IMPORT, SOURCE_USER, SearchMode},
690 };
691
692 const PROJ_A_PATH: &str = "/home/user/project-a";
693 const PROJ_A_API_PATH: &str = "/home/user/project-a/api";
694 const PROJ_B_PATH: &str = "/home/user/project-b";
695 const UNRELATED_PATH: &str = "/var/log";
696
697 #[tokio::test]
698 async fn test_setup_workspace_storage() {
699 let storage = SqliteStorage::new_in_memory().await.unwrap();
700 storage.check_sqlite_version().await;
701 let res = storage.setup_workspace_storage().await;
702 assert!(res.is_ok(), "Expected workspace storage setup to succeed: {res:?}");
703 }
704
705 #[tokio::test]
706 async fn test_is_empty() {
707 let storage = SqliteStorage::new_in_memory().await.unwrap();
708 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
709
710 let cmd = Command {
711 id: Uuid::now_v7(),
712 cmd: "test_cmd".to_string(),
713 ..Default::default()
714 };
715 storage.insert_command(cmd).await.unwrap();
716
717 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
718 }
719
720 #[tokio::test]
721 async fn test_is_empty_with_workspace() {
722 let storage = SqliteStorage::new_in_memory().await.unwrap();
723 storage.setup_workspace_storage().await.unwrap();
724 assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
725
726 let cmd = Command {
727 id: Uuid::now_v7(),
728 cmd: "test_cmd".to_string(),
729 ..Default::default()
730 };
731 storage.insert_command(cmd).await.unwrap();
732
733 assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
734 }
735
736 #[tokio::test]
737 async fn test_find_tags_no_filters() -> Result<()> {
738 let storage = setup_ranking_storage().await;
739
740 let result = storage
741 .find_tags(SearchCommandsFilter::default(), None, &SearchCommandTuning::default())
742 .await?;
743
744 let expected = vec![
745 ("#git".to_string(), 5, false),
746 ("#build".to_string(), 2, false),
747 ("#commit".to_string(), 2, false),
748 ("#docker".to_string(), 2, false),
749 ("#list".to_string(), 2, false),
750 ("#k8s".to_string(), 1, false),
751 ("#npm".to_string(), 1, false),
752 ("#pod".to_string(), 1, false),
753 ("#push".to_string(), 1, false),
754 ("#unix".to_string(), 1, false),
755 ];
756
757 assert_eq!(result.len(), 10, "Expected 10 unique tags");
758 assert_eq!(result, expected, "Tags list or order mismatch");
759
760 Ok(())
761 }
762
763 #[tokio::test]
764 async fn test_find_tags_filter_by_tags_only() -> Result<()> {
765 let storage = setup_ranking_storage().await;
766
767 let filter1 = SearchCommandsFilter {
768 tags: Some(vec!["#git".to_string()]),
769 ..Default::default()
770 };
771 let result1 = storage
772 .find_tags(filter1, None, &SearchCommandTuning::default())
773 .await?;
774 let expected1 = vec![("#commit".to_string(), 2, false), ("#push".to_string(), 1, false)];
775 assert_eq!(result1.len(), 2,);
776 assert_eq!(result1, expected1);
777
778 let filter2 = SearchCommandsFilter {
779 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
780 ..Default::default()
781 };
782 let result2 = storage
783 .find_tags(filter2, None, &SearchCommandTuning::default())
784 .await?;
785 assert!(result2.is_empty());
786
787 let filter3 = SearchCommandsFilter {
788 tags: Some(vec!["#list".to_string()]),
789 ..Default::default()
790 };
791 let result3 = storage
792 .find_tags(filter3, None, &SearchCommandTuning::default())
793 .await?;
794 let expected3 = vec![("#docker".to_string(), 1, false), ("#unix".to_string(), 1, false)];
795 assert_eq!(result3.len(), 2);
796 assert_eq!(result3, expected3);
797
798 Ok(())
799 }
800
801 #[tokio::test]
802 async fn test_find_tags_filter_by_prefix_only() -> Result<()> {
803 let storage = setup_ranking_storage().await;
804
805 let result = storage
806 .find_tags(
807 SearchCommandsFilter::default(),
808 Some("#comm".to_string()),
809 &SearchCommandTuning::default(),
810 )
811 .await?;
812 let expected = vec![("#commit".to_string(), 2, false)];
813 assert_eq!(result.len(), 1);
814 assert_eq!(result, expected);
815
816 Ok(())
817 }
818
819 #[tokio::test]
820 async fn test_find_tags_filter_by_tags_and_prefix() -> Result<()> {
821 let storage = setup_ranking_storage().await;
822
823 let filter1 = SearchCommandsFilter {
824 tags: Some(vec!["#git".to_string()]),
825 ..Default::default()
826 };
827 let result1 = storage
828 .find_tags(filter1, Some("#comm".to_string()), &SearchCommandTuning::default())
829 .await?;
830 let expected1 = vec![("#commit".to_string(), 2, false)];
831 assert_eq!(result1.len(), 1);
832 assert_eq!(result1, expected1);
833
834 let filter2 = SearchCommandsFilter {
835 tags: Some(vec!["#git".to_string()]),
836 ..Default::default()
837 };
838 let result2 = storage
839 .find_tags(filter2, Some("#push".to_string()), &SearchCommandTuning::default())
840 .await?;
841 let expected2 = vec![("#push".to_string(), 1, true)];
842 assert_eq!(result2.len(), 1);
843 assert_eq!(result2, expected2);
844
845 Ok(())
846 }
847
848 #[tokio::test]
849 async fn test_find_commands_no_filter() {
850 let storage = setup_ranking_storage().await;
851 let filter = SearchCommandsFilter::default();
852 let (commands, _) = storage
853 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
854 .await
855 .unwrap();
856 assert_eq!(commands.len(), 10, "Expected all sample commands");
857 }
858
859 #[tokio::test]
860 async fn test_find_commands_filter_by_category() {
861 let storage = setup_ranking_storage().await;
862 let filter = SearchCommandsFilter {
863 category: Some(vec!["git".to_string()]),
864 ..Default::default()
865 };
866 let (commands, _) = storage
867 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
868 .await
869 .unwrap();
870 assert_eq!(commands.len(), 2);
871 assert!(commands.iter().all(|c| c.category == "git"));
872
873 let filter_no_match = SearchCommandsFilter {
874 category: Some(vec!["nonexistent".to_string()]),
875 ..Default::default()
876 };
877 let (commands_no_match, _) = storage
878 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
879 .await
880 .unwrap();
881 assert!(commands_no_match.is_empty());
882 }
883
884 #[tokio::test]
885 async fn test_find_commands_filter_by_source() {
886 let storage = setup_ranking_storage().await;
887 let filter = SearchCommandsFilter {
888 source: Some(SOURCE_TLDR.to_string()),
889 ..Default::default()
890 };
891 let (commands, _) = storage
892 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
893 .await
894 .unwrap();
895 assert_eq!(commands.len(), 3);
896 assert!(commands.iter().all(|c| c.source == SOURCE_TLDR));
897 }
898
899 #[tokio::test]
900 async fn test_find_commands_filter_by_tags() {
901 let storage = setup_ranking_storage().await;
902 let filter_single_tag = SearchCommandsFilter {
903 tags: Some(vec!["#git".to_string()]),
904 ..Default::default()
905 };
906 let (commands_single_tag, _) = storage
907 .find_commands(filter_single_tag, "/some/path", &SearchCommandTuning::default())
908 .await
909 .unwrap();
910 assert_eq!(commands_single_tag.len(), 5);
911
912 let filter_multiple_tags = SearchCommandsFilter {
913 tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
914 ..Default::default()
915 };
916 let (commands_multiple_tags, _) = storage
917 .find_commands(filter_multiple_tags, "/some/path", &SearchCommandTuning::default())
918 .await
919 .unwrap();
920 assert_eq!(commands_multiple_tags.len(), 1);
921
922 let filter_empty_tags = SearchCommandsFilter {
923 tags: Some(vec![]),
924 ..Default::default()
925 };
926 let (commands_empty_tags, _) = storage
927 .find_commands(filter_empty_tags, "/some/path", &SearchCommandTuning::default())
928 .await
929 .unwrap();
930 assert_eq!(commands_empty_tags.len(), 10);
931 }
932
933 #[tokio::test]
934 async fn test_find_commands_alias_precedence() {
935 let storage = setup_ranking_storage().await;
936 storage
937 .setup_command(
938 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
939 [("/some/path", 100)],
940 )
941 .await;
942
943 for mode in SearchMode::iter() {
944 let filter = SearchCommandsFilter {
945 search_term: Some("gc".to_string()),
946 search_mode: mode,
947 ..Default::default()
948 };
949 let (commands, alias_match) = storage
950 .find_commands(filter, "", &SearchCommandTuning::default())
951 .await
952 .unwrap();
953 assert!(alias_match, "Expected alias match for mode {mode:?}");
954 assert_eq!(commands.len(), 1, "Expected only alias match for mode {mode:?}");
955 assert_eq!(
956 commands[0].cmd, "git commit -m",
957 "Expected correct alias command for mode {mode:?}"
958 );
959 }
960 }
961
962 #[tokio::test]
963 async fn test_find_commands_search_mode_exact() {
964 let storage = setup_ranking_storage().await;
965 storage.setup_workspace_storage().await.unwrap();
966 let filter_token_match = SearchCommandsFilter {
967 search_term: Some("commit".to_string()),
968 search_mode: SearchMode::Exact,
969 ..Default::default()
970 };
971 let (commands_token_match, _) = storage
972 .find_commands(filter_token_match, "/some/path", &SearchCommandTuning::default())
973 .await
974 .unwrap();
975 assert_eq!(commands_token_match.len(), 2);
976 assert_eq!(commands_token_match[0].cmd, "git commit -m");
977 assert_eq!(commands_token_match[1].cmd, "git commit -m '{{message}}'");
978
979 let filter_no_match = SearchCommandsFilter {
980 search_term: Some("nonexistentterm".to_string()),
981 search_mode: SearchMode::Exact,
982 ..Default::default()
983 };
984 let (commands_no_match, _) = storage
985 .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
986 .await
987 .unwrap();
988 assert!(commands_no_match.is_empty());
989 }
990
991 #[tokio::test]
992 async fn test_find_commands_search_mode_relaxed() {
993 let storage = setup_ranking_storage().await;
994 storage.setup_workspace_storage().await.unwrap();
995 let filter = SearchCommandsFilter {
996 search_term: Some("docker list".to_string()),
997 search_mode: SearchMode::Relaxed,
998 ..Default::default()
999 };
1000 let (commands, _) = storage
1001 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1002 .await
1003 .unwrap();
1004 assert_eq!(commands.len(), 2);
1005 assert!(commands.iter().any(|c| c.cmd == "docker ps -a"));
1006 assert!(commands.iter().any(|c| c.cmd == "ls -lha"));
1007 }
1008
1009 #[tokio::test]
1010 async fn test_find_commands_search_mode_regex() {
1011 let storage = setup_ranking_storage().await;
1012 storage.setup_workspace_storage().await.unwrap();
1013 let filter = SearchCommandsFilter {
1014 search_term: Some(r"git\s.*it".to_string()),
1015 search_mode: SearchMode::Regex,
1016 ..Default::default()
1017 };
1018 let (commands, _) = storage
1019 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1020 .await
1021 .unwrap();
1022 assert_eq!(commands.len(), 2);
1023 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
1024 assert_eq!(commands[1].cmd, "git commit -m");
1025
1026 let filter_invalid = SearchCommandsFilter {
1027 search_term: Some("[[invalid_regex".to_string()),
1028 search_mode: SearchMode::Regex,
1029 ..Default::default()
1030 };
1031 assert!(matches!(
1032 storage
1033 .find_commands(filter_invalid, "/some/path", &SearchCommandTuning::default())
1034 .await,
1035 Err(AppError::UserFacing(UserFacingError::InvalidRegex))
1036 ));
1037 }
1038
1039 #[tokio::test]
1040 async fn test_find_commands_search_mode_fuzzy() {
1041 let storage = setup_ranking_storage().await;
1042 storage.setup_workspace_storage().await.unwrap();
1043 let filter = SearchCommandsFilter {
1044 search_term: Some("gtcomit".to_string()),
1045 search_mode: SearchMode::Fuzzy,
1046 ..Default::default()
1047 };
1048 let (commands, _) = storage
1049 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1050 .await
1051 .unwrap();
1052 assert_eq!(commands.len(), 2);
1053 assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
1054 assert_eq!(commands[1].cmd, "git commit -m");
1055
1056 let filter_empty_fuzzy = SearchCommandsFilter {
1057 search_term: Some("'' | ^".to_string()),
1058 search_mode: SearchMode::Fuzzy,
1059 ..Default::default()
1060 };
1061 assert!(matches!(
1062 storage
1063 .find_commands(filter_empty_fuzzy, "/some/path", &SearchCommandTuning::default())
1064 .await,
1065 Err(AppError::UserFacing(UserFacingError::InvalidFuzzy))
1066 ));
1067 }
1068
1069 #[tokio::test]
1070 async fn test_find_commands_search_mode_auto() {
1071 let storage = setup_ranking_storage().await;
1072 let default_tuning = SearchCommandTuning::default();
1073
1074 let run_search = |term: &'static str, path: &'static str| {
1076 let storage = storage.clone();
1077 async move {
1078 let filter = SearchCommandsFilter {
1079 search_term: Some(term.to_string()),
1080 search_mode: SearchMode::Auto,
1081 ..Default::default()
1082 };
1083 storage.find_commands(filter, path, &default_tuning).await.unwrap()
1084 }
1085 };
1086
1087 let (commands, _) = run_search("list containers", UNRELATED_PATH).await;
1089 assert!(!commands.is_empty(), "Expected results for 'list containers'");
1090 assert_eq!(
1091 commands[0].cmd, "docker ps -a",
1092 "Expected 'docker ps -a' to be the top result for 'list containers'"
1093 );
1094
1095 let (commands, _) = run_search("git commit", PROJ_A_PATH).await;
1097 assert!(commands.len() >= 2, "Expected at least two results for 'git commit'");
1098 assert_eq!(
1099 commands[0].cmd, "git commit -m",
1100 "Expected 'git commit -m' to be the top result for 'git commit' due to usage"
1101 );
1102 assert_eq!(
1103 commands[1].cmd, "git commit -m '{{message}}'",
1104 "Expected template command to be second for 'git commit'"
1105 );
1106
1107 let (commands, _) = run_search("git commit -m 'my new feature'", PROJ_A_PATH).await;
1109 assert!(!commands.is_empty(), "Expected results for template match");
1110 assert_eq!(
1111 commands[0].cmd, "git commit -m '{{message}}'",
1112 "Expected template command to be the top result for a matching search term"
1113 );
1114
1115 let (commands, _) = run_search("build", PROJ_A_API_PATH).await;
1117 assert!(!commands.is_empty(), "Expected results for 'build'");
1118 assert_eq!(
1119 commands[0].cmd, "npm run build:prod",
1120 "Expected 'npm run build:prod' to be top result for 'build' in its project path"
1121 );
1122
1123 let (commands, _) = run_search("gt sta", PROJ_A_PATH).await;
1125 assert!(!commands.is_empty(), "Expected results for fuzzy search 'gt sta'");
1126 assert_eq!(
1127 commands[0].cmd, "git status",
1128 "Expected 'git status' as top result for fuzzy search 'gt sta'"
1129 );
1130
1131 let (commands, _) = run_search("get pod monitoring", UNRELATED_PATH).await;
1133 assert!(!commands.is_empty(), "Expected results for 'get pod monitoring'");
1134 assert_eq!(
1135 commands[0].cmd, "kubectl get pod -n monitoring my-specific-pod-12345",
1136 "Expected specific 'kubectl' command to be found"
1137 );
1138
1139 let (commands, _) = run_search("status", PROJ_A_API_PATH).await;
1141 assert!(!commands.is_empty(), "Expected results for 'status'");
1142 assert_eq!(
1143 commands[0].cmd, "git status",
1144 "Expected 'git status' to be top due to high usage in parent path"
1145 );
1146 }
1147
1148 #[tokio::test]
1149 async fn test_find_commands_search_mode_auto_hastag_only() {
1150 let storage = setup_ranking_storage().await;
1151
1152 let filter = SearchCommandsFilter {
1155 search_term: Some("#".to_string()),
1156 search_mode: SearchMode::Auto,
1157 ..Default::default()
1158 };
1159
1160 let res = storage
1161 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1162 .await;
1163 assert!(res.is_ok(), "Expected a success response, got: {res:?}")
1164 }
1165
1166 #[tokio::test]
1167 async fn test_find_commands_including_workspace() {
1168 let storage = setup_ranking_storage().await;
1169
1170 storage.setup_workspace_storage().await.unwrap();
1171 let commands_to_import = vec![
1172 Command {
1173 id: Uuid::now_v7(),
1174 cmd: "cmd1".to_string(),
1175 ..Default::default()
1176 },
1177 Command {
1178 id: Uuid::now_v7(),
1179 cmd: "cmd2".to_string(),
1180 ..Default::default()
1181 },
1182 ];
1183 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1184 storage.import_commands(stream, false, true).await.unwrap();
1185
1186 let (commands, _) = storage
1187 .find_commands(
1188 SearchCommandsFilter::default(),
1189 "/some/path",
1190 &SearchCommandTuning::default(),
1191 )
1192 .await
1193 .unwrap();
1194 assert_eq!(commands.len(), 12, "Expected 12 commands including workspace");
1195 }
1196
1197 #[tokio::test]
1198 async fn test_find_commands_with_text_including_workspace() {
1199 let storage = setup_ranking_storage().await;
1200
1201 storage.setup_workspace_storage().await.unwrap();
1202 let commands_to_import = vec![Command {
1203 id: Uuid::now_v7(),
1204 cmd: "git checkout -b feature/{{name:kebab}}".to_string(),
1205 ..Default::default()
1206 }];
1207 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1208 storage.import_commands(stream, false, true).await.unwrap();
1209
1210 let filter = SearchCommandsFilter {
1211 search_term: Some("git".to_string()),
1212 ..Default::default()
1213 };
1214
1215 let (commands, _) = storage
1216 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1217 .await
1218 .unwrap();
1219 assert_eq!(commands.len(), 6, "Expected 6 git commands including workspace");
1220 assert!(
1221 commands
1222 .iter()
1223 .any(|c| c.cmd == "git checkout -b feature/{{name:kebab}}")
1224 );
1225 }
1226
1227 #[tokio::test]
1228 async fn test_import_commands_no_overwrite() {
1229 let storage = SqliteStorage::new_in_memory().await.unwrap();
1230
1231 let commands_to_import = vec![
1232 Command {
1233 id: Uuid::now_v7(),
1234 cmd: "cmd1".to_string(),
1235 ..Default::default()
1236 },
1237 Command {
1238 id: Uuid::now_v7(),
1239 cmd: "cmd2".to_string(),
1240 ..Default::default()
1241 },
1242 ];
1243
1244 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1245 let (inserted, skipped_or_updated) = storage.import_commands(stream, false, false).await.unwrap();
1246
1247 assert_eq!(inserted, 2, "Expected 2 commands inserted");
1248 assert_eq!(skipped_or_updated, 0, "Expected 0 commands skipped or updated");
1249
1250 let stream = iter(commands_to_import.into_iter().map(Ok));
1252 let (inserted, skipped_or_updated) = storage.import_commands(stream, false, false).await.unwrap();
1253
1254 assert_eq!(
1255 inserted, 0,
1256 "Expected 0 commands inserted on second import (no overwrite)"
1257 );
1258 assert_eq!(
1259 skipped_or_updated, 2,
1260 "Expected 2 commands skipped on second import (no overwrite)"
1261 );
1262 }
1263
1264 #[tokio::test]
1265 async fn test_import_commands_overwrite() {
1266 let storage = SqliteStorage::new_in_memory().await.unwrap();
1267
1268 let existing_cmd = Command {
1269 id: Uuid::now_v7(),
1270 cmd: "existing_cmd".to_string(),
1271 description: Some("original desc".to_string()),
1272 alias: Some("original_alias".to_string()),
1273 tags: Some(vec!["tag_a".to_string()]),
1274 ..Default::default()
1275 };
1276 storage.insert_command(existing_cmd.clone()).await.unwrap();
1277
1278 let new_cmd = Command {
1279 id: Uuid::now_v7(),
1280 cmd: "new_cmd".to_string(),
1281 ..Default::default()
1282 };
1283
1284 let commands_to_import = vec![
1286 Command {
1287 id: Uuid::now_v7(),
1288 cmd: "existing_cmd".to_string(),
1289 description: Some("updated desc".to_string()),
1290 alias: None,
1291 tags: Some(vec!["tag_b".to_string()]),
1292 ..Default::default()
1293 },
1294 new_cmd.clone(),
1295 ];
1296
1297 let stream = iter(commands_to_import.into_iter().map(Ok));
1298 let (inserted, skipped_or_updated) = storage.import_commands(stream, true, false).await.unwrap();
1299
1300 assert_eq!(inserted, 1, "Expected 1 new command inserted");
1301 assert_eq!(skipped_or_updated, 1, "Expected 1 existing command updated");
1302
1303 let filter = SearchCommandsFilter {
1305 search_term: Some("existing_cmd".to_string()),
1306 search_mode: SearchMode::Exact,
1307 ..Default::default()
1308 };
1309 let (found_commands, _) = storage
1310 .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1311 .await
1312 .unwrap();
1313 assert_eq!(found_commands.len(), 1);
1314 let updated_cmd_in_db = &found_commands[0];
1315 assert_eq!(
1316 updated_cmd_in_db.description,
1317 Some("updated desc".to_string()),
1318 "Description should be updated"
1319 );
1320 assert_eq!(
1321 updated_cmd_in_db.alias,
1322 Some("original_alias".to_string()),
1323 "Alias should NOT be updated to NULL"
1324 );
1325 assert_eq!(
1326 updated_cmd_in_db.tags,
1327 Some(vec!["tag_b".to_string()]),
1328 "Tags should be updated"
1329 );
1330 }
1331
1332 #[tokio::test]
1333 async fn test_import_workspace_commands() {
1334 let storage = SqliteStorage::new_in_memory().await.unwrap();
1335 storage.setup_workspace_storage().await.unwrap();
1336
1337 let commands_to_import = vec![
1338 Command {
1339 id: Uuid::now_v7(),
1340 cmd: "cmd1".to_string(),
1341 ..Default::default()
1342 },
1343 Command {
1344 id: Uuid::now_v7(),
1345 cmd: "cmd2".to_string(),
1346 ..Default::default()
1347 },
1348 ];
1349
1350 let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1351 let (inserted, skipped_or_updated) = storage.import_commands(stream, false, true).await.unwrap();
1352
1353 assert_eq!(inserted, 2, "Expected 2 commands inserted");
1354 assert_eq!(skipped_or_updated, 0, "Expected 0 commands skipped or updated");
1355 }
1356
1357 #[tokio::test]
1358 async fn test_export_user_commands_no_filter() {
1359 let storage = setup_ranking_storage().await;
1360 let mut exported_commands = Vec::new();
1361 let mut stream = storage.export_user_commands(None).await;
1362 while let Some(Ok(cmd)) = stream.next().await {
1363 exported_commands.push(cmd);
1364 }
1365
1366 assert_eq!(exported_commands.len(), 7, "Expected 7 user commands to be exported");
1367 }
1368
1369 #[tokio::test]
1370 async fn test_export_user_commands_with_filter() {
1371 let storage = setup_ranking_storage().await;
1372 let filter = Regex::new(r"^git").unwrap(); let mut exported_commands = Vec::new();
1374 let mut stream = storage.export_user_commands(Some(filter)).await;
1375 while let Some(Ok(cmd)) = stream.next().await {
1376 exported_commands.push(cmd);
1377 }
1378
1379 assert_eq!(exported_commands.len(), 3, "Expected 3 git commands to be exported");
1380
1381 let exported_cmd_values: Vec<String> = exported_commands.into_iter().map(|c| c.cmd).collect();
1382 assert!(exported_cmd_values.contains(&"git status".to_string()));
1383 assert!(exported_cmd_values.contains(&"git checkout main".to_string()));
1384 }
1385
1386 #[tokio::test]
1387 async fn test_delete_tldr_commands() {
1388 let storage = SqliteStorage::new_in_memory().await.unwrap();
1389
1390 let tldr_cmd1 = Command {
1392 id: Uuid::now_v7(),
1393 category: "git".to_string(),
1394 source: SOURCE_TLDR.to_string(),
1395 cmd: "git status".to_string(),
1396 ..Default::default()
1397 };
1398 let tldr_cmd2 = Command {
1399 id: Uuid::now_v7(),
1400 category: "docker".to_string(),
1401 source: SOURCE_TLDR.to_string(),
1402 cmd: "docker ps".to_string(),
1403 ..Default::default()
1404 };
1405 let user_cmd = Command {
1406 id: Uuid::now_v7(),
1407 category: "git".to_string(),
1408 source: SOURCE_USER.to_string(),
1409 cmd: "git log".to_string(),
1410 ..Default::default()
1411 };
1412
1413 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1414 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1415 storage.insert_command(user_cmd.clone()).await.unwrap();
1416
1417 let removed = storage.delete_tldr_commands(None).await.unwrap();
1419 assert_eq!(removed, 2, "Should remove both tldr commands");
1420
1421 let (remaining, _) = storage
1422 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1423 .await
1424 .unwrap();
1425 assert_eq!(remaining.len(), 1, "Only user command should remain");
1426 assert_eq!(remaining[0].cmd, user_cmd.cmd);
1427
1428 storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1430 storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1431
1432 let removed_git = storage.delete_tldr_commands(Some("git".to_string())).await.unwrap();
1434 assert_eq!(removed_git, 1, "Should remove one tldr command in 'git' category");
1435
1436 let (remaining, _) = storage
1437 .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1438 .await
1439 .unwrap();
1440 let remaining_cmds: Vec<_> = remaining.iter().map(|c| &c.cmd).collect();
1441 assert!(remaining_cmds.contains(&&tldr_cmd2.cmd));
1442 assert!(remaining_cmds.contains(&&user_cmd.cmd));
1443 assert!(!remaining_cmds.contains(&&tldr_cmd1.cmd));
1444 }
1445
1446 #[tokio::test]
1447 async fn test_insert_command() {
1448 let storage = SqliteStorage::new_in_memory().await.unwrap();
1449
1450 let mut cmd = Command {
1451 id: Uuid::now_v7(),
1452 category: "test".to_string(),
1453 cmd: "test_cmd".to_string(),
1454 description: Some("test desc".to_string()),
1455 tags: Some(vec!["tag1".to_string()]),
1456 ..Default::default()
1457 };
1458
1459 let mut inserted = storage.insert_command(cmd.clone()).await.unwrap();
1460 assert_eq!(inserted.cmd, cmd.cmd);
1461
1462 inserted.cmd = "other_cmd".to_string();
1464 match storage.insert_command(inserted).await {
1465 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1466 _ => panic!("Expected CommandAlreadyExists error on duplicate id"),
1467 }
1468
1469 cmd.id = Uuid::now_v7();
1471 match storage.insert_command(cmd).await {
1472 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1473 _ => panic!("Expected CommandAlreadyExists error on duplicate cmd"),
1474 }
1475 }
1476
1477 #[tokio::test]
1478 async fn test_update_command() {
1479 let storage = SqliteStorage::new_in_memory().await.unwrap();
1480
1481 let cmd = Command {
1482 id: Uuid::now_v7(),
1483 cmd: "original".to_string(),
1484 description: Some("desc".to_string()),
1485 ..Default::default()
1486 };
1487
1488 storage.insert_command(cmd.clone()).await.unwrap();
1489
1490 let mut updated = cmd.clone();
1491 updated.cmd = "updated".to_string();
1492 updated.description = Some("new desc".to_string());
1493
1494 let result = storage.update_command(updated.clone()).await.unwrap();
1495 assert_eq!(result.cmd, "updated");
1496 assert_eq!(result.description, Some("new desc".to_string()));
1497
1498 let mut non_existent = cmd;
1500 non_existent.id = Uuid::now_v7();
1501 match storage.update_command(non_existent).await {
1502 Err(_) => (),
1503 _ => panic!("Expected error when updating non-existent command"),
1504 }
1505
1506 let another_cmd = Command {
1508 id: Uuid::now_v7(),
1509 cmd: "another".to_string(),
1510 ..Default::default()
1511 };
1512 let mut result = storage.insert_command(another_cmd.clone()).await.unwrap();
1513 result.cmd = "updated".to_string();
1514 match storage.update_command(result).await {
1515 Err(AppError::UserFacing(UserFacingError::CommandAlreadyExists)) => (),
1516 _ => panic!("Expected CommandAlreadyExists error when updating to existing cmd"),
1517 }
1518 }
1519
1520 #[tokio::test]
1521 async fn test_increment_command_usage() {
1522 let storage = SqliteStorage::new_in_memory().await.unwrap();
1523
1524 let command = storage
1526 .setup_command(
1527 Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
1528 [("/some/path", 100)],
1529 )
1530 .await;
1531
1532 let count = storage.increment_command_usage(command.id, "/path").await.unwrap();
1534 assert_eq!(count, 1);
1535
1536 let count = storage.increment_command_usage(command.id, "/some/path").await.unwrap();
1538 assert_eq!(count, 101);
1539 }
1540
1541 #[tokio::test]
1542 async fn test_delete_command() {
1543 let storage = SqliteStorage::new_in_memory().await.unwrap();
1544
1545 let cmd = Command {
1546 id: Uuid::now_v7(),
1547 cmd: "to_delete".to_string(),
1548 ..Default::default()
1549 };
1550
1551 let cmd = storage.insert_command(cmd).await.unwrap();
1552 let res = storage.delete_command(cmd.id).await;
1553 assert!(res.is_ok());
1554
1555 match storage.delete_command(cmd.id).await {
1557 Err(_) => (),
1558 _ => panic!("Expected error when deleting non-existent command"),
1559 }
1560 }
1561
1562 async fn setup_ranking_storage() -> SqliteStorage {
1564 let storage = SqliteStorage::new_in_memory().await.unwrap();
1565 storage
1566 .setup_command(
1567 Command::new(
1568 CATEGORY_USER,
1569 SOURCE_USER,
1570 "kubectl get pod -n monitoring my-specific-pod-12345",
1571 )
1572 .with_description(Some(
1573 "Get a very specific pod by its full name in the monitoring namespace".to_string(),
1574 ))
1575 .with_tags(Some(vec!["#k8s".to_string(), "#pod".to_string()])),
1576 [("/other/path", 1)],
1577 )
1578 .await;
1579 storage
1580 .setup_command(
1581 Command::new(CATEGORY_USER, SOURCE_USER, "git status")
1582 .with_description(Some("Check the status of the git repository".to_string()))
1583 .with_tags(Some(vec!["#git".to_string()])),
1584 [(PROJ_A_PATH, 50), (PROJ_B_PATH, 50), (UNRELATED_PATH, 100)],
1585 )
1586 .await;
1587 storage
1588 .setup_command(
1589 Command::new(CATEGORY_USER, SOURCE_USER, "npm run build:prod")
1590 .with_description(Some("Build the project for production".to_string()))
1591 .with_tags(Some(vec!["#npm".to_string(), "#build".to_string()])),
1592 [(PROJ_A_API_PATH, 25)],
1593 )
1594 .await;
1595 storage
1596 .setup_command(
1597 Command::new(CATEGORY_USER, SOURCE_USER, "container-image-build.sh")
1598 .with_description(Some("A generic script to build a container image".to_string()))
1599 .with_tags(Some(vec!["#docker".to_string(), "#build".to_string()])),
1600 [(UNRELATED_PATH, 35)],
1601 )
1602 .await;
1603 storage
1604 .setup_command(
1605 Command::new(CATEGORY_USER, SOURCE_USER, "git commit -m '{{message}}'")
1606 .with_description(Some("Commit with a message".to_string()))
1607 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1608 [(PROJ_A_PATH, 10), (PROJ_B_PATH, 10)],
1609 )
1610 .await;
1611 storage
1612 .setup_command(
1613 Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")
1614 .with_alias(Some("gco".to_string()))
1615 .with_description(Some("Checkout the main branch".to_string()))
1616 .with_tags(Some(vec!["#git".to_string()])),
1617 [(PROJ_A_PATH, 30), (PROJ_B_PATH, 30)],
1618 )
1619 .await;
1620 storage
1621 .setup_command(
1622 Command::new("git", SOURCE_TLDR, "git commit -m")
1623 .with_alias(Some("gc".to_string()))
1624 .with_description(Some("Commit changes".to_string()))
1625 .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1626 [(PROJ_A_PATH, 15)],
1627 )
1628 .await;
1629 storage
1630 .setup_command(
1631 Command::new("docker", SOURCE_TLDR, "docker ps -a")
1632 .with_description(Some("List all containers".to_string()))
1633 .with_tags(Some(vec!["#docker".to_string(), "#list".to_string()])),
1634 [(PROJ_A_PATH, 5), (PROJ_B_PATH, 5)],
1635 )
1636 .await;
1637 storage
1638 .setup_command(
1639 Command::new("git", SOURCE_TLDR, "git push")
1640 .with_description(Some("Push changes".to_string()))
1641 .with_tags(Some(vec!["#git".to_string(), "#push".to_string()])),
1642 [(PROJ_A_PATH, 20), (PROJ_B_PATH, 20)],
1643 )
1644 .await;
1645 storage
1646 .setup_command(
1647 Command::new(CATEGORY_USER, SOURCE_IMPORT, "ls -lha")
1648 .with_description(Some("List files".to_string()))
1649 .with_tags(Some(vec!["#unix".to_string(), "#list".to_string()])),
1650 [(PROJ_A_PATH, 100), (PROJ_B_PATH, 100), (UNRELATED_PATH, 100)],
1651 )
1652 .await;
1653
1654 storage
1655 }
1656
1657 impl SqliteStorage {
1658 async fn check_sqlite_version(&self) {
1660 let version: String = self
1661 .client
1662 .conn_mut(|conn| {
1663 conn.query_row("SELECT sqlite_version()", [], |row| row.get(0))
1664 .map_err(Into::into)
1665 })
1666 .await
1667 .unwrap();
1668 println!("Running with SQLite version: {version}");
1669 }
1670
1671 async fn setup_command(
1674 &self,
1675 command: Command,
1676 usage: impl IntoIterator<Item = (&str, i32)> + Send + 'static,
1677 ) -> Command {
1678 let command = self.insert_command(command).await.unwrap();
1679 self.client
1680 .conn_mut(move |conn| {
1681 for (path, usage_count) in usage {
1682 conn.execute(
1683 r#"
1684 INSERT INTO command_usage (command_id, path, usage_count)
1685 VALUES (?1, ?2, ?3)
1686 ON CONFLICT(command_id, path) DO UPDATE SET
1687 usage_count = excluded.usage_count"#,
1688 (&command.id, path, usage_count),
1689 )?;
1690 }
1691 Ok(command)
1692 })
1693 .await
1694 .unwrap()
1695 }
1696 }
1697}