intelli_shell/storage/
import_export.rs

1use std::pin::pin;
2
3use chrono::{DateTime, Utc};
4use futures_util::StreamExt;
5use regex::Regex;
6use tokio::sync::mpsc;
7use tokio_stream::{Stream, wrappers::ReceiverStream};
8use tracing::instrument;
9
10use super::SqliteStorage;
11use crate::{
12    errors::{AppError, Result},
13    model::{CATEGORY_USER, Command, ImportExportItem, ImportStats, VariableCompletion},
14};
15
16impl SqliteStorage {
17    /// Imports a collection of commands and completions into the database.
18    ///
19    /// This function allows for bulk insertion or updating of items from a stream.
20    /// The behavior for existing items depends on the `overwrite` flag.
21    #[instrument(skip_all)]
22    pub async fn import_items(
23        &self,
24        items: impl Stream<Item = Result<ImportExportItem>> + Send + 'static,
25        overwrite: bool,
26        workspace: bool,
27    ) -> Result<ImportStats> {
28        // Create a channel to bridge the async stream with the sync database operations
29        let (tx, mut rx) = mpsc::channel(100);
30
31        // Spawn a producer task to read from the async stream and send to the channel
32        tokio::spawn(async move {
33            // Pin the stream to be able to iterate over it
34            let mut items = pin!(items);
35            while let Some(item_res) = items.next().await {
36                if tx.send(item_res).await.is_err() {
37                    // Receiver has been dropped, so we can stop
38                    tracing::debug!("Import stream channel closed by receiver");
39                    break;
40                }
41            }
42        });
43
44        // Determine which tables to import into based on the `workspace` flag
45        let commands_table = if workspace { "workspace_command" } else { "command" };
46        let completions_table = if workspace {
47            "workspace_variable_completion"
48        } else {
49            "variable_completion"
50        };
51
52        self.client
53            .conn_mut(move |conn| {
54                let mut stats = ImportStats::default();
55                let tx = conn.transaction()?;
56
57                let mut cmd_stmt = if overwrite {
58                    tx.prepare(&format!(
59                        r#"INSERT INTO {commands_table} (
60                            id,
61                            category,
62                            source,
63                            alias,
64                            cmd,
65                            flat_cmd,
66                            description,
67                            flat_description,
68                            tags,
69                            created_at
70                        ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
71                        ON CONFLICT (cmd) DO UPDATE SET
72                            alias = COALESCE(excluded.alias, alias),
73                            cmd = excluded.cmd,
74                            flat_cmd = excluded.flat_cmd,
75                            description = COALESCE(excluded.description, description),
76                            flat_description = COALESCE(excluded.flat_description, flat_description),
77                            tags = COALESCE(excluded.tags, tags),
78                            updated_at = excluded.created_at
79                        RETURNING updated_at;"#
80                    ))?
81                } else {
82                    tx.prepare(&format!(
83                        r#"INSERT OR IGNORE INTO {commands_table} (
84                            id,
85                            category,
86                            source,
87                            alias,
88                            cmd,
89                            flat_cmd,
90                            description,
91                            flat_description,
92                            tags,
93                            created_at
94                        ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
95                        RETURNING updated_at;"#,
96                    ))?
97                };
98
99                let mut cmp_stmt = if overwrite {
100                    tx.prepare(&format!(
101                        r#"INSERT INTO {completions_table} (
102                            id,
103                            source,
104                            root_cmd,
105                            flat_root_cmd,
106                            variable,
107                            flat_variable,
108                            suggestions_provider,
109                            created_at
110                        ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
111                        ON CONFLICT (flat_root_cmd, flat_variable) DO UPDATE SET
112                            source = excluded.source,
113                            root_cmd = excluded.root_cmd,
114                            flat_root_cmd = excluded.flat_root_cmd,
115                            variable = excluded.variable,
116                            flat_variable = excluded.flat_variable,
117                            suggestions_provider = excluded.suggestions_provider,
118                            updated_at = excluded.created_at
119                        RETURNING updated_at;"#
120                    ))?
121                } else {
122                    tx.prepare(&format!(
123                        r#"INSERT OR IGNORE INTO {completions_table} (
124                            id,
125                            source,
126                            root_cmd,
127                            flat_root_cmd,
128                            variable,
129                            flat_variable,
130                            suggestions_provider,
131                            created_at
132                        ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
133                        RETURNING updated_at;"#,
134                    ))?
135                };
136
137                // Process items from the channel
138                while let Some(item_result) = rx.blocking_recv() {
139                    match item_result? {
140                        ImportExportItem::Command(command) => {
141                            tracing::trace!("Importing a {commands_table}: {}", command.cmd);
142                            let mut rows = cmd_stmt.query((
143                                &command.id,
144                                &command.category,
145                                &command.source,
146                                &command.alias,
147                                &command.cmd,
148                                &command.flat_cmd,
149                                &command.description,
150                                &command.flat_description,
151                                serde_json::to_value(&command.tags)?,
152                                &command.created_at,
153                            ))?;
154                            match rows.next()? {
155                                // No row returned, this happens only when overwrite = false, meaning it was skipped
156                                None => stats.commands_skipped += 1,
157                                // When a row is returned (can happen on both paths)
158                                Some(r) => {
159                                    let updated_at = r.get::<_, Option<DateTime<Utc>>>(0)?;
160                                    match updated_at {
161                                        // If there's no update date, it's a new insert
162                                        None => stats.commands_imported += 1,
163                                        // If it has a value, it was updated
164                                        Some(_) => stats.commands_updated += 1,
165                                    }
166                                }
167                            }
168                        }
169                        ImportExportItem::Completion(completion) => {
170                            tracing::trace!("Importing a {completions_table}: {completion}");
171                            let mut rows = cmp_stmt.query((
172                                &completion.id,
173                                &completion.source,
174                                &completion.root_cmd,
175                                &completion.flat_root_cmd,
176                                &completion.variable,
177                                &completion.flat_variable,
178                                &completion.suggestions_provider,
179                                &completion.created_at,
180                            ))?;
181                            match rows.next()? {
182                                // No row returned, this happens only when overwrite = false, meaning it was skipped
183                                None => stats.completions_skipped += 1,
184                                // When a row is returned (can happen on both paths)
185                                Some(r) => {
186                                    let updated_at = r.get::<_, Option<DateTime<Utc>>>(0)?;
187                                    match updated_at {
188                                        // If there's no update date, it's a new insert
189                                        None => stats.completions_imported += 1,
190                                        // If it has a value, it was updated
191                                        Some(_) => stats.completions_updated += 1,
192                                    }
193                                }
194                            }
195                        }
196                    }
197                }
198
199                drop(cmd_stmt);
200                drop(cmp_stmt);
201                tx.commit()?;
202                Ok(stats)
203            })
204            .await
205    }
206
207    /// Export user commands
208    #[instrument(skip_all)]
209    pub async fn export_user_commands(
210        &self,
211        filter: Option<Regex>,
212    ) -> impl Stream<Item = Result<Command>> + Send + 'static {
213        // Create a channel to stream results from the database with a small buffer to provide backpressure
214        let (tx, rx) = mpsc::channel(100);
215
216        // Spawn a new task to run the query and send results back through the channel
217        let client = self.client.clone();
218        tokio::spawn(async move {
219            let res = client
220                .conn_mut(move |conn| {
221                    // Prepare the query
222                    let mut q_values = vec![CATEGORY_USER.to_owned()];
223                    let mut query = String::from(
224                        r"SELECT
225                            rowid,
226                            id,
227                            category,
228                            source,
229                            alias,
230                            cmd,
231                            flat_cmd,
232                            description,
233                            flat_description,
234                            tags,
235                            created_at,
236                            updated_at
237                        FROM command
238                        WHERE category = ?1",
239                    );
240                    if let Some(filter) = filter {
241                        q_values.push(filter.as_str().to_owned());
242                        query.push_str(" AND (cmd REGEXP ?2 OR (description IS NOT NULL AND description REGEXP ?2))");
243                    }
244                    query.push_str("\nORDER BY cmd ASC");
245
246                    tracing::trace!("Exporting commands: {query}");
247
248                    // Create an iterator over the rows
249                    let mut stmt = conn.prepare(&query)?;
250                    let records_iter =
251                        stmt.query_and_then(rusqlite::params_from_iter(q_values), |r| Command::try_from(r))?;
252
253                    // Iterate and send each record back through the channel
254                    for record_result in records_iter {
255                        if tx.blocking_send(record_result.map_err(AppError::from)).is_err() {
256                            tracing::debug!("Async stream receiver dropped, closing db query");
257                            break;
258                        }
259                    }
260
261                    Ok(())
262                })
263                .await;
264            if let Err(err) = res {
265                panic!("Couldn't fetch commands to export: {err:?}");
266            }
267        });
268
269        // Return the receiver stream
270        ReceiverStream::new(rx)
271    }
272
273    /// Exports user variable completions for a given set of (flat_root_cmd, flat_variable_name) pairs.
274    ///
275    /// For each pair, it resolves the best match by first looking for a completion with the
276    /// specific `flat_root_cmd`, falling back to one with an empty one if not found.
277    ///
278    /// **Note**: This method does not consider workspace-specific completions, only user tables.
279    #[instrument(skip_all)]
280    pub async fn export_user_variable_completions(
281        &self,
282        flat_root_cmd_and_var: impl IntoIterator<Item = (String, String)>,
283    ) -> Result<Vec<VariableCompletion>> {
284        // Flatten the incoming (command, variable) key pairs
285        let flat_keys = flat_root_cmd_and_var.into_iter().collect::<Vec<_>>();
286
287        if flat_keys.is_empty() {
288            return Ok(Vec::new());
289        }
290
291        self.client
292            .conn(move |conn| {
293                let values_placeholders = vec!["(?, ?)"; flat_keys.len()].join(", ");
294                let query = format!(
295                    r#"WITH input_keys(flat_root_cmd, flat_variable) AS (VALUES {values_placeholders})
296                    SELECT
297                        t.id,
298                        t.source,
299                        t.root_cmd,
300                        t.flat_root_cmd,
301                        t.variable,
302                        t.flat_variable,
303                        t.suggestions_provider,
304                        t.created_at,
305                        t.updated_at
306                    FROM (
307                        SELECT
308                            vc.*,
309                            ROW_NUMBER() OVER (
310                                PARTITION BY ik.flat_root_cmd, ik.flat_variable
311                                ORDER BY
312                                    CASE WHEN vc.flat_root_cmd = ik.flat_root_cmd THEN 0 ELSE 1 END
313                            ) as rn
314                        FROM variable_completion vc
315                        JOIN input_keys ik ON vc.flat_variable = ik.flat_variable
316                        WHERE vc.flat_root_cmd = ik.flat_root_cmd 
317                            OR vc.flat_root_cmd = ''
318                    ) AS t
319                    WHERE t.rn = 1
320                    ORDER BY t.root_cmd, t.variable"#
321                );
322                tracing::trace!("Exporting completions: {query}");
323
324                Ok(conn
325                    .prepare(&query)?
326                    .query_map(
327                        rusqlite::params_from_iter(flat_keys.into_iter().flat_map(|(cmd, var)| vec![cmd, var])),
328                        |row| VariableCompletion::try_from(row),
329                    )?
330                    .collect::<Result<Vec<_>, _>>()?)
331            })
332            .await
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use tokio_stream::iter;
339
340    use super::*;
341    use crate::model::{SOURCE_TLDR, SOURCE_USER};
342
343    #[tokio::test]
344    async fn test_import_items_commands() {
345        let storage = SqliteStorage::new_in_memory().await.unwrap();
346
347        let items_to_import = vec![
348            ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cmd1")),
349            ImportExportItem::Command(
350                Command::new(CATEGORY_USER, SOURCE_USER, "cmd2").with_description(Some("original desc".to_string())),
351            ),
352        ];
353
354        // First import (new items)
355        let stream = iter(items_to_import.clone().into_iter().map(Ok));
356        let stats = storage.import_items(stream, false, false).await.unwrap();
357        assert_eq!(stats.commands_imported, 2, "Expected 2 new commands to be imported");
358        assert_eq!(stats.commands_skipped, 0);
359
360        // Second import (no overwrite)
361        let stream = iter(items_to_import.clone().into_iter().map(Ok));
362        let stats = storage.import_items(stream, false, false).await.unwrap();
363        assert_eq!(stats.commands_imported, 0, "Expected 0 commands to be imported");
364        assert_eq!(stats.commands_skipped, 2, "Expected 2 commands to be skipped");
365
366        // Third import (with overwrite)
367        let items_to_update = vec![ImportExportItem::Command(
368            Command::new(CATEGORY_USER, SOURCE_USER, "cmd2").with_description(Some("updated desc".to_string())),
369        )];
370        let stream = iter(items_to_update.into_iter().map(Ok));
371        let stats = storage.import_items(stream, true, false).await.unwrap();
372        assert_eq!(stats.commands_imported, 0, "Expected 0 new commands to be imported");
373        assert_eq!(stats.commands_updated, 1, "Expected 1 command to be updated");
374    }
375
376    #[tokio::test]
377    async fn test_import_items_completions() {
378        let storage = SqliteStorage::new_in_memory().await.unwrap();
379
380        let items_to_import = vec![
381            ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch")),
382            ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps")),
383        ];
384
385        // First import (new items)
386        let stream = iter(items_to_import.clone().into_iter().map(Ok));
387        let stats = storage.import_items(stream, false, false).await.unwrap();
388        assert_eq!(stats.completions_imported, 2);
389        assert_eq!(stats.completions_skipped, 0);
390
391        // Second import (no overwrite)
392        let stream = iter(items_to_import.clone().into_iter().map(Ok));
393        let stats = storage.import_items(stream, false, false).await.unwrap();
394        assert_eq!(stats.completions_imported, 0);
395        assert_eq!(stats.completions_skipped, 2);
396
397        // Third import (with overwrite)
398        let items_to_update = vec![ImportExportItem::Completion(VariableCompletion::new(
399            SOURCE_USER,
400            "git",
401            "branch",
402            "git branch -a",
403        ))];
404        let stream = iter(items_to_update.into_iter().map(Ok));
405        let stats = storage.import_items(stream, true, false).await.unwrap();
406        assert_eq!(stats.completions_imported, 0);
407        assert_eq!(stats.completions_updated, 1);
408    }
409
410    #[tokio::test]
411    async fn test_import_workspace_items() {
412        let (_, stats) = setup_storage(true, true, true).await;
413
414        assert_eq!(
415            stats.commands_imported, 8,
416            "Expected 8 commands inserted into workspace"
417        );
418        assert_eq!(
419            stats.completions_imported, 3,
420            "Expected 3 completions inserted into workspace"
421        );
422        assert_eq!(stats.commands_skipped, 0, "Expected 0 commands skipped in workspace");
423        assert_eq!(
424            stats.completions_skipped, 0,
425            "Expected 0 completions skipped in workspace"
426        );
427    }
428
429    #[tokio::test]
430    async fn test_import_items_mixed_no_overwrite() {
431        let storage = SqliteStorage::new_in_memory().await.unwrap();
432
433        let items_to_import = vec![
434            ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cmd1")),
435            ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch")),
436            ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cmd2")),
437            ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps")),
438        ];
439
440        // First import (new items)
441        let stream = iter(items_to_import.clone().into_iter().map(Ok));
442        let stats = storage.import_items(stream, false, false).await.unwrap();
443        assert_eq!(stats.commands_imported, 2);
444        assert_eq!(stats.completions_imported, 2);
445        assert_eq!(stats.commands_skipped, 0);
446        assert_eq!(stats.completions_skipped, 0);
447
448        // Second import (no overwrite)
449        let stream = iter(items_to_import.into_iter().map(Ok));
450        let stats = storage.import_items(stream, false, false).await.unwrap();
451        assert_eq!(stats.commands_imported, 0);
452        assert_eq!(stats.completions_imported, 0);
453        assert_eq!(stats.commands_skipped, 2);
454        assert_eq!(stats.completions_skipped, 2);
455    }
456
457    #[tokio::test]
458    async fn test_import_items_mixed_with_overwrite() {
459        let (storage, _) = setup_storage(true, true, false).await;
460
461        let items_to_import = vec![
462            // Update an existing command
463            ImportExportItem::Command(
464                Command::new(CATEGORY_USER, SOURCE_USER, "git status")
465                    .with_description(Some("new description".to_string())),
466            ),
467            // Add a new command
468            ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "new command")),
469            // Update an existing completion
470            ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch -a")),
471            // Add a new completion
472            ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "npm", "script", "npm run")),
473        ];
474
475        let stream = iter(items_to_import.into_iter().map(Ok));
476        let stats = storage.import_items(stream, true, false).await.unwrap();
477
478        assert_eq!(stats.commands_updated, 1, "Expected 1 command to be updated");
479        assert_eq!(stats.commands_imported, 1, "Expected 1 new command to be imported");
480        assert_eq!(stats.completions_updated, 1, "Expected 1 completion to be updated");
481        assert_eq!(
482            stats.completions_imported, 1,
483            "Expected 1 new completion to be imported"
484        );
485    }
486
487    #[tokio::test]
488    async fn test_export_user_commands_no_filter() {
489        let (storage, _) = setup_storage(true, false, false).await;
490        let mut exported_commands = Vec::new();
491        let mut stream = storage.export_user_commands(None).await;
492        while let Some(Ok(cmd)) = stream.next().await {
493            exported_commands.push(cmd);
494        }
495
496        assert_eq!(exported_commands.len(), 7, "Expected 7 user commands to be exported");
497    }
498
499    #[tokio::test]
500    async fn test_export_user_commands_with_filter() {
501        let (storage, _) = setup_storage(true, false, false).await;
502        let filter = Regex::new(r"^git").unwrap();
503        let mut exported_commands = Vec::new();
504        let mut stream = storage.export_user_commands(Some(filter)).await;
505        while let Some(Ok(cmd)) = stream.next().await {
506            exported_commands.push(cmd);
507        }
508
509        assert_eq!(exported_commands.len(), 3, "Expected 3 git commands to be exported");
510
511        let exported_cmd_values: Vec<String> = exported_commands.into_iter().map(|c| c.cmd).collect();
512        assert!(exported_cmd_values.contains(&"git status".to_string()));
513        assert!(exported_cmd_values.contains(&"git checkout main".to_string()));
514        assert!(exported_cmd_values.contains(&"git pull".to_string()));
515    }
516
517    #[tokio::test]
518    async fn test_export_user_variable_completions() {
519        let storage = SqliteStorage::new_in_memory().await.unwrap();
520        let completions_to_insert = vec![
521            // A specific and a generic completion exist for "branch"
522            VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch --specific"),
523            VariableCompletion::new(SOURCE_USER, "", "branch", "git branch --generic"),
524            // Only a generic completion exists for "commit"
525            VariableCompletion::new(SOURCE_USER, "", "commit", "git log --oneline --generic"),
526            // Only a specific completion exists for "container"
527            VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps"),
528        ];
529        for c in completions_to_insert {
530            storage.insert_variable_completion(c).await.unwrap();
531        }
532
533        // Define keys to export, covering all resolution cases
534        let keys_to_export = vec![
535            ("git".to_string(), "branch".to_string()), // Should resolve to the specific version
536            ("git".to_string(), "commit".to_string()), // Should fall back to the generic version
537            ("docker".to_string(), "container".to_string()), // Should resolve to its specific version
538            ("docker".to_string(), "nonexistent".to_string()), // Should find nothing
539        ];
540
541        // Export completions
542        let found = storage.export_user_variable_completions(keys_to_export).await.unwrap();
543        assert_eq!(found.len(), 3, "Should export 3 completions based on precedence rules");
544
545        // Assert 'commit' fell back to the generic completion
546        let commit = &found[0];
547        assert_eq!(
548            commit.flat_root_cmd, "",
549            "Should have fallen back to the empty root cmd for commit"
550        );
551        assert_eq!(commit.flat_variable, "commit");
552        assert_eq!(commit.suggestions_provider, "git log --oneline --generic");
553
554        // Assert 'container' was resolved to its specific completion
555        let container = &found[1];
556        assert_eq!(container.flat_root_cmd, "docker");
557        assert_eq!(container.flat_variable, "container");
558        assert_eq!(container.suggestions_provider, "docker ps");
559
560        // Assert 'branch' was resolved to the specific completion
561        let branch = &found[2];
562        assert_eq!(
563            branch.flat_root_cmd, "git",
564            "Should have picked the specific root cmd for branch"
565        );
566        assert_eq!(branch.flat_variable, "branch");
567        assert_eq!(branch.suggestions_provider, "git branch --specific");
568
569        // Test the edge case of exporting with an empty list of keys
570        let found_empty = storage.export_user_variable_completions([]).await.unwrap();
571        assert!(found_empty.is_empty(), "Should return an empty vec for empty keys");
572    }
573
574    /// Helper function to set up storage with predefined test data
575    async fn setup_storage(
576        with_commands: bool,
577        with_completions: bool,
578        workspace: bool,
579    ) -> (SqliteStorage, ImportStats) {
580        let storage = SqliteStorage::new_in_memory().await.unwrap();
581        if workspace {
582            storage.setup_workspace_storage().await.unwrap();
583        }
584
585        let mut items_to_import = Vec::new();
586        if with_commands {
587            items_to_import.extend(vec![
588                ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "git status")),
589                ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")),
590                ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "git pull")),
591                ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "docker ps")),
592                ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "docker-compose up")),
593                ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "npm install")),
594                ImportExportItem::Command(Command::new(CATEGORY_USER, SOURCE_USER, "cargo build")),
595                // A non-user command that should not be exported by user-only functions
596                ImportExportItem::Command(Command::new("common", SOURCE_TLDR, "ls -la")),
597            ]);
598        }
599        if with_completions {
600            items_to_import.extend(vec![
601                ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "git", "branch", "git branch")),
602                ImportExportItem::Completion(VariableCompletion::new(
603                    SOURCE_USER,
604                    "git",
605                    "commit",
606                    "git log --oneline",
607                )),
608                ImportExportItem::Completion(VariableCompletion::new(SOURCE_USER, "docker", "container", "docker ps")),
609            ]);
610        }
611
612        let stats = if !items_to_import.is_empty() {
613            let stream = iter(items_to_import.into_iter().map(Ok));
614            storage.import_items(stream, false, workspace).await.unwrap()
615        } else {
616            ImportStats::default()
617        };
618
619        (storage, stats)
620    }
621}