1use std::sync::atomic::Ordering as AtomicOrdering;
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{Row, ToSql, ffi};
5use tracing::instrument;
6use uuid::Uuid;
7
8use super::SqliteStorage;
9use crate::{
10 errors::{Result, UserFacingError},
11 model::VariableCompletion,
12};
13
14impl SqliteStorage {
15 #[instrument(skip_all)]
17 pub async fn list_variable_completion_root_cmds(&self) -> Result<Vec<String>> {
18 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
19 self.client
20 .conn(move |conn| {
21 let query = if workspace_tables_loaded {
22 r"SELECT root_cmd
23 FROM (
24 SELECT root_cmd FROM variable_completion
25 UNION
26 SELECT root_cmd FROM workspace_variable_completion
27 )
28 ORDER BY root_cmd"
29 } else {
30 "SELECT root_cmd
31 FROM (SELECT DISTINCT root_cmd FROM variable_completion)
32 ORDER BY root_cmd"
33 };
34 tracing::trace!("Listing root commands completions:\n{query}");
35 Ok(conn
36 .prepare(query)?
37 .query_map([], |row| row.get(0))?
38 .collect::<Result<Vec<String>, _>>()?)
39 })
40 .await
41 }
42
43 #[instrument(skip_all)]
47 pub async fn list_variable_completions(
48 &self,
49 flat_root_cmd: Option<String>,
50 flat_variable_names: Option<Vec<String>>,
51 skip_workspace: bool,
52 ) -> Result<Vec<VariableCompletion>> {
53 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
54
55 self.client
56 .conn(move |conn| {
57 let mut conditions = Vec::new();
58 let mut params = Vec::<&dyn ToSql>::new();
59 let base_query = if !skip_workspace && workspace_tables_loaded {
60 conditions.push("rn = 1".to_string());
61 r"SELECT *
62 FROM (
63 SELECT
64 id,
65 source,
66 root_cmd,
67 flat_root_cmd,
68 variable,
69 flat_variable,
70 suggestions_provider,
71 created_at,
72 updated_at,
73 ROW_NUMBER() OVER (PARTITION BY flat_root_cmd, flat_variable ORDER BY is_workspace ASC) as rn
74 FROM (
75 SELECT *, 0 AS is_workspace FROM variable_completion
76 UNION ALL
77 SELECT *, 1 AS is_workspace FROM workspace_variable_completion
78 )
79 )"
80 } else {
81 r"SELECT
82 id,
83 source,
84 root_cmd,
85 flat_root_cmd,
86 variable,
87 flat_variable,
88 suggestions_provider,
89 created_at,
90 updated_at
91 FROM variable_completion"
92 };
93
94 if let Some(cmd) = &flat_root_cmd {
96 conditions.push("flat_root_cmd = ?".to_string());
97 params.push(cmd);
98 }
99
100 if let Some(vars) = &flat_variable_names {
102 if vars.is_empty() {
103 conditions.push(String::from("1=0"));
105 } else if vars.len() == 1 {
106 conditions.push("flat_variable = ?".to_string());
108 params.push(&vars[0]);
109 } else {
110 let placeholders = vec!["?"; vars.len()].join(",");
112 conditions.push(format!("flat_variable IN ({placeholders})"));
113 for var in vars {
114 params.push(var);
115 }
116 }
117 }
118
119 let query = if conditions.is_empty() {
120 format!("{base_query}\nORDER BY root_cmd, variable")
121 } else {
122 format!("{base_query}\nWHERE {}\nORDER BY root_cmd, variable", conditions.join(" AND "))
123 };
124
125 tracing::trace!("Listing completions:\n{query}");
126
127 Ok(conn
128 .prepare(&query)?
129 .query_map(¶ms[..], |row| VariableCompletion::try_from(row))?
130 .collect::<Result<Vec<_>, _>>()?)
131 })
132 .await
133 }
134
135 pub async fn get_completions_for(
140 &self,
141 flat_root_cmd: impl Into<String>,
142 flat_variable_names: Vec<String>,
143 ) -> Result<Vec<VariableCompletion>> {
144 if flat_variable_names.is_empty() {
146 return Ok(Vec::new());
147 }
148
149 let flat_root_cmd = flat_root_cmd.into();
150 let workspace_tables_loaded = self.workspace_tables_loaded.load(AtomicOrdering::SeqCst);
151
152 self.client
153 .conn(move |conn| {
154 let mut params: Vec<&dyn ToSql> = vec![&flat_root_cmd, &flat_root_cmd];
156
157 let placeholders = vec!["?"; flat_variable_names.len()].join(",");
158 for var in &flat_variable_names {
159 params.push(var);
160 }
161
162 let mut order_by_clause = "ORDER BY CASE flat_variable ".to_string();
164 for (index, var_name) in flat_variable_names.iter().enumerate() {
165 order_by_clause.push_str(&format!("WHEN ? THEN {index} "));
166 params.push(var_name);
167 }
168 order_by_clause.push_str("END");
169
170 let sub_query = if workspace_tables_loaded {
172 r"SELECT *, 0 AS is_workspace FROM variable_completion
173 UNION ALL
174 SELECT *, 1 AS is_workspace FROM workspace_variable_completion"
175 } else {
176 "SELECT *, 0 AS is_workspace FROM variable_completion"
177 };
178
179 let query = format!(
183 r"SELECT
184 id,
185 source,
186 root_cmd,
187 flat_root_cmd,
188 variable,
189 flat_variable,
190 suggestions_provider,
191 created_at,
192 updated_at
193 FROM (
194 SELECT
195 *,
196 ROW_NUMBER() OVER (
197 PARTITION BY flat_variable
198 ORDER BY
199 CASE WHEN flat_root_cmd = ? THEN 0 ELSE 1 END,
200 is_workspace
201 ) as rn
202 FROM (
203 {sub_query}
204 )
205 WHERE (flat_root_cmd = ? OR flat_root_cmd = '')
206 AND flat_variable IN ({placeholders})
207 )
208 WHERE rn = 1
209 {order_by_clause}"
210 );
211
212 tracing::trace!("Retrieving completions for a variable:\n{query}");
213
214 Ok(conn
215 .prepare(&query)?
216 .query_map(¶ms[..], |row| VariableCompletion::try_from(row))?
217 .collect::<Result<Vec<_>, _>>()?)
218 })
219 .await
220 }
221
222 #[instrument(skip_all)]
224 pub async fn insert_variable_completion(&self, var: VariableCompletion) -> Result<VariableCompletion> {
225 self.client
226 .conn_mut(move |conn| {
227 let query = r#"INSERT INTO variable_completion (
228 id,
229 source,
230 root_cmd,
231 flat_root_cmd,
232 variable,
233 flat_variable,
234 suggestions_provider,
235 created_at,
236 updated_at
237 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)"#;
238 tracing::trace!("Inserting a completion:\n{query}");
239 let res = conn.execute(
240 query,
241 (
242 &var.id,
243 &var.source,
244 &var.root_cmd,
245 &var.flat_root_cmd,
246 &var.variable,
247 &var.flat_variable,
248 &var.suggestions_provider,
249 &var.created_at,
250 &var.updated_at,
251 ),
252 );
253 match res {
254 Ok(_) => Ok(var),
255 Err(err) => {
256 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
257 if code == ffi::SQLITE_CONSTRAINT_UNIQUE || code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY {
258 Err(UserFacingError::CompletionAlreadyExists.into())
259 } else {
260 Err(Report::from(err).into())
261 }
262 }
263 }
264 })
265 .await
266 }
267
268 #[instrument(skip_all)]
270 pub async fn update_variable_completion(&self, var: VariableCompletion) -> Result<VariableCompletion> {
271 self.client
272 .conn_mut(move |conn| {
273 let query = r#"
274 UPDATE variable_completion
275 SET source = ?2,
276 root_cmd = ?3,
277 flat_root_cmd = ?4,
278 variable = ?5,
279 flat_variable = ?6,
280 suggestions_provider = ?7,
281 created_at = ?8,
282 updated_at = ?9
283 WHERE id = ?1
284 "#;
285 tracing::trace!("Updating a completion:\n{query}");
286 let res = conn.execute(
287 query,
288 (
289 &var.id,
290 &var.source,
291 &var.root_cmd,
292 &var.flat_root_cmd,
293 &var.variable,
294 &var.flat_variable,
295 &var.suggestions_provider,
296 &var.created_at,
297 &var.updated_at,
298 ),
299 );
300 match res {
301 Ok(0) => Err(eyre!("Variable completion not found: {}", var.id)
302 .wrap_err("Couldn't update a variable completion")
303 .into()),
304 Ok(_) => Ok(var),
305 Err(err) => {
306 let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
307 if code == ffi::SQLITE_CONSTRAINT_UNIQUE {
308 Err(UserFacingError::CompletionAlreadyExists.into())
309 } else {
310 Err(Report::from(err).into())
311 }
312 }
313 }
314 })
315 .await
316 }
317
318 #[instrument(skip_all)]
320 pub async fn delete_variable_completion(&self, completion_id: Uuid) -> Result<()> {
321 self.client
322 .conn_mut(move |conn| {
323 let query = "DELETE FROM variable_completion WHERE id = ?1";
324 tracing::trace!("Deleting a completion:\n{query}");
325 let res = conn.execute(query, (&completion_id,));
326 match res {
327 Ok(0) => Err(eyre!("Variable completion not found: {completion_id}").into()),
328 Ok(_) => Ok(()),
329 Err(err) => Err(Report::from(err).into()),
330 }
331 })
332 .await
333 }
334
335 #[instrument(skip_all)]
337 pub async fn delete_variable_completion_by_key(
338 &self,
339 flat_root_cmd: impl Into<String>,
340 flat_variable_name: impl Into<String>,
341 ) -> Result<Option<VariableCompletion>> {
342 let flat_root_cmd = flat_root_cmd.into();
343 let flat_variable_name = flat_variable_name.into();
344
345 self.client
346 .conn_mut(move |conn| {
347 let query = r"DELETE FROM variable_completion
348 WHERE flat_root_cmd = ?1 AND flat_variable = ?2
349 RETURNING
350 id,
351 source,
352 root_cmd,
353 flat_root_cmd,
354 variable,
355 flat_variable,
356 suggestions_provider,
357 created_at,
358 updated_at";
359 tracing::trace!("Deleting a completion:\n{query}");
360 let res = conn.query_row(query, (&flat_root_cmd, &flat_variable_name), |row| {
361 VariableCompletion::try_from(row)
362 });
363
364 match res {
365 Ok(completion) => Ok(Some(completion)),
366 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
367 Err(err) => Err(Report::from(err).into()),
368 }
369 })
370 .await
371 }
372}
373
374impl<'a> TryFrom<&'a Row<'a>> for VariableCompletion {
375 type Error = rusqlite::Error;
376
377 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
378 Ok(Self {
379 id: row.get(0)?,
380 source: row.get(1)?,
381 root_cmd: row.get(2)?,
382 flat_root_cmd: row.get(3)?,
383 variable: row.get(4)?,
384 flat_variable: row.get(5)?,
385 suggestions_provider: row.get(6)?,
386 created_at: row.get(7)?,
387 updated_at: row.get(8)?,
388 })
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use futures_util::stream;
395 use pretty_assertions::assert_eq;
396
397 use super::*;
398 use crate::{
399 errors::AppError,
400 model::{ImportExportItem, SOURCE_IMPORT, SOURCE_USER, VariableCompletion},
401 };
402
403 #[tokio::test]
404 async fn test_list_variable_completion_root_cmds() {
405 let storage = SqliteStorage::new_in_memory().await.unwrap();
407
408 let root_cmds = storage.list_variable_completion_root_cmds().await.unwrap();
410 assert!(
411 root_cmds.is_empty(),
412 "Should return an empty vector when the database is empty"
413 );
414
415 let var1 = VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch");
417 let var2 = VariableCompletion::new(SOURCE_USER, "git", "commit", "git log --oneline");
418 let var3 = VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps");
419 storage.insert_variable_completion(var1).await.unwrap();
420 storage.insert_variable_completion(var2).await.unwrap();
421 storage.insert_variable_completion(var3).await.unwrap();
422
423 let root_cmds = storage.list_variable_completion_root_cmds().await.unwrap();
425 let expected = vec!["docker".to_string(), "git".to_string()];
426 assert_eq!(root_cmds.len(), 2, "Should return only unique root commands");
427 assert_eq!(
428 root_cmds, expected,
429 "The returned root commands should match the expected unique values"
430 );
431
432 storage.setup_workspace_storage().await.unwrap();
434
435 let workspace_items = vec![
437 Ok(ImportExportItem::Completion(VariableCompletion::new(
438 SOURCE_IMPORT,
439 "git",
440 "tag",
441 "git tag",
442 ))),
443 Ok(ImportExportItem::Completion(VariableCompletion::new(
444 SOURCE_IMPORT,
445 "npm",
446 "install",
447 "npm i",
448 ))),
449 ];
450 let stream = stream::iter(workspace_items);
451 storage.import_items(stream, false, true).await.unwrap();
452
453 let root_cmds_with_workspace = storage.list_variable_completion_root_cmds().await.unwrap();
455 let expected_with_workspace = vec!["docker".to_string(), "git".to_string(), "npm".to_string()];
456 assert_eq!(
457 root_cmds_with_workspace, expected_with_workspace,
458 "Should include unique root cmds from workspace"
459 );
460 }
461
462 #[tokio::test]
463 async fn test_list_variable_completions() {
464 let storage = SqliteStorage::new_in_memory().await.unwrap();
466 let var1 = VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch");
467 let var2 = VariableCompletion::new(SOURCE_USER, "git", "commit", "git log --oneline");
468 let var3 = VariableCompletion::new(SOURCE_IMPORT, "docker", "container", "docker ps");
469 storage.insert_variable_completion(var1).await.unwrap();
470 storage.insert_variable_completion(var2).await.unwrap();
471 storage.insert_variable_completion(var3).await.unwrap();
472
473 let all = storage.list_variable_completions(None, None, false).await.unwrap();
475 assert_eq!(all.len(), 3);
476
477 let git_cmds = storage
479 .list_variable_completions(Some("git".into()), None, false)
480 .await
481 .unwrap();
482 assert_eq!(git_cmds.len(), 2);
483
484 let branch_vars = storage
486 .list_variable_completions(None, Some(vec!["branch".into()]), false)
487 .await
488 .unwrap();
489 assert_eq!(branch_vars.len(), 1);
490
491 let git_branch = storage
493 .list_variable_completions(Some("git".into()), Some(vec!["branch".into()]), false)
494 .await
495 .unwrap();
496 assert_eq!(git_branch.len(), 1);
497 assert_eq!(git_branch[0].flat_root_cmd, "git");
498 assert_eq!(git_branch[0].flat_variable, "branch");
499
500 let git_multi_vars = storage
502 .list_variable_completions(Some("git".into()), Some(vec!["commit".into(), "branch".into()]), false)
503 .await
504 .unwrap();
505 assert_eq!(git_multi_vars.len(), 2);
506 assert_eq!(git_multi_vars[0].variable, "branch");
507 assert_eq!(git_multi_vars[1].variable, "commit");
508
509 let none_cmd = storage
511 .list_variable_completions(Some("nonexistent".into()), None, false)
512 .await
513 .unwrap();
514 assert_eq!(none_cmd.len(), 0);
515
516 let none_var = storage
517 .list_variable_completions(Some("git".into()), Some(vec!["nonexistent".into()]), false)
518 .await
519 .unwrap();
520 assert_eq!(none_var.len(), 0);
521 }
522
523 #[tokio::test]
524 async fn test_list_variable_completions_with_workspace_precedence() {
525 let storage = SqliteStorage::new_in_memory().await.unwrap();
526 storage.setup_workspace_storage().await.unwrap();
527
528 let global_var = VariableCompletion::new(SOURCE_USER, "git", "checkout", "git branch --global");
530 storage.insert_variable_completion(global_var).await.unwrap();
531
532 let workspace_var = VariableCompletion::new(SOURCE_IMPORT, "git", "checkout", "git branch --workspace");
534 let workspace_only_var = VariableCompletion::new(SOURCE_IMPORT, "npm", "install", "npm i --workspace");
535 let stream = stream::iter(vec![
536 Ok(ImportExportItem::Completion(workspace_var)),
537 Ok(ImportExportItem::Completion(workspace_only_var)),
538 ]);
539 storage.import_items(stream, false, true).await.unwrap();
540
541 let completions = storage
543 .list_variable_completions(Some("git".into()), Some(vec!["checkout".into()]), false)
544 .await
545 .unwrap();
546 assert_eq!(completions.len(), 1);
547 assert_eq!(
548 completions[0].source, SOURCE_USER,
549 "Global completion should take precedence"
550 );
551 assert_eq!(completions[0].suggestions_provider, "git branch --global");
552
553 let completions_npm = storage
555 .list_variable_completions(Some("npm".into()), Some(vec!["install".into()]), false)
556 .await
557 .unwrap();
558 assert_eq!(completions_npm.len(), 1);
559 assert_eq!(
560 completions_npm[0].source, SOURCE_IMPORT,
561 "Should get workspace completion when no global exists"
562 );
563
564 let completions_skip_workspace = storage
566 .list_variable_completions(Some("git".into()), Some(vec!["checkout".into()]), true)
567 .await
568 .unwrap();
569 assert_eq!(completions_skip_workspace.len(), 1);
570 assert_eq!(
571 completions_skip_workspace[0].source, SOURCE_USER,
572 "Should only find global completion when skipping workspace"
573 );
574 }
575
576 #[tokio::test]
577 async fn test_get_completions_for() {
578 let storage = SqliteStorage::new_in_memory().await.unwrap();
579 storage.setup_workspace_storage().await.unwrap();
580
581 let user_completions = vec![
588 VariableCompletion::new(SOURCE_USER, "docker", "image", "docker images --user-specific"),
590 VariableCompletion::new(SOURCE_USER, "", "image", "generic images --user"),
592 VariableCompletion::new(SOURCE_USER, "", "container", "generic container --user"),
594 VariableCompletion::new(SOURCE_USER, "", "version", "generic version --user"),
596 ];
597 for completion in user_completions {
598 storage.insert_variable_completion(completion).await.unwrap();
599 }
600
601 let workspace_items = vec![
602 Ok(ImportExportItem::Completion(VariableCompletion::new(
603 SOURCE_IMPORT,
604 "docker",
605 "image",
606 "docker images --workspace-specific", ))),
608 Ok(ImportExportItem::Completion(VariableCompletion::new(
609 SOURCE_IMPORT,
610 "",
611 "image",
612 "generic images --workspace", ))),
614 Ok(ImportExportItem::Completion(VariableCompletion::new(
615 SOURCE_IMPORT,
616 "",
617 "container",
618 "generic container --workspace", ))),
620 Ok(ImportExportItem::Completion(VariableCompletion::new(
621 SOURCE_IMPORT,
622 "docker",
623 "volume",
624 "docker volume ls --workspace", ))),
626 Ok(ImportExportItem::Completion(VariableCompletion::new(
627 SOURCE_IMPORT,
628 "",
629 "network",
630 "generic network --workspace", ))),
632 ];
633 storage
634 .import_items(stream::iter(workspace_items), false, true)
635 .await
636 .unwrap();
637
638 let completions = storage
640 .get_completions_for(
641 "docker",
642 vec![
643 "image".into(),
644 "container".into(),
645 "nonexistent".into(),
646 "volume".into(),
647 "network".into(),
648 "version".into(),
649 ],
650 )
651 .await
652 .unwrap();
653
654 assert_eq!(
655 completions.len(),
656 5,
657 "Should resolve one completion for each existing variable and ignore non-existent ones"
658 );
659
660 let image = &completions[0];
662 assert_eq!(image.flat_variable, "image");
663 assert_eq!(image.flat_root_cmd, "docker");
664 assert_eq!(image.source, SOURCE_USER);
665 assert_eq!(image.suggestions_provider, "docker images --user-specific");
666
667 let container = &completions[1];
669 assert_eq!(container.flat_variable, "container");
670 assert_eq!(container.flat_root_cmd, "");
671 assert_eq!(container.source, SOURCE_USER);
672 assert_eq!(container.suggestions_provider, "generic container --user");
673
674 let volume = &completions[2];
676 assert_eq!(volume.flat_variable, "volume");
677 assert_eq!(volume.flat_root_cmd, "docker");
678 assert_eq!(volume.source, SOURCE_IMPORT);
679 assert_eq!(volume.suggestions_provider, "docker volume ls --workspace");
680
681 let network = &completions[3];
683 assert_eq!(network.flat_variable, "network");
684 assert_eq!(network.flat_root_cmd, "");
685 assert_eq!(network.source, SOURCE_IMPORT);
686 assert_eq!(network.suggestions_provider, "generic network --workspace");
687
688 let version = &completions[4];
690 assert_eq!(version.flat_variable, "version");
691 assert_eq!(version.flat_root_cmd, "");
692 assert_eq!(version.source, SOURCE_USER);
693 assert_eq!(version.suggestions_provider, "generic version --user");
694 }
695
696 #[tokio::test]
697 async fn test_insert_variable_completion() {
698 let storage = SqliteStorage::new_in_memory().await.unwrap();
699 let var = VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch");
700
701 let inserted_var = storage.insert_variable_completion(var.clone()).await.unwrap();
702 assert_eq!(inserted_var.flat_root_cmd, var.flat_root_cmd);
703
704 match storage.insert_variable_completion(var).await {
706 Err(AppError::UserFacing(UserFacingError::CompletionAlreadyExists)) => {}
707 res => panic!("Expected CompletionAlreadyExists error, got {res:?}"),
708 }
709 }
710
711 #[tokio::test]
712 async fn test_update_variable_completion() {
713 let storage = SqliteStorage::new_in_memory().await.unwrap();
714 let var = VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch");
715 let mut inserted_var = storage.insert_variable_completion(var).await.unwrap();
716
717 inserted_var.suggestions_provider = "git branch --all".to_string();
718 storage.update_variable_completion(inserted_var).await.unwrap();
719
720 let mut found = storage
721 .list_variable_completions(Some("git".into()), Some(vec!["branch".into()]), false)
722 .await
723 .unwrap();
724 assert_eq!(found.len(), 1);
725 let found = found.pop().unwrap();
726 assert_eq!(found.suggestions_provider, "git branch --all");
727 }
728
729 #[tokio::test]
730 async fn test_delete_variable_completion() {
731 let storage = SqliteStorage::new_in_memory().await.unwrap();
732 let var = VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch");
733 let inserted_var = storage.insert_variable_completion(var).await.unwrap();
734
735 storage.delete_variable_completion(inserted_var.id).await.unwrap();
736
737 let found = storage
738 .list_variable_completions(Some("git".into()), Some(vec!["branch".into()]), false)
739 .await
740 .unwrap();
741 assert!(found.is_empty());
742 }
743
744 #[tokio::test]
745 async fn test_delete_variable_completion_by_key() {
746 let storage = SqliteStorage::new_in_memory().await.unwrap();
747 let var = VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch");
748 storage.insert_variable_completion(var.clone()).await.unwrap();
749
750 let deleted = storage
752 .delete_variable_completion_by_key("git", "branch")
753 .await
754 .unwrap();
755 assert_eq!(deleted, Some(var));
756
757 let found = storage
759 .list_variable_completions(Some("git".into()), Some(vec!["branch".into()]), false)
760 .await
761 .unwrap();
762 assert!(found.is_empty());
763
764 let deleted_again = storage
766 .delete_variable_completion_by_key("git", "branch")
767 .await
768 .unwrap();
769 assert_eq!(deleted_again, None);
770 }
771}