intelli_shell/storage/
command.rs

1use std::{cmp::Ordering, pin::pin};
2
3use chrono::{DateTime, Utc};
4use color_eyre::{
5    Report, Result,
6    eyre::{Context, eyre},
7};
8use futures_util::StreamExt;
9use regex::Regex;
10use rusqlite::{Row, fallible_iterator::FallibleIterator, ffi, types::Type};
11use sea_query::SqliteQueryBuilder;
12use sea_query_rusqlite::RusqliteBinder;
13use tokio::sync::mpsc;
14use tokio_stream::{Stream, wrappers::ReceiverStream};
15use tracing::instrument;
16use uuid::Uuid;
17
18use super::{SqliteStorage, queries::*};
19use crate::{
20    config::SearchCommandTuning,
21    errors::{InsertError, SearchError, UpdateError},
22    model::{CATEGORY_USER, Command, SOURCE_TLDR, SearchCommandsFilter},
23};
24
25impl SqliteStorage {
26    /// Creates temporary tables for workspace-specific commands for the current session by reflecting the schema of the
27    /// main `command` table.
28    #[instrument(skip_all)]
29    pub async fn setup_workspace_storage(&self) -> Result<()> {
30        self.client
31            .conn_mut::<_, _, Report>(|conn| {
32                // Fetch the schema for the main tables and triggers
33                let schemas: Vec<String> = conn
34                    .prepare(
35                        r"SELECT sql 
36                        FROM sqlite_master 
37                        WHERE (type = 'table' AND name = 'command') 
38                            OR (type = 'table' AND name LIKE 'command_%fts')
39                            OR (type = 'trigger' AND name LIKE 'command_%_fts' AND tbl_name = 'command')",
40                    )?
41                    .query_map([], |row| row.get(0))?
42                    .collect::<Result<Vec<String>, _>>()?;
43
44                let tx = conn.transaction()?;
45
46                // Modify and execute each schema statement to create temporary versions
47                for schema in schemas {
48                    let temp_schema = schema
49                        .replace("command", "workspace_command")
50                        .replace("CREATE TABLE", "CREATE TEMP TABLE")
51                        .replace("CREATE VIRTUAL TABLE ", "CREATE VIRTUAL TABLE temp.")
52                        .replace("CREATE TRIGGER", "CREATE TEMP TRIGGER");
53                    tx.execute(&temp_schema, [])?;
54                }
55
56                tx.commit()?;
57                Ok(())
58            })
59            .await
60            .wrap_err("Failed to create temporary workspace storage from schema")?;
61
62        self.workspace_tables_loaded
63            .store(true, std::sync::atomic::Ordering::SeqCst);
64
65        Ok(())
66    }
67
68    /// Determines if the storage is empty, i.e., if there are no commands in the database
69    #[instrument(skip_all)]
70    pub async fn is_empty(&self) -> Result<bool> {
71        let workspace_tables_loaded = self.workspace_tables_loaded.load(std::sync::atomic::Ordering::SeqCst);
72        self.client
73            .conn::<_, _, Report>(move |conn| {
74                if workspace_tables_loaded {
75                    Ok(conn.query_row(
76                        "SELECT NOT EXISTS (SELECT 1 FROM command UNION ALL SELECT 1 FROM workspace_command)",
77                        [],
78                        |r| r.get(0),
79                    )?)
80                } else {
81                    Ok(conn.query_row("SELECT NOT EXISTS(SELECT 1 FROM command)", [], |r| r.get(0))?)
82                }
83            })
84            .await
85            .wrap_err("Couldn't check if storage is empty")
86    }
87
88    /// Retrieves all tags from the database along with their usage statistics and if it's an exact match for the prefix
89    #[instrument(skip_all)]
90    pub async fn find_tags(
91        &self,
92        filter: SearchCommandsFilter,
93        tag_prefix: Option<String>,
94        tuning: &SearchCommandTuning,
95    ) -> Result<Vec<(String, u64, bool)>, SearchError> {
96        let workspace_tables_loaded = self.workspace_tables_loaded.load(std::sync::atomic::Ordering::SeqCst);
97        let query = query_find_tags(filter, tag_prefix, tuning, workspace_tables_loaded)?;
98        if tracing::enabled!(tracing::Level::TRACE) {
99            tracing::trace!("Querying tags:\n{}", query.to_string(SqliteQueryBuilder));
100        }
101        let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
102        Ok(self
103            .client
104            .conn::<_, _, Report>(move |conn| {
105                conn.prepare(&stmt)?
106                    .query(&*values.as_params())?
107                    .and_then(|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)))
108                    .collect()
109            })
110            .await
111            .wrap_err("Couldn't find tags")?)
112    }
113
114    /// Finds and retrieves commands from the database.
115    ///
116    /// When a search term is present, if there's a command which alias exactly match the term, that'll be the only one
117    /// returned.
118    #[instrument(skip_all)]
119    pub async fn find_commands(
120        &self,
121        filter: SearchCommandsFilter,
122        working_path: impl Into<String>,
123        tuning: &SearchCommandTuning,
124    ) -> Result<(Vec<Command>, bool), SearchError> {
125        let workspace_tables_loaded = self.workspace_tables_loaded.load(std::sync::atomic::Ordering::SeqCst);
126        let cleaned_filter = filter.cleaned();
127
128        // When there's a search term
129        let mut query_alias = None;
130        if let Some(ref term) = cleaned_filter.search_term {
131            // Prepare the query for the alias as well
132            query_alias = Some((
133                format!(
134                    r#"SELECT c.rowid, c.* 
135                    FROM command c 
136                    WHERE c.alias IS NOT NULL AND c.alias = ?1 
137                    ORDER BY c.cmd ASC
138                    LIMIT {QUERY_LIMIT}"#
139                ),
140                (term.clone(),),
141            ));
142        }
143
144        // Build the commands query when no alias is matched
145        let query = query_find_commands(cleaned_filter, working_path, tuning, workspace_tables_loaded)?;
146        let query_trace = if tracing::enabled!(tracing::Level::TRACE) {
147            query.to_string(SqliteQueryBuilder)
148        } else {
149            String::default()
150        };
151        let (stmt, values) = query.build_rusqlite(SqliteQueryBuilder);
152
153        // Execute the queries
154        let tuning = *tuning;
155        Ok(self
156            .client
157            .conn::<_, _, Report>(move |conn| {
158                // If there's a query to find the command by alias
159                if let Some((query_alias, a_params)) = query_alias {
160                    // Run the query
161                    let rows = conn
162                        .prepare(&query_alias)?
163                        .query(a_params)?
164                        .map(|r| Command::try_from(r))
165                        .collect::<Vec<_>>()?;
166                    // Return the rows if there's a match
167                    if !rows.is_empty() {
168                        return Ok((rows, true));
169                    }
170                }
171                // Otherwise, run the regular search query and re-rank results
172                if tracing::enabled!(tracing::Level::TRACE) {
173                    tracing::trace!("Querying commands:\n{query_trace}");
174                }
175                Ok((
176                    rerank_query_results(
177                        conn.prepare(&stmt)?
178                            .query(&*values.as_params())?
179                            .and_then(|r| QueryResultItem::try_from(r))
180                            .collect::<Result<Vec<_>, _>>()?,
181                        &tuning,
182                    ),
183                    false,
184                ))
185            })
186            .await
187            .wrap_err("Couldn't search commands")?)
188    }
189
190    /// Imports a collection of commands into the database.
191    ///
192    /// This function allows for bulk insertion or updating of commands from a stream.
193    /// The behavior for existing commands depends on the `overwrite` flag.
194    ///
195    /// Returns the number of new commands inserted and skipped/updated.
196    #[instrument(skip_all)]
197    pub async fn import_commands(
198        &self,
199        commands: impl Stream<Item = Result<Command>> + Send + 'static,
200        filter: Option<Regex>,
201        overwrite: bool,
202        workspace: bool,
203    ) -> Result<(u64, u64)> {
204        // Create a channel to bridge the async stream with the sync database operations
205        let (tx, mut rx) = mpsc::channel(100);
206
207        // Spawn a producer task to read from the async stream and send to the channel
208        tokio::spawn(async move {
209            // Pin the stream to be able to iterate over it
210            let mut commands = pin!(commands);
211            while let Some(command_result) = commands.next().await {
212                if tx.send(command_result).await.is_err() {
213                    // Receiver has been dropped, so we can stop
214                    tracing::debug!("Import stream channel closed by receiver");
215                    break;
216                }
217            }
218        });
219
220        // Determine which table to import into based on the `workspace` flag
221        let table = if workspace { "workspace_command" } else { "command" };
222
223        self.client
224            .conn_mut::<_, _, Report>(move |conn| {
225                let mut inserted = 0;
226                let mut skipped_or_updated = 0;
227                let tx = conn.transaction()?;
228                let mut stmt = if overwrite {
229                    tx.prepare(&format!(
230                        r#"INSERT INTO {table} (
231                                id,
232                                category,
233                                source,
234                                alias,
235                                cmd,
236                                flat_cmd,
237                                description,
238                                flat_description,
239                                tags,
240                                created_at
241                            ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
242                            ON CONFLICT (cmd) DO UPDATE SET
243                                alias = COALESCE(excluded.alias, alias),
244                                cmd = excluded.cmd,
245                                flat_cmd = excluded.flat_cmd,
246                                description = COALESCE(excluded.description, description),
247                                flat_description = COALESCE(excluded.flat_description, flat_description),
248                                tags = COALESCE(excluded.tags, tags),
249                                updated_at = excluded.created_at
250                            RETURNING updated_at;"#
251                    ))?
252                } else {
253                    tx.prepare(&format!(
254                        r#"INSERT OR IGNORE INTO {table} (
255                                id,
256                                category,
257                                source,
258                                alias,
259                                cmd,
260                                flat_cmd,
261                                description,
262                                flat_description,
263                                tags,
264                                created_at
265                            ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
266                            RETURNING updated_at;"#,
267                    ))?
268                };
269
270                // Process commands from the channel
271                while let Some(command_result) = rx.blocking_recv() {
272                    let command = command_result?;
273
274                    // If there's a filter for imported commands
275                    if let Some(ref filter) = filter {
276                        // Skip the command when it doesn't pass the filter
277                        let matches_filter = filter.is_match(&command.cmd)
278                            || command.description.as_ref().is_some_and(|d| filter.is_match(d));
279                        if !matches_filter {
280                            continue;
281                        }
282                    }
283
284                    let mut rows = stmt.query((
285                        &command.id,
286                        &command.category,
287                        &command.source,
288                        &command.alias,
289                        &command.cmd,
290                        &command.flat_cmd,
291                        &command.description,
292                        &command.flat_description,
293                        serde_json::to_value(&command.tags)?,
294                        &command.created_at,
295                    ))?;
296
297                    match rows.next()? {
298                        // No row returned, this happens only when overwrite = false, meaning it was skipped
299                        None => skipped_or_updated += 1,
300                        // When a row is returned (can happen on both paths)
301                        Some(r) => {
302                            let updated_at = r.get::<_, Option<DateTime<Utc>>>(0)?;
303                            match updated_at {
304                                // If there's no update date, it's a new insert
305                                None => inserted += 1,
306                                // If it has a value, it was updated
307                                Some(_) => skipped_or_updated += 1,
308                            }
309                        }
310                    }
311                }
312
313                drop(stmt);
314                tx.commit()?;
315                Ok((inserted, skipped_or_updated))
316            })
317            .await
318            .wrap_err("Couldn't import commands")
319    }
320
321    /// Export user commands
322    #[instrument(skip_all)]
323    pub async fn export_user_commands(
324        &self,
325        filter: Option<Regex>,
326    ) -> impl Stream<Item = Result<Command>> + Send + 'static {
327        // Create a channel to stream results from the database with a small buffer to provide backpressure
328        let (tx, rx) = mpsc::channel(100);
329
330        // Spawn a new task to run the query and send results back through the channel
331        let client = self.client.clone();
332        tokio::spawn(async move {
333            let res: Result<(), Report> = client
334                .conn_mut(move |conn| {
335                    // Prepare the query
336                    let mut q_values = vec![CATEGORY_USER.to_owned()];
337                    let mut query = String::from(
338                        r"SELECT
339                            rowid,
340                            id,
341                            category,
342                            source,
343                            alias,
344                            cmd,
345                            flat_cmd,
346                            description,
347                            flat_description,
348                            tags,
349                            created_at,
350                            updated_at
351                        FROM command
352                        WHERE category = ?1",
353                    );
354                    if let Some(filter) = filter {
355                        q_values.push(filter.as_str().to_owned());
356                        query.push_str(" AND (cmd REGEXP ?2 OR (description IS NOT NULL AND description REGEXP ?2))");
357                    }
358                    query.push_str("\nORDER BY cmd ASC");
359
360                    // Create an iterator over the rows
361                    let mut stmt = conn.prepare(&query)?;
362                    let records_iter =
363                        stmt.query_and_then(rusqlite::params_from_iter(q_values), |r| Command::try_from(r))?;
364
365                    // Iterate and send each record back through the channel
366                    for record_result in records_iter {
367                        if tx
368                            .blocking_send(record_result.wrap_err("Error fetching command"))
369                            .is_err()
370                        {
371                            tracing::debug!("Async stream receiver dropped, closing db query");
372                            break;
373                        }
374                    }
375
376                    Ok(())
377                })
378                .await;
379            if let Err(e) = res {
380                panic!("Couldn't fetch commands to export: {e:?}");
381            }
382        });
383
384        // Return the receiver stream
385        ReceiverStream::new(rx)
386    }
387
388    /// Removes tldr commands
389    #[instrument(skip_all)]
390    pub async fn delete_tldr_commands(&self, category: Option<String>) -> Result<u64> {
391        self.client
392            .conn_mut::<_, _, Report>(move |conn| {
393                let mut query = String::from("DELETE FROM command WHERE source = ?1");
394                let mut params: Vec<String> = vec![SOURCE_TLDR.to_owned()];
395                if let Some(cat) = category {
396                    query.push_str(" AND category = ?2");
397                    params.push(cat);
398                }
399                let affected = conn.execute(&query, rusqlite::params_from_iter(params))?;
400                Ok(affected as u64)
401            })
402            .await
403            .wrap_err("Couldn't remove tldr commands")
404    }
405
406    /// Inserts a new command into the database.
407    ///
408    /// If a command with the same `id` or `cmd` already exists in the database, an error will be returned.
409    #[instrument(skip_all)]
410    pub async fn insert_command(&self, command: Command) -> Result<Command, InsertError> {
411        self.client
412            .conn(move |conn| {
413                let res = conn.execute(
414                    r#"INSERT INTO command (
415                        id,
416                        category,
417                        source,
418                        alias,
419                        cmd,
420                        flat_cmd,
421                        description,
422                        flat_description,
423                        tags,
424                        created_at,
425                        updated_at
426                    ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)"#,
427                    (
428                        &command.id,
429                        &command.category,
430                        &command.source,
431                        &command.alias,
432                        &command.cmd,
433                        &command.flat_cmd,
434                        &command.description,
435                        &command.flat_description,
436                        serde_json::to_value(&command.tags).wrap_err("Couldn't insert a command")?,
437                        &command.created_at,
438                        &command.updated_at,
439                    ),
440                );
441                match res {
442                    Ok(_) => Ok(command),
443                    Err(err) => {
444                        let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
445                        if code == ffi::SQLITE_CONSTRAINT_UNIQUE || code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY {
446                            Err(InsertError::AlreadyExists)
447                        } else {
448                            Err(Report::from(err).wrap_err("Couldn't insert a command").into())
449                        }
450                    }
451                }
452            })
453            .await
454    }
455
456    /// Updates an existing command in the database.
457    ///
458    /// If the command to be updated does not exist, an error will be returned.
459    #[instrument(skip_all)]
460    pub async fn update_command(&self, command: Command) -> Result<Command, UpdateError> {
461        self.client
462            .conn(move |conn| {
463                let res = conn.execute(
464                    r#"UPDATE command SET 
465                        category = ?2,
466                        source = ?3,
467                        alias = ?4,
468                        cmd = ?5,
469                        flat_cmd = ?6,
470                        description = ?7,
471                        flat_description = ?8,
472                        tags = ?9,
473                        created_at = ?10,
474                        updated_at = ?11
475                    WHERE id = ?1"#,
476                    (
477                        &command.id,
478                        &command.category,
479                        &command.source,
480                        &command.alias,
481                        &command.cmd,
482                        &command.flat_cmd,
483                        &command.description,
484                        &command.flat_description,
485                        serde_json::to_value(&command.tags).wrap_err("Couldn't update a command")?,
486                        &command.created_at,
487                        &command.updated_at,
488                    ),
489                );
490                match res {
491                    Ok(0) => Err(eyre!("Command not found: {}", command.id)
492                        .wrap_err("Couldn't update a command")
493                        .into()),
494                    Ok(_) => Ok(command),
495                    Err(err) => {
496                        let code = err.sqlite_error().map(|e| e.extended_code).unwrap_or_default();
497                        if code == ffi::SQLITE_CONSTRAINT_UNIQUE {
498                            Err(UpdateError::AlreadyExists)
499                        } else {
500                            Err(Report::from(err).wrap_err("Couldn't update a command").into())
501                        }
502                    }
503                }
504            })
505            .await
506    }
507
508    /// Increments the usage of a command
509    #[instrument(skip_all)]
510    pub async fn increment_command_usage(
511        &self,
512        command_id: Uuid,
513        path: impl AsRef<str> + Send + 'static,
514    ) -> Result<i32, UpdateError> {
515        self.client
516            .conn_mut(move |conn| {
517                let res = conn.query_row(
518                    r#"
519                    INSERT INTO command_usage (command_id, path, usage_count)
520                    VALUES (?1, ?2, 1)
521                    ON CONFLICT(command_id, path) DO UPDATE SET
522                        usage_count = usage_count + 1
523                    RETURNING usage_count;"#,
524                    (&command_id, &path.as_ref()),
525                    |r| r.get(0),
526                );
527                match res {
528                    Ok(u) => Ok(u),
529                    Err(err) => Err(Report::from(err).wrap_err("Couldn't update a command usage").into()),
530                }
531            })
532            .await
533    }
534
535    /// Deletes an existing command from the database.
536    ///
537    /// If the command to be deleted does not exist, an error will be returned.
538    #[instrument(skip_all)]
539    pub async fn delete_command(&self, command_id: Uuid) -> Result<()> {
540        self.client
541            .conn(move |conn| {
542                let res = conn.execute("DELETE FROM command WHERE id = ?1", (&command_id,));
543                match res {
544                    Ok(0) => Err(eyre!("Command not found: {command_id}").wrap_err("Couldn't delete a command")),
545                    Ok(_) => Ok(()),
546                    Err(err) => Err(Report::from(err).wrap_err("Couldn't delete a command")),
547                }
548            })
549            .await
550    }
551}
552
553/// Re-ranks a vector of [`QueryResultItem`] based on a combined score and command type.
554///
555/// The ranking priority is as follows:
556/// 1. Template matches (highest priority)
557/// 2. Workspace-specific commands
558/// 3. Other commands
559///
560/// Within categories 2 and 3, items are sorted based on a combined score of normalized text, path, and usage scores.
561fn rerank_query_results(items: Vec<QueryResultItem>, tuning: &SearchCommandTuning) -> Vec<Command> {
562    // Handle empty or single-item input
563    if items.is_empty() {
564        return Vec::new();
565    }
566    if items.len() == 1 {
567        return items.into_iter().map(|item| item.command).collect();
568    }
569
570    // 1. Partition results into template matches and all others
571    // Template matches have a fixed high rank and are handled separately to ensure they are always first
572    let (template_matches, mut other_items): (Vec<_>, Vec<_>) = items
573        .into_iter()
574        .partition(|item| item.text_score >= TEMPLATE_MATCH_RANK);
575    if !template_matches.is_empty() {
576        tracing::trace!("Found {} template matches", template_matches.len());
577    }
578
579    // Convert template matches to Command structs
580    let mut final_commands: Vec<Command> = template_matches.into_iter().map(|item| item.command).collect();
581
582    // If there are no other items, or only one, no complex normalization is needed
583    if other_items.len() <= 1 {
584        final_commands.extend(other_items.into_iter().map(|item| item.command));
585        return final_commands;
586    }
587
588    // Find min / max for all three scores
589    let mut min_text = f64::INFINITY;
590    let mut max_text = f64::NEG_INFINITY;
591    let mut min_path = f64::INFINITY;
592    let mut max_path = f64::NEG_INFINITY;
593    let mut min_usage = f64::INFINITY;
594    let mut max_usage = f64::NEG_INFINITY;
595    for item in &other_items {
596        min_text = min_text.min(item.text_score);
597        max_text = max_text.max(item.text_score);
598        min_path = min_path.min(item.path_score);
599        max_path = max_path.max(item.path_score);
600        min_usage = min_usage.min(item.usage_score);
601        max_usage = max_usage.max(item.usage_score);
602    }
603
604    // Calculate score ranges for normalization
605    let range_text = (max_text > min_text).then_some(max_text - min_text);
606    let range_path = (max_path > min_path).then_some(max_path - min_path);
607    let range_usage = (max_usage > min_usage).then_some(max_usage - min_usage);
608
609    // Sort items based on the combined normalized score
610    other_items.sort_by(|a, b| {
611        // Primary sort key: Workspace commands first (descending order for bool)
612        match b.is_workspace_command.cmp(&a.is_workspace_command) {
613            Ordering::Equal => {
614                // Secondary sort key: Calculated score. Only compute if primary keys are equal.
615                let calculate_score = |item: &QueryResultItem| -> f64 {
616                    // Normalize each score to a 0.0 ~ 1.0 range
617                    // If the range is 0, the score is neutral (0.5)
618                    let norm_text = range_text.map_or(0.5, |range| (item.text_score - min_text) / range);
619                    let norm_path = range_path.map_or(0.5, |range| (item.path_score - min_path) / range);
620                    let norm_usage = range_usage.map_or(0.5, |range| (item.usage_score - min_usage) / range);
621
622                    // Apply points from tuning configuration
623                    (norm_text * tuning.text.points as f64)
624                        + (norm_path * tuning.path.points as f64)
625                        + (norm_usage * tuning.usage.points as f64)
626                };
627
628                let final_score_a = calculate_score(a);
629                let final_score_b = calculate_score(b);
630
631                // Sort by final_score in descending order (higher score is better)
632                final_score_b.partial_cmp(&final_score_a).unwrap_or(Ordering::Equal)
633            }
634            // If items are in different categories (workspace vs. other), use the primary ordering
635            other => other,
636        }
637    });
638
639    // Append the sorted "other" items to the high-priority template commands
640    final_commands.extend(other_items.into_iter().map(|item| item.command));
641    final_commands
642}
643
644impl<'a> TryFrom<&'a Row<'a>> for Command {
645    type Error = rusqlite::Error;
646
647    fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
648        Ok(Self {
649            // rowid is skipped
650            id: row.get(1)?,
651            category: row.get(2)?,
652            source: row.get(3)?,
653            alias: row.get(4)?,
654            cmd: row.get(5)?,
655            flat_cmd: row.get(6)?,
656            description: row.get(7)?,
657            flat_description: row.get(8)?,
658            tags: serde_json::from_value(row.get::<_, serde_json::Value>(9)?)
659                .map_err(|e| rusqlite::Error::FromSqlConversionFailure(9, Type::Text, Box::new(e)))?,
660            created_at: row.get(10)?,
661            updated_at: row.get(11)?,
662        })
663    }
664}
665
666/// Struct representing a command query result item when using FTS ranking
667struct QueryResultItem {
668    /// The command associated with this result item
669    command: Command,
670    /// Whether this command is included in the workspace commands file
671    is_workspace_command: bool,
672    /// Score for the command global usage
673    usage_score: f64,
674    /// Score for the command path usage relevance
675    path_score: f64,
676    /// Score for the text relevance
677    text_score: f64,
678}
679
680impl<'a> TryFrom<&'a Row<'a>> for QueryResultItem {
681    type Error = rusqlite::Error;
682
683    fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
684        Ok(Self {
685            command: Command::try_from(row)?,
686            is_workspace_command: row.get(12)?,
687            usage_score: row.get(13)?,
688            path_score: row.get(14)?,
689            text_score: row.get(15)?,
690        })
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use futures_util::StreamExt;
697    use pretty_assertions::assert_eq;
698    use strum::IntoEnumIterator;
699    use tokio_stream::iter;
700    use uuid::Uuid;
701
702    use super::*;
703    use crate::model::{CATEGORY_USER, SOURCE_IMPORT, SOURCE_USER, SearchMode};
704
705    const PROJ_A_PATH: &str = "/home/user/project-a";
706    const PROJ_A_API_PATH: &str = "/home/user/project-a/api";
707    const PROJ_B_PATH: &str = "/home/user/project-b";
708    const UNRELATED_PATH: &str = "/var/log";
709
710    #[tokio::test]
711    async fn test_setup_workspace_storage() {
712        let storage = SqliteStorage::new_in_memory().await.unwrap();
713        storage.check_sqlite_version().await;
714        let res = storage.setup_workspace_storage().await;
715        assert!(res.is_ok(), "Expected workspace storage setup to succeed: {res:?}");
716    }
717
718    #[tokio::test]
719    async fn test_is_empty() {
720        let storage = SqliteStorage::new_in_memory().await.unwrap();
721        assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
722
723        let cmd = Command {
724            id: Uuid::now_v7(),
725            cmd: "test_cmd".to_string(),
726            ..Default::default()
727        };
728        storage.insert_command(cmd).await.unwrap();
729
730        assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
731    }
732
733    #[tokio::test]
734    async fn test_is_empty_with_workspace() {
735        let storage = SqliteStorage::new_in_memory().await.unwrap();
736        storage.setup_workspace_storage().await.unwrap();
737        assert!(storage.is_empty().await.unwrap(), "Expected empty storage initially");
738
739        let cmd = Command {
740            id: Uuid::now_v7(),
741            cmd: "test_cmd".to_string(),
742            ..Default::default()
743        };
744        storage.insert_command(cmd).await.unwrap();
745
746        assert!(!storage.is_empty().await.unwrap(), "Expected non-empty after insert");
747    }
748
749    #[tokio::test]
750    async fn test_find_tags_no_filters() -> Result<(), SearchError> {
751        let storage = setup_ranking_storage().await;
752
753        let result = storage
754            .find_tags(SearchCommandsFilter::default(), None, &SearchCommandTuning::default())
755            .await?;
756
757        let expected = vec![
758            ("#git".to_string(), 5, false),
759            ("#build".to_string(), 2, false),
760            ("#commit".to_string(), 2, false),
761            ("#docker".to_string(), 2, false),
762            ("#list".to_string(), 2, false),
763            ("#k8s".to_string(), 1, false),
764            ("#npm".to_string(), 1, false),
765            ("#pod".to_string(), 1, false),
766            ("#push".to_string(), 1, false),
767            ("#unix".to_string(), 1, false),
768        ];
769
770        assert_eq!(result.len(), 10, "Expected 10 unique tags");
771        assert_eq!(result, expected, "Tags list or order mismatch");
772
773        Ok(())
774    }
775
776    #[tokio::test]
777    async fn test_find_tags_filter_by_tags_only() -> Result<(), SearchError> {
778        let storage = setup_ranking_storage().await;
779
780        let filter1 = SearchCommandsFilter {
781            tags: Some(vec!["#git".to_string()]),
782            ..Default::default()
783        };
784        let result1 = storage
785            .find_tags(filter1, None, &SearchCommandTuning::default())
786            .await?;
787        let expected1 = vec![("#commit".to_string(), 2, false), ("#push".to_string(), 1, false)];
788        assert_eq!(result1.len(), 2,);
789        assert_eq!(result1, expected1);
790
791        let filter2 = SearchCommandsFilter {
792            tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
793            ..Default::default()
794        };
795        let result2 = storage
796            .find_tags(filter2, None, &SearchCommandTuning::default())
797            .await?;
798        assert!(result2.is_empty());
799
800        let filter3 = SearchCommandsFilter {
801            tags: Some(vec!["#list".to_string()]),
802            ..Default::default()
803        };
804        let result3 = storage
805            .find_tags(filter3, None, &SearchCommandTuning::default())
806            .await?;
807        let expected3 = vec![("#docker".to_string(), 1, false), ("#unix".to_string(), 1, false)];
808        assert_eq!(result3.len(), 2);
809        assert_eq!(result3, expected3);
810
811        Ok(())
812    }
813
814    #[tokio::test]
815    async fn test_find_tags_filter_by_prefix_only() -> Result<(), SearchError> {
816        let storage = setup_ranking_storage().await;
817
818        let result = storage
819            .find_tags(
820                SearchCommandsFilter::default(),
821                Some("#comm".to_string()),
822                &SearchCommandTuning::default(),
823            )
824            .await?;
825        let expected = vec![("#commit".to_string(), 2, false)];
826        assert_eq!(result.len(), 1);
827        assert_eq!(result, expected);
828
829        Ok(())
830    }
831
832    #[tokio::test]
833    async fn test_find_tags_filter_by_tags_and_prefix() -> Result<(), SearchError> {
834        let storage = setup_ranking_storage().await;
835
836        let filter1 = SearchCommandsFilter {
837            tags: Some(vec!["#git".to_string()]),
838            ..Default::default()
839        };
840        let result1 = storage
841            .find_tags(filter1, Some("#comm".to_string()), &SearchCommandTuning::default())
842            .await?;
843        let expected1 = vec![("#commit".to_string(), 2, false)];
844        assert_eq!(result1.len(), 1);
845        assert_eq!(result1, expected1);
846
847        let filter2 = SearchCommandsFilter {
848            tags: Some(vec!["#git".to_string()]),
849            ..Default::default()
850        };
851        let result2 = storage
852            .find_tags(filter2, Some("#push".to_string()), &SearchCommandTuning::default())
853            .await?;
854        let expected2 = vec![("#push".to_string(), 1, true)];
855        assert_eq!(result2.len(), 1);
856        assert_eq!(result2, expected2);
857
858        Ok(())
859    }
860
861    #[tokio::test]
862    async fn test_find_commands_no_filter() {
863        let storage = setup_ranking_storage().await;
864        let filter = SearchCommandsFilter::default();
865        let (commands, _) = storage
866            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
867            .await
868            .unwrap();
869        assert_eq!(commands.len(), 10, "Expected all sample commands");
870    }
871
872    #[tokio::test]
873    async fn test_find_commands_filter_by_category() {
874        let storage = setup_ranking_storage().await;
875        let filter = SearchCommandsFilter {
876            category: Some(vec!["git".to_string()]),
877            ..Default::default()
878        };
879        let (commands, _) = storage
880            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
881            .await
882            .unwrap();
883        assert_eq!(commands.len(), 2);
884        assert!(commands.iter().all(|c| c.category == "git"));
885
886        let filter_no_match = SearchCommandsFilter {
887            category: Some(vec!["nonexistent".to_string()]),
888            ..Default::default()
889        };
890        let (commands_no_match, _) = storage
891            .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
892            .await
893            .unwrap();
894        assert!(commands_no_match.is_empty());
895    }
896
897    #[tokio::test]
898    async fn test_find_commands_filter_by_source() {
899        let storage = setup_ranking_storage().await;
900        let filter = SearchCommandsFilter {
901            source: Some(SOURCE_TLDR.to_string()),
902            ..Default::default()
903        };
904        let (commands, _) = storage
905            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
906            .await
907            .unwrap();
908        assert_eq!(commands.len(), 3);
909        assert!(commands.iter().all(|c| c.source == SOURCE_TLDR));
910    }
911
912    #[tokio::test]
913    async fn test_find_commands_filter_by_tags() {
914        let storage = setup_ranking_storage().await;
915        let filter_single_tag = SearchCommandsFilter {
916            tags: Some(vec!["#git".to_string()]),
917            ..Default::default()
918        };
919        let (commands_single_tag, _) = storage
920            .find_commands(filter_single_tag, "/some/path", &SearchCommandTuning::default())
921            .await
922            .unwrap();
923        assert_eq!(commands_single_tag.len(), 5);
924
925        let filter_multiple_tags = SearchCommandsFilter {
926            tags: Some(vec!["#docker".to_string(), "#list".to_string()]),
927            ..Default::default()
928        };
929        let (commands_multiple_tags, _) = storage
930            .find_commands(filter_multiple_tags, "/some/path", &SearchCommandTuning::default())
931            .await
932            .unwrap();
933        assert_eq!(commands_multiple_tags.len(), 1);
934
935        let filter_empty_tags = SearchCommandsFilter {
936            tags: Some(vec![]),
937            ..Default::default()
938        };
939        let (commands_empty_tags, _) = storage
940            .find_commands(filter_empty_tags, "/some/path", &SearchCommandTuning::default())
941            .await
942            .unwrap();
943        assert_eq!(commands_empty_tags.len(), 10);
944    }
945
946    #[tokio::test]
947    async fn test_find_commands_alias_precedence() {
948        let storage = setup_ranking_storage().await;
949        storage
950            .setup_command(
951                Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
952                [("/some/path", 100)],
953            )
954            .await;
955
956        for mode in SearchMode::iter() {
957            let filter = SearchCommandsFilter {
958                search_term: Some("gc".to_string()),
959                search_mode: mode,
960                ..Default::default()
961            };
962            let (commands, alias_match) = storage
963                .find_commands(filter, "", &SearchCommandTuning::default())
964                .await
965                .unwrap();
966            assert!(alias_match, "Expected alias match for mode {mode:?}");
967            assert_eq!(commands.len(), 1, "Expected only alias match for mode {mode:?}");
968            assert_eq!(
969                commands[0].cmd, "git commit -m",
970                "Expected correct alias command for mode {mode:?}"
971            );
972        }
973    }
974
975    #[tokio::test]
976    async fn test_find_commands_search_mode_exact() {
977        let storage = setup_ranking_storage().await;
978        let filter_token_match = SearchCommandsFilter {
979            search_term: Some("commit".to_string()),
980            search_mode: SearchMode::Exact,
981            ..Default::default()
982        };
983        let (commands_token_match, _) = storage
984            .find_commands(filter_token_match, "/some/path", &SearchCommandTuning::default())
985            .await
986            .unwrap();
987        assert_eq!(commands_token_match.len(), 2);
988        assert_eq!(commands_token_match[0].cmd, "git commit -m");
989        assert_eq!(commands_token_match[1].cmd, "git commit -m '{{message}}'");
990
991        let filter_no_match = SearchCommandsFilter {
992            search_term: Some("nonexistentterm".to_string()),
993            search_mode: SearchMode::Exact,
994            ..Default::default()
995        };
996        let (commands_no_match, _) = storage
997            .find_commands(filter_no_match, "/some/path", &SearchCommandTuning::default())
998            .await
999            .unwrap();
1000        assert!(commands_no_match.is_empty());
1001    }
1002
1003    #[tokio::test]
1004    async fn test_find_commands_search_mode_relaxed() {
1005        let storage = setup_ranking_storage().await;
1006        let filter = SearchCommandsFilter {
1007            search_term: Some("docker list".to_string()),
1008            search_mode: SearchMode::Relaxed,
1009            ..Default::default()
1010        };
1011        let (commands, _) = storage
1012            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1013            .await
1014            .unwrap();
1015        assert_eq!(commands.len(), 2);
1016        assert!(commands.iter().any(|c| c.cmd == "docker ps -a"));
1017        assert!(commands.iter().any(|c| c.cmd == "ls -lha"));
1018    }
1019
1020    #[tokio::test]
1021    async fn test_find_commands_search_mode_regex() {
1022        let storage = setup_ranking_storage().await;
1023        let filter = SearchCommandsFilter {
1024            search_term: Some(r"git\s.*it".to_string()),
1025            search_mode: SearchMode::Regex,
1026            ..Default::default()
1027        };
1028        let (commands, _) = storage
1029            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1030            .await
1031            .unwrap();
1032        assert_eq!(commands.len(), 2);
1033        assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
1034        assert_eq!(commands[1].cmd, "git commit -m");
1035
1036        let filter_invalid = SearchCommandsFilter {
1037            search_term: Some("[[invalid_regex".to_string()),
1038            search_mode: SearchMode::Regex,
1039            ..Default::default()
1040        };
1041        assert!(matches!(
1042            storage
1043                .find_commands(filter_invalid, "/some/path", &SearchCommandTuning::default())
1044                .await,
1045            Err(SearchError::InvalidRegex(_))
1046        ));
1047    }
1048
1049    #[tokio::test]
1050    async fn test_find_commands_search_mode_fuzzy() {
1051        let storage = setup_ranking_storage().await;
1052        let filter = SearchCommandsFilter {
1053            search_term: Some("gtcomit".to_string()),
1054            search_mode: SearchMode::Fuzzy,
1055            ..Default::default()
1056        };
1057        let (commands, _) = storage
1058            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1059            .await
1060            .unwrap();
1061        assert_eq!(commands.len(), 2);
1062        assert_eq!(commands[0].cmd, "git commit -m '{{message}}'");
1063        assert_eq!(commands[1].cmd, "git commit -m");
1064
1065        let filter_empty_fuzzy = SearchCommandsFilter {
1066            search_term: Some("'' | ^".to_string()),
1067            search_mode: SearchMode::Fuzzy,
1068            ..Default::default()
1069        };
1070        assert!(matches!(
1071            storage
1072                .find_commands(filter_empty_fuzzy, "/some/path", &SearchCommandTuning::default())
1073                .await,
1074            Err(SearchError::InvalidFuzzy)
1075        ));
1076    }
1077
1078    #[tokio::test]
1079    async fn test_find_commands_search_mode_auto() {
1080        let storage = setup_ranking_storage().await;
1081        let default_tuning = SearchCommandTuning::default();
1082
1083        // Helper closure for running a search and making assertions
1084        let run_search = |term: &'static str, path: &'static str| {
1085            let storage = storage.clone();
1086            async move {
1087                let filter = SearchCommandsFilter {
1088                    search_term: Some(term.to_string()),
1089                    search_mode: SearchMode::Auto,
1090                    ..Default::default()
1091                };
1092                storage.find_commands(filter, path, &default_tuning).await.unwrap()
1093            }
1094        };
1095
1096        // Scenario 1: Basic text and description search
1097        let (commands, _) = run_search("list containers", UNRELATED_PATH).await;
1098        assert!(!commands.is_empty(), "Expected results for 'list containers'");
1099        assert_eq!(
1100            commands[0].cmd, "docker ps -a",
1101            "Expected 'docker ps -a' to be the top result for 'list containers'"
1102        );
1103
1104        // Scenario 2: Prefix and usage search
1105        let (commands, _) = run_search("git commit", PROJ_A_PATH).await;
1106        assert!(commands.len() >= 2, "Expected at least two results for 'git commit'");
1107        assert_eq!(
1108            commands[0].cmd, "git commit -m",
1109            "Expected 'git commit -m' to be the top result for 'git commit' due to usage"
1110        );
1111        assert_eq!(
1112            commands[1].cmd, "git commit -m '{{message}}'",
1113            "Expected template command to be second for 'git commit'"
1114        );
1115
1116        // Scenario 3: Template matching
1117        let (commands, _) = run_search("git commit -m 'my new feature'", PROJ_A_PATH).await;
1118        assert!(!commands.is_empty(), "Expected results for template match");
1119        assert_eq!(
1120            commands[0].cmd, "git commit -m '{{message}}'",
1121            "Expected template command to be the top result for a matching search term"
1122        );
1123
1124        // Scenario 4: Path relevance
1125        let (commands, _) = run_search("build", PROJ_A_API_PATH).await;
1126        assert!(!commands.is_empty(), "Expected results for 'build'");
1127        assert_eq!(
1128            commands[0].cmd, "npm run build:prod",
1129            "Expected 'npm run build:prod' to be top result for 'build' in its project path"
1130        );
1131
1132        // Scenario 5: Fuzzy search fallback
1133        let (commands, _) = run_search("gt sta", PROJ_A_PATH).await;
1134        assert!(!commands.is_empty(), "Expected results for fuzzy search 'gt sta'");
1135        assert_eq!(
1136            commands[0].cmd, "git status",
1137            "Expected 'git status' as top result for fuzzy search 'gt sta'"
1138        );
1139
1140        // Scenario 6: Specific description search with low usage
1141        let (commands, _) = run_search("get pod monitoring", UNRELATED_PATH).await;
1142        assert!(!commands.is_empty(), "Expected results for 'get pod monitoring'");
1143        assert_eq!(
1144            commands[0].cmd, "kubectl get pod -n monitoring my-specific-pod-12345",
1145            "Expected specific 'kubectl' command to be found"
1146        );
1147
1148        // Scenario 7: High usage in parent path
1149        let (commands, _) = run_search("status", PROJ_A_API_PATH).await;
1150        assert!(!commands.is_empty(), "Expected results for 'status'");
1151        assert_eq!(
1152            commands[0].cmd, "git status",
1153            "Expected 'git status' to be top due to high usage in parent path"
1154        );
1155    }
1156
1157    #[tokio::test]
1158    async fn test_find_commands_search_mode_auto_hastag_only() {
1159        let storage = setup_ranking_storage().await;
1160
1161        // This test forces fts2 and fts3 to be omited and the final query to contain fts1 only
1162        // If the query planner tries to inline it, it would fail because bm25 functión can't be used on that context
1163        let filter = SearchCommandsFilter {
1164            search_term: Some("#".to_string()),
1165            search_mode: SearchMode::Auto,
1166            ..Default::default()
1167        };
1168
1169        let res = storage
1170            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1171            .await;
1172        assert!(res.is_ok(), "Expected a success response, got: {res:?}")
1173    }
1174
1175    #[tokio::test]
1176    async fn test_find_commands_including_workspace() {
1177        let storage = setup_ranking_storage().await;
1178
1179        storage.setup_workspace_storage().await.unwrap();
1180        let commands_to_import = vec![
1181            Command {
1182                id: Uuid::now_v7(),
1183                cmd: "cmd1".to_string(),
1184                ..Default::default()
1185            },
1186            Command {
1187                id: Uuid::now_v7(),
1188                cmd: "cmd2".to_string(),
1189                ..Default::default()
1190            },
1191        ];
1192        let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1193        storage.import_commands(stream, None, false, true).await.unwrap();
1194
1195        let (commands, _) = storage
1196            .find_commands(
1197                SearchCommandsFilter::default(),
1198                "/some/path",
1199                &SearchCommandTuning::default(),
1200            )
1201            .await
1202            .unwrap();
1203        assert_eq!(commands.len(), 12, "Expected 12 commands including workspace");
1204    }
1205
1206    #[tokio::test]
1207    async fn test_find_commands_with_text_including_workspace() {
1208        let storage = setup_ranking_storage().await;
1209
1210        storage.setup_workspace_storage().await.unwrap();
1211        let commands_to_import = vec![Command {
1212            id: Uuid::now_v7(),
1213            cmd: "git checkout -b feature/{{name:kebab}}".to_string(),
1214            ..Default::default()
1215        }];
1216        let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1217        storage.import_commands(stream, None, false, true).await.unwrap();
1218
1219        let filter = SearchCommandsFilter {
1220            search_term: Some("git".to_string()),
1221            ..Default::default()
1222        };
1223
1224        let (commands, _) = storage
1225            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1226            .await
1227            .unwrap();
1228        assert_eq!(commands.len(), 6, "Expected 6 git commands including workspace");
1229        assert!(
1230            commands
1231                .iter()
1232                .any(|c| c.cmd == "git checkout -b feature/{{name:kebab}}")
1233        );
1234    }
1235
1236    #[tokio::test]
1237    async fn test_import_commands_no_overwrite() {
1238        let storage = SqliteStorage::new_in_memory().await.unwrap();
1239
1240        let commands_to_import = vec![
1241            Command {
1242                id: Uuid::now_v7(),
1243                cmd: "cmd1".to_string(),
1244                ..Default::default()
1245            },
1246            Command {
1247                id: Uuid::now_v7(),
1248                cmd: "cmd2".to_string(),
1249                ..Default::default()
1250            },
1251        ];
1252
1253        let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1254        let (inserted, skipped_or_updated) = storage.import_commands(stream, None, false, false).await.unwrap();
1255
1256        assert_eq!(inserted, 2, "Expected 2 commands inserted");
1257        assert_eq!(skipped_or_updated, 0, "Expected 0 commands skipped or updated");
1258
1259        // Import the same commands again with no overwrite
1260        let stream = iter(commands_to_import.into_iter().map(Ok));
1261        let (inserted, skipped_or_updated) = storage.import_commands(stream, None, false, false).await.unwrap();
1262
1263        assert_eq!(
1264            inserted, 0,
1265            "Expected 0 commands inserted on second import (no overwrite)"
1266        );
1267        assert_eq!(
1268            skipped_or_updated, 2,
1269            "Expected 2 commands skipped on second import (no overwrite)"
1270        );
1271    }
1272
1273    #[tokio::test]
1274    async fn test_import_commands_overwrite() {
1275        let storage = SqliteStorage::new_in_memory().await.unwrap();
1276
1277        let existing_cmd = Command {
1278            id: Uuid::now_v7(),
1279            cmd: "existing_cmd".to_string(),
1280            description: Some("original desc".to_string()),
1281            alias: Some("original_alias".to_string()),
1282            tags: Some(vec!["tag_a".to_string()]),
1283            ..Default::default()
1284        };
1285        storage.insert_command(existing_cmd.clone()).await.unwrap();
1286
1287        let new_cmd = Command {
1288            id: Uuid::now_v7(),
1289            cmd: "new_cmd".to_string(),
1290            ..Default::default()
1291        };
1292
1293        // Import a list containing the existing command (modified) and a new command
1294        let commands_to_import = vec![
1295            Command {
1296                id: Uuid::now_v7(),
1297                cmd: "existing_cmd".to_string(),
1298                description: Some("updated desc".to_string()),
1299                alias: None,
1300                tags: Some(vec!["tag_b".to_string()]),
1301                ..Default::default()
1302            },
1303            new_cmd.clone(),
1304        ];
1305
1306        let stream = iter(commands_to_import.into_iter().map(Ok));
1307        let (inserted, skipped_or_updated) = storage.import_commands(stream, None, true, false).await.unwrap();
1308
1309        assert_eq!(inserted, 1, "Expected 1 new command inserted");
1310        assert_eq!(skipped_or_updated, 1, "Expected 1 existing command updated");
1311
1312        // Verify the existing command was updated
1313        let filter = SearchCommandsFilter {
1314            search_term: Some("existing_cmd".to_string()),
1315            search_mode: SearchMode::Exact,
1316            ..Default::default()
1317        };
1318        let (found_commands, _) = storage
1319            .find_commands(filter, "/some/path", &SearchCommandTuning::default())
1320            .await
1321            .unwrap();
1322        assert_eq!(found_commands.len(), 1);
1323        let updated_cmd_in_db = &found_commands[0];
1324        assert_eq!(
1325            updated_cmd_in_db.description,
1326            Some("updated desc".to_string()),
1327            "Description should be updated"
1328        );
1329        assert_eq!(
1330            updated_cmd_in_db.alias,
1331            Some("original_alias".to_string()),
1332            "Alias should NOT be updated to NULL"
1333        );
1334        assert_eq!(
1335            updated_cmd_in_db.tags,
1336            Some(vec!["tag_b".to_string()]),
1337            "Tags should be updated"
1338        );
1339    }
1340
1341    #[tokio::test]
1342    async fn test_import_commands_with_filter() {
1343        let storage = SqliteStorage::new_in_memory().await.unwrap();
1344
1345        let commands_to_import = vec![
1346            Command {
1347                id: Uuid::now_v7(),
1348                cmd: "git commit".to_string(),
1349                ..Default::default()
1350            },
1351            Command {
1352                id: Uuid::now_v7(),
1353                cmd: "docker ps".to_string(),
1354                ..Default::default()
1355            },
1356            Command {
1357                id: Uuid::now_v7(),
1358                cmd: "git push".to_string(),
1359                ..Default::default()
1360            },
1361        ];
1362
1363        let filter = Some(Regex::new("^git").unwrap());
1364        let stream = iter(commands_to_import.into_iter().map(Ok));
1365        let (inserted, _) = storage.import_commands(stream, filter, false, false).await.unwrap();
1366
1367        assert_eq!(inserted, 2, "Expected 2 commands to be inserted");
1368
1369        let (all_commands, _) = storage
1370            .find_commands(
1371                SearchCommandsFilter::default(),
1372                "/some/path",
1373                &SearchCommandTuning::default(),
1374            )
1375            .await
1376            .unwrap();
1377        assert!(all_commands.iter().all(|c| c.cmd.starts_with("git")));
1378        assert!(!all_commands.iter().any(|c| c.cmd.starts_with("docker")));
1379    }
1380
1381    #[tokio::test]
1382    async fn test_import_workspace_commands() {
1383        let storage = SqliteStorage::new_in_memory().await.unwrap();
1384        storage.setup_workspace_storage().await.unwrap();
1385
1386        let commands_to_import = vec![
1387            Command {
1388                id: Uuid::now_v7(),
1389                cmd: "cmd1".to_string(),
1390                ..Default::default()
1391            },
1392            Command {
1393                id: Uuid::now_v7(),
1394                cmd: "cmd2".to_string(),
1395                ..Default::default()
1396            },
1397        ];
1398
1399        let stream = iter(commands_to_import.clone().into_iter().map(Ok));
1400        let (inserted, skipped_or_updated) = storage.import_commands(stream, None, false, true).await.unwrap();
1401
1402        assert_eq!(inserted, 2, "Expected 2 commands inserted");
1403        assert_eq!(skipped_or_updated, 0, "Expected 0 commands skipped or updated");
1404    }
1405
1406    #[tokio::test]
1407    async fn test_export_user_commands_no_filter() {
1408        let storage = setup_ranking_storage().await;
1409        let mut exported_commands = Vec::new();
1410        let mut stream = storage.export_user_commands(None).await;
1411        while let Some(Ok(cmd)) = stream.next().await {
1412            exported_commands.push(cmd);
1413        }
1414
1415        assert_eq!(exported_commands.len(), 7, "Expected 7 user commands to be exported");
1416    }
1417
1418    #[tokio::test]
1419    async fn test_export_user_commands_with_filter() {
1420        let storage = setup_ranking_storage().await;
1421        let filter = Regex::new(r"^git").unwrap(); // Commands starting with "git"
1422        let mut exported_commands = Vec::new();
1423        let mut stream = storage.export_user_commands(Some(filter)).await;
1424        while let Some(Ok(cmd)) = stream.next().await {
1425            exported_commands.push(cmd);
1426        }
1427
1428        assert_eq!(exported_commands.len(), 3, "Expected 3 git commands to be exported");
1429
1430        let exported_cmd_values: Vec<String> = exported_commands.into_iter().map(|c| c.cmd).collect();
1431        assert!(exported_cmd_values.contains(&"git status".to_string()));
1432        assert!(exported_cmd_values.contains(&"git checkout main".to_string()));
1433    }
1434
1435    #[tokio::test]
1436    async fn test_delete_tldr_commands() {
1437        let storage = SqliteStorage::new_in_memory().await.unwrap();
1438
1439        // Insert some tldr and non-tldr commands
1440        let tldr_cmd1 = Command {
1441            id: Uuid::now_v7(),
1442            category: "git".to_string(),
1443            source: SOURCE_TLDR.to_string(),
1444            cmd: "git status".to_string(),
1445            ..Default::default()
1446        };
1447        let tldr_cmd2 = Command {
1448            id: Uuid::now_v7(),
1449            category: "docker".to_string(),
1450            source: SOURCE_TLDR.to_string(),
1451            cmd: "docker ps".to_string(),
1452            ..Default::default()
1453        };
1454        let user_cmd = Command {
1455            id: Uuid::now_v7(),
1456            category: "git".to_string(),
1457            source: SOURCE_USER.to_string(),
1458            cmd: "git log".to_string(),
1459            ..Default::default()
1460        };
1461
1462        storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1463        storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1464        storage.insert_command(user_cmd.clone()).await.unwrap();
1465
1466        // Delete all tldr commands
1467        let removed = storage.delete_tldr_commands(None).await.unwrap();
1468        assert_eq!(removed, 2, "Should remove both tldr commands");
1469
1470        let (remaining, _) = storage
1471            .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1472            .await
1473            .unwrap();
1474        assert_eq!(remaining.len(), 1, "Only user command should remain");
1475        assert_eq!(remaining[0].cmd, user_cmd.cmd);
1476
1477        // Re-insert tldr commands for category-specific removal
1478        storage.insert_command(tldr_cmd1.clone()).await.unwrap();
1479        storage.insert_command(tldr_cmd2.clone()).await.unwrap();
1480
1481        // Remove only tldr commands in 'git' category
1482        let removed_git = storage.delete_tldr_commands(Some("git".to_string())).await.unwrap();
1483        assert_eq!(removed_git, 1, "Should remove one tldr command in 'git' category");
1484
1485        let (remaining, _) = storage
1486            .find_commands(SearchCommandsFilter::default(), "", &SearchCommandTuning::default())
1487            .await
1488            .unwrap();
1489        let remaining_cmds: Vec<_> = remaining.iter().map(|c| &c.cmd).collect();
1490        assert!(remaining_cmds.contains(&&tldr_cmd2.cmd));
1491        assert!(remaining_cmds.contains(&&user_cmd.cmd));
1492        assert!(!remaining_cmds.contains(&&tldr_cmd1.cmd));
1493    }
1494
1495    #[tokio::test]
1496    async fn test_insert_command() {
1497        let storage = SqliteStorage::new_in_memory().await.unwrap();
1498
1499        let mut cmd = Command {
1500            id: Uuid::now_v7(),
1501            category: "test".to_string(),
1502            cmd: "test_cmd".to_string(),
1503            description: Some("test desc".to_string()),
1504            tags: Some(vec!["tag1".to_string()]),
1505            ..Default::default()
1506        };
1507
1508        let mut inserted = storage.insert_command(cmd.clone()).await.unwrap();
1509        assert_eq!(inserted.cmd, cmd.cmd);
1510
1511        // Test duplicate id insert fails
1512        inserted.cmd = "other_cmd".to_string();
1513        match storage.insert_command(inserted).await {
1514            Err(InsertError::AlreadyExists) => (),
1515            _ => panic!("Expected AlreadyExists error on duplicate id"),
1516        }
1517
1518        // Test duplicate cmd insert fails
1519        cmd.id = Uuid::now_v7();
1520        match storage.insert_command(cmd).await {
1521            Err(InsertError::AlreadyExists) => (),
1522            _ => panic!("Expected AlreadyExists error on duplicate cmd"),
1523        }
1524    }
1525
1526    #[tokio::test]
1527    async fn test_update_command() {
1528        let storage = SqliteStorage::new_in_memory().await.unwrap();
1529
1530        let cmd = Command {
1531            id: Uuid::now_v7(),
1532            cmd: "original".to_string(),
1533            description: Some("desc".to_string()),
1534            ..Default::default()
1535        };
1536
1537        storage.insert_command(cmd.clone()).await.unwrap();
1538
1539        let mut updated = cmd.clone();
1540        updated.cmd = "updated".to_string();
1541        updated.description = Some("new desc".to_string());
1542
1543        let result = storage.update_command(updated.clone()).await.unwrap();
1544        assert_eq!(result.cmd, "updated");
1545        assert_eq!(result.description, Some("new desc".to_string()));
1546
1547        // Test update non-existent fails
1548        let mut non_existent = cmd;
1549        non_existent.id = Uuid::now_v7();
1550        match storage.update_command(non_existent).await {
1551            Err(_) => (),
1552            _ => panic!("Expected error when updating non-existent command"),
1553        }
1554
1555        // Test update to existing cmd fails
1556        let another_cmd = Command {
1557            id: Uuid::now_v7(),
1558            cmd: "another".to_string(),
1559            ..Default::default()
1560        };
1561        let mut result = storage.insert_command(another_cmd.clone()).await.unwrap();
1562        result.cmd = "updated".to_string();
1563        match storage.update_command(result).await {
1564            Err(UpdateError::AlreadyExists) => (),
1565            _ => panic!("Expected AlreadyExists error when updating to existing cmd"),
1566        }
1567    }
1568
1569    #[tokio::test]
1570    async fn test_increment_command_usage() {
1571        let storage = SqliteStorage::new_in_memory().await.unwrap();
1572
1573        // Setup the command
1574        let command = storage
1575            .setup_command(
1576                Command::new(CATEGORY_USER, SOURCE_USER, "gc command interfering"),
1577                [("/some/path", 100)],
1578            )
1579            .await;
1580
1581        // Insert
1582        let count = storage.increment_command_usage(command.id, "/path").await.unwrap();
1583        assert_eq!(count, 1);
1584
1585        // Update
1586        let count = storage.increment_command_usage(command.id, "/some/path").await.unwrap();
1587        assert_eq!(count, 101);
1588    }
1589
1590    #[tokio::test]
1591    async fn test_delete_command() {
1592        let storage = SqliteStorage::new_in_memory().await.unwrap();
1593
1594        let cmd = Command {
1595            id: Uuid::now_v7(),
1596            cmd: "to_delete".to_string(),
1597            ..Default::default()
1598        };
1599
1600        let cmd = storage.insert_command(cmd).await.unwrap();
1601        let res = storage.delete_command(cmd.id).await;
1602        assert!(res.is_ok());
1603
1604        // Test delete non-existent fails
1605        match storage.delete_command(cmd.id).await {
1606            Err(_) => (),
1607            _ => panic!("Expected error when deleting non-existent command"),
1608        }
1609    }
1610
1611    /// Helper to setup a storage instance with a comprehensive suite of commands for testing all scenarios.
1612    async fn setup_ranking_storage() -> SqliteStorage {
1613        let storage = SqliteStorage::new_in_memory().await.unwrap();
1614        storage
1615            .setup_command(
1616                Command::new(
1617                    CATEGORY_USER,
1618                    SOURCE_USER,
1619                    "kubectl get pod -n monitoring my-specific-pod-12345",
1620                )
1621                .with_description(Some(
1622                    "Get a very specific pod by its full name in the monitoring namespace".to_string(),
1623                ))
1624                .with_tags(Some(vec!["#k8s".to_string(), "#pod".to_string()])),
1625                [("/other/path", 1)],
1626            )
1627            .await;
1628        storage
1629            .setup_command(
1630                Command::new(CATEGORY_USER, SOURCE_USER, "git status")
1631                    .with_description(Some("Check the status of the git repository".to_string()))
1632                    .with_tags(Some(vec!["#git".to_string()])),
1633                [(PROJ_A_PATH, 50), (PROJ_B_PATH, 50), (UNRELATED_PATH, 100)],
1634            )
1635            .await;
1636        storage
1637            .setup_command(
1638                Command::new(CATEGORY_USER, SOURCE_USER, "npm run build:prod")
1639                    .with_description(Some("Build the project for production".to_string()))
1640                    .with_tags(Some(vec!["#npm".to_string(), "#build".to_string()])),
1641                [(PROJ_A_API_PATH, 25)],
1642            )
1643            .await;
1644        storage
1645            .setup_command(
1646                Command::new(CATEGORY_USER, SOURCE_USER, "container-image-build.sh")
1647                    .with_description(Some("A generic script to build a container image".to_string()))
1648                    .with_tags(Some(vec!["#docker".to_string(), "#build".to_string()])),
1649                [(UNRELATED_PATH, 35)],
1650            )
1651            .await;
1652        storage
1653            .setup_command(
1654                Command::new(CATEGORY_USER, SOURCE_USER, "git commit -m '{{message}}'")
1655                    .with_description(Some("Commit with a message".to_string()))
1656                    .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1657                [(PROJ_A_PATH, 10), (PROJ_B_PATH, 10)],
1658            )
1659            .await;
1660        storage
1661            .setup_command(
1662                Command::new(CATEGORY_USER, SOURCE_USER, "git checkout main")
1663                    .with_alias(Some("gco".to_string()))
1664                    .with_description(Some("Checkout the main branch".to_string()))
1665                    .with_tags(Some(vec!["#git".to_string()])),
1666                [(PROJ_A_PATH, 30), (PROJ_B_PATH, 30)],
1667            )
1668            .await;
1669        storage
1670            .setup_command(
1671                Command::new("git", SOURCE_TLDR, "git commit -m")
1672                    .with_alias(Some("gc".to_string()))
1673                    .with_description(Some("Commit changes".to_string()))
1674                    .with_tags(Some(vec!["#git".to_string(), "#commit".to_string()])),
1675                [(PROJ_A_PATH, 15)],
1676            )
1677            .await;
1678        storage
1679            .setup_command(
1680                Command::new("docker", SOURCE_TLDR, "docker ps -a")
1681                    .with_description(Some("List all containers".to_string()))
1682                    .with_tags(Some(vec!["#docker".to_string(), "#list".to_string()])),
1683                [(PROJ_A_PATH, 5), (PROJ_B_PATH, 5)],
1684            )
1685            .await;
1686        storage
1687            .setup_command(
1688                Command::new("git", SOURCE_TLDR, "git push")
1689                    .with_description(Some("Push changes".to_string()))
1690                    .with_tags(Some(vec!["#git".to_string(), "#push".to_string()])),
1691                [(PROJ_A_PATH, 20), (PROJ_B_PATH, 20)],
1692            )
1693            .await;
1694        storage
1695            .setup_command(
1696                Command::new(CATEGORY_USER, SOURCE_IMPORT, "ls -lha")
1697                    .with_description(Some("List files".to_string()))
1698                    .with_tags(Some(vec!["#unix".to_string(), "#list".to_string()])),
1699                [(PROJ_A_PATH, 100), (PROJ_B_PATH, 100), (UNRELATED_PATH, 100)],
1700            )
1701            .await;
1702
1703        storage
1704    }
1705
1706    impl SqliteStorage {
1707        /// A helper function to validate the SQLite version
1708        async fn check_sqlite_version(&self) {
1709            let version: String = self
1710                .client
1711                .conn_mut::<_, _, Report>(|conn| {
1712                    conn.query_row("SELECT sqlite_version()", [], |row| row.get(0))
1713                        .map_err(Into::into)
1714                })
1715                .await
1716                .unwrap();
1717            println!("Running with SQLite version: {version}");
1718        }
1719
1720        /// A helper function to make setting up test data cleaner.
1721        /// It inserts a command and then increments its usage.
1722        async fn setup_command(
1723            &self,
1724            command: Command,
1725            usage: impl IntoIterator<Item = (&str, i32)> + Send + 'static,
1726        ) -> Command {
1727            let command = self.insert_command(command).await.unwrap();
1728            self.client
1729                .conn_mut::<_, _, Report>(move |conn| {
1730                    for (path, usage_count) in usage {
1731                        conn.execute(
1732                            r#"
1733                        INSERT INTO command_usage (command_id, path, usage_count)
1734                        VALUES (?1, ?2, ?3)
1735                        ON CONFLICT(command_id, path) DO UPDATE SET
1736                            usage_count = excluded.usage_count"#,
1737                            (&command.id, path, usage_count),
1738                        )?;
1739                    }
1740                    Ok(command)
1741                })
1742                .await
1743                .unwrap()
1744        }
1745    }
1746}