1use std::pin::pin;
2
3use chrono::{DateTime, Utc};
4use futures_util::StreamExt;
5use regex::Regex;
6use tokio::sync::mpsc;
7use tokio_stream::{Stream, wrappers::ReceiverStream};
8use tracing::instrument;
9
10use super::SqliteStorage;
11use crate::{
12 errors::{AppError, Result},
13 model::{CATEGORY_USER, Command, ImportExportItem, ImportStats, VariableCompletion},
14};
15
16impl SqliteStorage {
17 #[instrument(skip_all)]
22 pub async fn import_items(
23 &self,
24 items: impl Stream<Item = Result<ImportExportItem>> + Send + 'static,
25 overwrite: bool,
26 workspace: bool,
27 ) -> Result<ImportStats> {
28 let (tx, mut rx) = mpsc::channel(100);
30
31 tokio::spawn(async move {
33 let mut items = pin!(items);
35 while let Some(item_res) = items.next().await {
36 if tx.send(item_res).await.is_err() {
37 tracing::debug!("Import stream channel closed by receiver");
39 break;
40 }
41 }
42 });
43
44 let commands_table = if workspace { "workspace_command" } else { "command" };
46 let completions_table = if workspace {
47 "workspace_variable_completion"
48 } else {
49 "variable_completion"
50 };
51
52 self.client
53 .conn_mut(move |conn| {
54 let mut stats = ImportStats::default();
55 let tx = conn.transaction()?;
56
57 let mut cmd_stmt = if overwrite {
58 tx.prepare(&format!(
59 r#"INSERT INTO {commands_table} (
60 id,
61 category,
62 source,
63 alias,
64 cmd,
65 flat_cmd,
66 description,
67 flat_description,
68 tags,
69 created_at
70 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
71 ON CONFLICT (cmd) DO UPDATE SET
72 alias = COALESCE(excluded.alias, alias),
73 cmd = excluded.cmd,
74 flat_cmd = excluded.flat_cmd,
75 description = COALESCE(excluded.description, description),
76 flat_description = COALESCE(excluded.flat_description, flat_description),
77 tags = COALESCE(excluded.tags, tags),
78 updated_at = excluded.created_at
79 RETURNING updated_at;"#
80 ))?
81 } else {
82 tx.prepare(&format!(
83 r#"INSERT OR IGNORE INTO {commands_table} (
84 id,
85 category,
86 source,
87 alias,
88 cmd,
89 flat_cmd,
90 description,
91 flat_description,
92 tags,
93 created_at
94 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
95 RETURNING updated_at;"#,
96 ))?
97 };
98
99 let mut cmp_stmt = if overwrite {
100 tx.prepare(&format!(
101 r#"INSERT INTO {completions_table} (
102 id,
103 source,
104 root_cmd,
105 flat_root_cmd,
106 variable,
107 flat_variable,
108 suggestions_provider,
109 created_at
110 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
111 ON CONFLICT (flat_root_cmd, flat_variable) DO UPDATE SET
112 source = excluded.source,
113 root_cmd = excluded.root_cmd,
114 flat_root_cmd = excluded.flat_root_cmd,
115 variable = excluded.variable,
116 flat_variable = excluded.flat_variable,
117 suggestions_provider = excluded.suggestions_provider,
118 updated_at = excluded.created_at
119 RETURNING updated_at;"#
120 ))?
121 } else {
122 tx.prepare(&format!(
123 r#"INSERT OR IGNORE INTO {completions_table} (
124 id,
125 source,
126 root_cmd,
127 flat_root_cmd,
128 variable,
129 flat_variable,
130 suggestions_provider,
131 created_at
132 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
133 RETURNING updated_at;"#,
134 ))?
135 };
136
137 while let Some(item_result) = rx.blocking_recv() {
139 match item_result? {
140 ImportExportItem::Command(command) => {
141 tracing::trace!("Importing a {commands_table}: {}", command.cmd);
142 let mut rows = cmd_stmt.query((
143 &command.id,
144 &command.category,
145 &command.source,
146 &command.alias,
147 &command.cmd,
148 &command.flat_cmd,
149 &command.description,
150 &command.flat_description,
151 serde_json::to_value(&command.tags)?,
152 &command.created_at,
153 ))?;
154 match rows.next()? {
155 None => stats.commands_skipped += 1,
157 Some(r) => {
159 let updated_at = r.get::<_, Option<DateTime<Utc>>>(0)?;
160 match updated_at {
161 None => stats.commands_imported += 1,
163 Some(_) => stats.commands_updated += 1,
165 }
166 }
167 }
168 }
169 ImportExportItem::Completion(completion) => {
170 tracing::trace!("Importing a {completions_table}: {completion}");
171 let mut rows = cmp_stmt.query((
172 &completion.id,
173 &completion.source,
174 &completion.root_cmd,
175 &completion.flat_root_cmd,
176 &completion.variable,
177 &completion.flat_variable,
178 &completion.suggestions_provider,
179 &completion.created_at,
180 ))?;
181 match rows.next()? {
182 None => stats.completions_skipped += 1,
184 Some(r) => {
186 let updated_at = r.get::<_, Option<DateTime<Utc>>>(0)?;
187 match updated_at {
188 None => stats.completions_imported += 1,
190 Some(_) => stats.completions_updated += 1,
192 }
193 }
194 }
195 }
196 }
197 }
198
199 drop(cmd_stmt);
200 drop(cmp_stmt);
201 tx.commit()?;
202 Ok(stats)
203 })
204 .await
205 }
206
207 #[instrument(skip_all)]
209 pub async fn export_user_commands(
210 &self,
211 filter: Option<Regex>,
212 ) -> impl Stream<Item = Result<Command>> + Send + 'static {
213 let (tx, rx) = mpsc::channel(100);
215
216 let client = self.client.clone();
218 tokio::spawn(async move {
219 let res = client
220 .conn_mut(move |conn| {
221 let mut q_values = vec![CATEGORY_USER.to_owned()];
223 let mut query = String::from(
224 r"SELECT
225 rowid,
226 id,
227 category,
228 source,
229 alias,
230 cmd,
231 flat_cmd,
232 description,
233 flat_description,
234 tags,
235 created_at,
236 updated_at
237 FROM command
238 WHERE category = ?1",
239 );
240 if let Some(filter) = filter {
241 q_values.push(filter.as_str().to_owned());
242 query.push_str(" AND (cmd REGEXP ?2 OR (description IS NOT NULL AND description REGEXP ?2))");
243 }
244 query.push_str("\nORDER BY cmd ASC");
245
246 tracing::trace!("Exporting commands: {query}");
247
248 let mut stmt = conn.prepare(&query)?;
250 let records_iter =
251 stmt.query_and_then(rusqlite::params_from_iter(q_values), |r| Command::try_from(r))?;
252
253 for record_result in records_iter {
255 if tx.blocking_send(record_result.map_err(AppError::from)).is_err() {
256 tracing::debug!("Async stream receiver dropped, closing db query");
257 break;
258 }
259 }
260
261 Ok(())
262 })
263 .await;
264 if let Err(err) = res {
265 panic!("Couldn't fetch commands to export: {err:?}");
266 }
267 });
268
269 ReceiverStream::new(rx)
271 }
272
273 #[instrument(skip_all)]
280 pub async fn export_user_variable_completions(
281 &self,
282 flat_root_cmd_and_var: impl IntoIterator<Item = (String, String)>,
283 ) -> Result<Vec<VariableCompletion>> {
284 let flat_keys = flat_root_cmd_and_var.into_iter().collect::<Vec<_>>();
286
287 if flat_keys.is_empty() {
288 return Ok(Vec::new());
289 }
290
291 self.client
292 .conn(move |conn| {
293 let values_placeholders = vec!["(?, ?)"; flat_keys.len()].join(", ");
294 let query = format!(
295 r#"WITH input_keys(flat_root_cmd, flat_variable) AS (VALUES {values_placeholders})
296 SELECT
297 t.id,
298 t.source,
299 t.root_cmd,
300 t.flat_root_cmd,
301 t.variable,
302 t.flat_variable,
303 t.suggestions_provider,
304 t.created_at,
305 t.updated_at
306 FROM (
307 SELECT
308 vc.*,
309 ROW_NUMBER() OVER (
310 PARTITION BY ik.flat_root_cmd, ik.flat_variable
311 ORDER BY
312 CASE WHEN vc.flat_root_cmd = ik.flat_root_cmd THEN 0 ELSE 1 END
313 ) as rn
314 FROM variable_completion vc
315 JOIN input_keys ik ON vc.flat_variable = ik.flat_variable
316 WHERE vc.flat_root_cmd = ik.flat_root_cmd
317 OR vc.flat_root_cmd = ''
318 ) AS t
319 WHERE t.rn = 1
320 ORDER BY t.root_cmd, t.variable"#
321 );
322 tracing::trace!("Exporting completions: {query}");
323
324 Ok(conn
325 .prepare(&query)?
326 .query_map(
327 rusqlite::params_from_iter(flat_keys.into_iter().flat_map(|(cmd, var)| vec![cmd, var])),
328 |row| VariableCompletion::try_from(row),
329 )?
330 .collect::<Result<Vec<_>, _>>()?)
331 })
332 .await
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use tokio_stream::iter;
339
340 use super::*;
341 use crate::model::{SOURCE_TLDR, SOURCE_USER};
342
343 #[tokio::test]
344 async fn test_import_items_commands() {
345 let storage = SqliteStorage::new_in_memory().await.unwrap();
346
347 let items_to_import = vec![
348 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cmd1")),
349 ImportExportItem::Command(
350 Command::new(CATEGORY_USER, SOURCE_USER, "cmd2").with_description(Some("original desc".to_string())),
351 ),
352 ];
353
354 let stream = iter(items_to_import.clone().into_iter().map(Ok));
356 let stats = storage.import_items(stream, false, false).await.unwrap();
357 assert_eq!(stats.commands_imported, 2, "Expected 2 new commands to be imported");
358 assert_eq!(stats.commands_skipped, 0);
359
360 let stream = iter(items_to_import.clone().into_iter().map(Ok));
362 let stats = storage.import_items(stream, false, false).await.unwrap();
363 assert_eq!(stats.commands_imported, 0, "Expected 0 commands to be imported");
364 assert_eq!(stats.commands_skipped, 2, "Expected 2 commands to be skipped");
365
366 let items_to_update = vec![ImportExportItem::Command(
368 Command::new(CATEGORY_USER, SOURCE_USER, "cmd2").with_description(Some("updated desc".to_string())),
369 )];
370 let stream = iter(items_to_update.into_iter().map(Ok));
371 let stats = storage.import_items(stream, true, false).await.unwrap();
372 assert_eq!(stats.commands_imported, 0, "Expected 0 new commands to be imported");
373 assert_eq!(stats.commands_updated, 1, "Expected 1 command to be updated");
374 }
375
376 #[tokio::test]
377 async fn test_import_items_completions() {
378 let storage = SqliteStorage::new_in_memory().await.unwrap();
379
380 let items_to_import = vec![
381 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch")),
382 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps")),
383 ];
384
385 let stream = iter(items_to_import.clone().into_iter().map(Ok));
387 let stats = storage.import_items(stream, false, false).await.unwrap();
388 assert_eq!(stats.completions_imported, 2);
389 assert_eq!(stats.completions_skipped, 0);
390
391 let stream = iter(items_to_import.clone().into_iter().map(Ok));
393 let stats = storage.import_items(stream, false, false).await.unwrap();
394 assert_eq!(stats.completions_imported, 0);
395 assert_eq!(stats.completions_skipped, 2);
396
397 let items_to_update = vec![ImportExportItem::Completion(VariableCompletion::new(
399 SOURCE_USER,
400 "git",
401 "branch",
402 "git branch -a",
403 ))];
404 let stream = iter(items_to_update.into_iter().map(Ok));
405 let stats = storage.import_items(stream, true, false).await.unwrap();
406 assert_eq!(stats.completions_imported, 0);
407 assert_eq!(stats.completions_updated, 1);
408 }
409
410 #[tokio::test]
411 async fn test_import_workspace_items() {
412 let (_, stats) = setup_storage(true, true, true).await;
413
414 assert_eq!(
415 stats.commands_imported, 8,
416 "Expected 8 commands inserted into workspace"
417 );
418 assert_eq!(
419 stats.completions_imported, 3,
420 "Expected 3 completions inserted into workspace"
421 );
422 assert_eq!(stats.commands_skipped, 0, "Expected 0 commands skipped in workspace");
423 assert_eq!(
424 stats.completions_skipped, 0,
425 "Expected 0 completions skipped in workspace"
426 );
427 }
428
429 #[tokio::test]
430 async fn test_import_items_mixed_no_overwrite() {
431 let storage = SqliteStorage::new_in_memory().await.unwrap();
432
433 let items_to_import = vec![
434 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cmd1")),
435 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch")),
436 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cmd2")),
437 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps")),
438 ];
439
440 let stream = iter(items_to_import.clone().into_iter().map(Ok));
442 let stats = storage.import_items(stream, false, false).await.unwrap();
443 assert_eq!(stats.commands_imported, 2);
444 assert_eq!(stats.completions_imported, 2);
445 assert_eq!(stats.commands_skipped, 0);
446 assert_eq!(stats.completions_skipped, 0);
447
448 let stream = iter(items_to_import.into_iter().map(Ok));
450 let stats = storage.import_items(stream, false, false).await.unwrap();
451 assert_eq!(stats.commands_imported, 0);
452 assert_eq!(stats.completions_imported, 0);
453 assert_eq!(stats.commands_skipped, 2);
454 assert_eq!(stats.completions_skipped, 2);
455 }
456
457 #[tokio::test]
458 async fn test_import_items_mixed_with_overwrite() {
459 let (storage, _) = setup_storage(true, true, false).await;
460
461 let items_to_import = vec![
462 ImportExportItem::Command(
464 Command::new(CATEGORY_USER, SOURCE_USER, "git status")
465 .with_description(Some("new description".to_string())),
466 ),
467 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "new command")),
469 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch -a")),
471 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "npm", "script", "npm run")),
473 ];
474
475 let stream = iter(items_to_import.into_iter().map(Ok));
476 let stats = storage.import_items(stream, true, false).await.unwrap();
477
478 assert_eq!(stats.commands_updated, 1, "Expected 1 command to be updated");
479 assert_eq!(stats.commands_imported, 1, "Expected 1 new command to be imported");
480 assert_eq!(stats.completions_updated, 1, "Expected 1 completion to be updated");
481 assert_eq!(
482 stats.completions_imported, 1,
483 "Expected 1 new completion to be imported"
484 );
485 }
486
487 #[tokio::test]
488 async fn test_export_user_commands_no_filter() {
489 let (storage, _) = setup_storage(true, false, false).await;
490 let mut exported_commands = Vec::new();
491 let mut stream = storage.export_user_commands(None).await;
492 while let Some(Ok(cmd)) = stream.next().await {
493 exported_commands.push(cmd);
494 }
495
496 assert_eq!(exported_commands.len(), 7, "Expected 7 user commands to be exported");
497 }
498
499 #[tokio::test]
500 async fn test_export_user_commands_with_filter() {
501 let (storage, _) = setup_storage(true, false, false).await;
502 let filter = Regex::new(r"^git").unwrap();
503 let mut exported_commands = Vec::new();
504 let mut stream = storage.export_user_commands(Some(filter)).await;
505 while let Some(Ok(cmd)) = stream.next().await {
506 exported_commands.push(cmd);
507 }
508
509 assert_eq!(exported_commands.len(), 3, "Expected 3 git commands to be exported");
510
511 let exported_cmd_values: Vec<String> = exported_commands.into_iter().map(|c| c.cmd).collect();
512 assert!(exported_cmd_values.contains(&"git status".to_string()));
513 assert!(exported_cmd_values.contains(&"git checkout main".to_string()));
514 assert!(exported_cmd_values.contains(&"git pull".to_string()));
515 }
516
517 #[tokio::test]
518 async fn test_export_user_variable_completions() {
519 let storage = SqliteStorage::new_in_memory().await.unwrap();
520 let completions_to_insert = vec![
521 VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch --specific"),
523 VariableCompletion::new(SOURCE_USER, "", "branch", "git branch --generic"),
524 VariableCompletion::new(SOURCE_USER, "", "commit", "git log --oneline --generic"),
526 VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps"),
528 ];
529 for c in completions_to_insert {
530 storage.insert_variable_completion(c).await.unwrap();
531 }
532
533 let keys_to_export = vec![
535 ("git".to_string(), "branch".to_string()), ("git".to_string(), "commit".to_string()), ("docker".to_string(), "container".to_string()), ("docker".to_string(), "nonexistent".to_string()), ];
540
541 let found = storage.export_user_variable_completions(keys_to_export).await.unwrap();
543 assert_eq!(found.len(), 3, "Should export 3 completions based on precedence rules");
544
545 let commit = &found[0];
547 assert_eq!(
548 commit.flat_root_cmd, "",
549 "Should have fallen back to the empty root cmd for commit"
550 );
551 assert_eq!(commit.flat_variable, "commit");
552 assert_eq!(commit.suggestions_provider, "git log --oneline --generic");
553
554 let container = &found[1];
556 assert_eq!(container.flat_root_cmd, "docker");
557 assert_eq!(container.flat_variable, "container");
558 assert_eq!(container.suggestions_provider, "docker ps");
559
560 let branch = &found[2];
562 assert_eq!(
563 branch.flat_root_cmd, "git",
564 "Should have picked the specific root cmd for branch"
565 );
566 assert_eq!(branch.flat_variable, "branch");
567 assert_eq!(branch.suggestions_provider, "git branch --specific");
568
569 let found_empty = storage.export_user_variable_completions([]).await.unwrap();
571 assert!(found_empty.is_empty(), "Should return an empty vec for empty keys");
572 }
573
574 async fn setup_storage(
576 with_commands: bool,
577 with_completions: bool,
578 workspace: bool,
579 ) -> (SqliteStorage, ImportStats) {
580 let storage = SqliteStorage::new_in_memory().await.unwrap();
581 if workspace {
582 storage.setup_workspace_storage().await.unwrap();
583 }
584
585 let mut items_to_import = Vec::new();
586 if with_commands {
587 items_to_import.extend(vec![
588 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "git status")),
589 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")),
590 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "git pull")),
591 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "docker ps")),
592 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "docker-compose up")),
593 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "npm install")),
594 ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cargo build")),
595 ImportExportItem::Command(Command::new("common", SOURCE_TLDR, "ls -la")),
597 ]);
598 }
599 if with_completions {
600 items_to_import.extend(vec![
601 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch")),
602 ImportExportItem::Completion(VariableCompletion::new(
603 SOURCE_USER,
604 "git",
605 "commit",
606 "git log --oneline",
607 )),
608 ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps")),
609 ]);
610 }
611
612 let stats = if !items_to_import.is_empty() {
613 let stream = iter(items_to_import.into_iter().map(Ok));
614 storage.import_items(stream, false, workspace).await.unwrap()
615 } else {
616 ImportStats::default()
617 };
618
619 (storage, stats)
620 }
621}