intelli_shell/storage/
variable.rs

1use std::collections::BTreeMap;
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{ErrorCode, Row, types::Value};
5use tracing::instrument;
6
7use super::SqliteStorage;
8use crate::{
9    config::SearchVariableTuning,
10    errors::{Result, UserFacingError},
11    model::VariableValue,
12    utils::{flatten_str, flatten_variable},
13};
14
15impl SqliteStorage {
16    /// Finds variable values for a given root command, variable name and context.
17    ///
18    /// The `variable_name` input can be a single term or multiple terms delimited by `|`.
19    /// The method searches for values matching any of these individual (flattened) terms, as well as the (flattened)
20    /// composite variable itself.
21    ///
22    /// Results are returned for the original input variable, even if they don't explicitly exists, ordered to
23    /// prioritize overall relevance.
24    #[instrument(skip_all)]
25    pub async fn find_variable_values(
26        &self,
27        root_cmd: impl AsRef<str>,
28        variable_name: impl AsRef<str>,
29        working_path: impl Into<String>,
30        context: &BTreeMap<String, String>,
31        tuning: &SearchVariableTuning,
32    ) -> Result<Vec<VariableValue>> {
33        // Prepare flattened inputs
34        // If the variable contains any `|` char, we will consider it as a list of variables to find values for
35        let flat_root_cmd = flatten_str(root_cmd);
36        let flat_variable = flatten_variable(variable_name);
37        let mut flat_variable_values = flat_variable.split('|').map(str::to_owned).collect::<Vec<_>>();
38        // Including values for the entire variable itself
39        flat_variable_values.push(flat_variable.clone());
40        flat_variable_values.dedup();
41
42        // Prepare the query params:
43        // -- ?1~5: tuning params
44        // -- ?7: flat_root_cmd
45        // -- ?8: original flat_variable
46        // -- ?9: working_path
47        // -- ?10: context json
48        // -- ?n: all flat_variable placeholders
49        let mut all_sql_params = Vec::with_capacity(2 + flat_variable_values.len());
50        all_sql_params.push(Value::from(tuning.path.exact));
51        all_sql_params.push(Value::from(tuning.path.ancestor));
52        all_sql_params.push(Value::from(tuning.path.descendant));
53        all_sql_params.push(Value::from(tuning.path.unrelated));
54        all_sql_params.push(Value::from(tuning.path.points));
55        all_sql_params.push(Value::from(tuning.context.points));
56        all_sql_params.push(Value::from(flat_root_cmd));
57        all_sql_params.push(Value::from(flat_variable));
58        all_sql_params.push(Value::from(working_path.into()));
59        all_sql_params.push(Value::from(serde_json::to_string(context)?));
60        let prev_params_len = all_sql_params.len();
61        let mut in_placeholders = Vec::new();
62        for (idx, variable_param) in flat_variable_values.into_iter().enumerate() {
63            all_sql_params.push(Value::from(variable_param));
64            in_placeholders.push(format!("?{}", idx + prev_params_len + 1));
65        }
66        let in_placeholders = in_placeholders.join(",");
67
68        // Construct the SQL query
69        let query = format!(
70            r#"WITH
71            -- Pre-calculate the total number of variables in the query context
72            context_info AS (
73                SELECT MAX(CAST(total AS REAL)) AS total_variables
74                FROM (
75                    SELECT COUNT(*) as total FROM json_each(?10)
76                    UNION ALL SELECT 0
77                )
78            ),
79            -- Calculate the individual relevance score for each unique usage record
80            value_scores AS (
81                SELECT
82                    v.value,
83                    u.usage_count,
84                    CASE
85                        -- Exact path match
86                        WHEN u.path = ?9 THEN ?1
87                        -- Ascendant path match (parent)
88                        WHEN ?9 LIKE u.path || '/%' THEN ?2
89                        -- Descendant path match (child)
90                        WHEN u.path LIKE ?9 || '/%' THEN ?3
91                        -- Other/unrelated path
92                        ELSE ?4
93                    END AS path_relevance,
94                    (
95                        SELECT
96                            CASE
97                                WHEN ci.total_variables > 0 THEN (CAST(COUNT(*) AS REAL) / ci.total_variables)
98                                ELSE 0
99                            END
100                        FROM json_each(?10) AS query_ctx
101                        CROSS JOIN context_info ci
102                        WHERE json_extract(u.context_json, '$."' || query_ctx.key || '"') = query_ctx.value
103                    ) AS context_relevance
104                FROM variable_value v
105                JOIN variable_value_usage u ON v.id = u.value_id
106                WHERE v.flat_root_cmd = ?7 AND v.flat_variable IN ({in_placeholders})
107            ),
108            -- Group by values to find the best relevance score and the total usage count
109            agg_values AS (
110                SELECT
111                    vs.value,
112                    MAX(
113                        (vs.path_relevance * ?5)
114                        + (vs.context_relevance * ?6)
115                    ) as relevance_score,
116                    SUM(vs.usage_count) as total_usage
117                FROM value_scores vs
118                GROUP BY vs.value
119            )
120            -- Calculate the final score and join back to find the ID
121            SELECT
122                v.id,
123                ?7 AS flat_root_cmd,
124                ?8 AS flat_variable,
125                a.value,
126                (a.relevance_score + log(a.total_usage + 1)) AS final_score
127            FROM agg_values a
128            LEFT JOIN variable_value v ON v.flat_root_cmd = ?7 AND v.flat_variable = ?8 AND v.value = a.value
129            ORDER BY final_score DESC;"#
130        );
131
132        // Execute the query
133        self.client
134            .conn(move |conn| {
135                tracing::trace!("Querying variable values:\n{query}");
136                tracing::trace!("With parameters:\n{all_sql_params:?}");
137                Ok(conn
138                    .prepare(&query)?
139                    .query(rusqlite::params_from_iter(all_sql_params.iter()))?
140                    .and_then(|r| VariableValue::try_from(r))
141                    .collect::<Result<Vec<_>, _>>()?)
142            })
143            .await
144    }
145
146    /// Inserts a new variable value into the database if it doesn't already exist
147    #[instrument(skip_all)]
148    pub async fn insert_variable_value(&self, mut value: VariableValue) -> Result<VariableValue> {
149        // Check if the value already has an ID
150        if value.id.is_some() {
151            return Err(eyre!("ID should not be set when inserting a new value").into());
152        };
153
154        // Insert the value into the database
155        self.client
156            .conn_mut(move |conn| {
157                let res = conn.query_row(
158                    r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value) 
159                    VALUES (?1, ?2, ?3)
160                    RETURNING id"#,
161                    (&value.flat_root_cmd, &value.flat_variable, &value.value),
162                    |r| r.get(0),
163                );
164                match res {
165                    Ok(id) => {
166                        value.id = Some(id);
167                        Ok(value)
168                    }
169                    Err(err) => match err.sqlite_error_code() {
170                        Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
171                        _ => Err(Report::from(err).into()),
172                    },
173                }
174            })
175            .await
176    }
177
178    /// Updates an existing variable value
179    #[instrument(skip_all)]
180    pub async fn update_variable_value(&self, value: VariableValue) -> Result<VariableValue> {
181        // Check if the value doesn't have an ID to update
182        let Some(value_id) = value.id else {
183            return Err(eyre!("ID must be set when updating a variable value").into());
184        };
185
186        // Update the value in the database
187        self.client
188            .conn_mut(move |conn| {
189                let res = conn.execute(
190                    r#"
191                    UPDATE variable_value 
192                    SET flat_root_cmd = ?2, 
193                        flat_variable = ?3, 
194                        value = ?4
195                    WHERE rowid = ?1
196                    "#,
197                    (&value_id, &value.flat_root_cmd, &value.flat_variable, &value.value),
198                );
199                match res {
200                    Ok(0) => Err(eyre!("Variable value not found: {value_id}")
201                        .wrap_err("Couldn't update a variable value")
202                        .into()),
203                    Ok(_) => Ok(value),
204                    Err(err) => match err.sqlite_error_code() {
205                        Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
206                        _ => Err(Report::from(err).into()),
207                    },
208                }
209            })
210            .await
211    }
212
213    /// Increments the usage of a variable value
214    #[instrument(skip_all)]
215    pub async fn increment_variable_value_usage(
216        &self,
217        value_id: i32,
218        path: impl AsRef<str> + Send + 'static,
219        context: &BTreeMap<String, String>,
220    ) -> Result<i32> {
221        let context = serde_json::to_string(context)?;
222        self.client
223            .conn_mut(move |conn| {
224                Ok(conn.query_row(
225                    r#"
226                    INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
227                    VALUES (?1, ?2, ?3, 1)
228                    ON CONFLICT(value_id, path, context_json) DO UPDATE SET
229                        usage_count = usage_count + 1
230                    RETURNING usage_count;"#,
231                    (&value_id, &path.as_ref(), &context),
232                    |r| r.get(0),
233                )?)
234            })
235            .await
236    }
237
238    /// Deletes an existing variable value from the database.
239    ///
240    /// If the value to be deleted does not exist, an error will be returned.
241    #[instrument(skip_all)]
242    pub async fn delete_variable_value(&self, value_id: i32) -> Result<()> {
243        self.client
244            .conn_mut(move |conn| {
245                let res = conn.execute("DELETE FROM variable_value WHERE rowid = ?1", (&value_id,));
246                match res {
247                    Ok(0) => Err(eyre!("Variable value not found: {value_id}").into()),
248                    Ok(_) => Ok(()),
249                    Err(err) => Err(Report::from(err).into()),
250                }
251            })
252            .await
253    }
254}
255
256impl<'a> TryFrom<&'a Row<'a>> for VariableValue {
257    type Error = rusqlite::Error;
258
259    fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
260        Ok(Self {
261            id: row.get(0)?,
262            flat_root_cmd: row.get(1)?,
263            flat_variable: row.get(2)?,
264            value: row.get(3)?,
265        })
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use pretty_assertions::assert_eq;
272
273    use super::*;
274    use crate::errors::AppError;
275
276    #[tokio::test]
277    async fn test_find_variable_values_empty() {
278        let storage = SqliteStorage::new_in_memory().await.unwrap();
279        let values = storage
280            .find_variable_values(
281                "cmd",
282                "variable",
283                "/some/path",
284                &BTreeMap::new(),
285                &SearchVariableTuning::default(),
286            )
287            .await
288            .unwrap();
289        assert!(values.is_empty());
290    }
291
292    #[tokio::test]
293    async fn test_find_variable_values_path_relevance_ranking() {
294        let storage = SqliteStorage::new_in_memory().await.unwrap();
295        let root = "docker";
296        let variable = "image";
297        let current_path = "/home/user/project-a/api";
298
299        // Setup values with different path relevance, but identical usage and context
300        storage
301            .setup_variable_value(root, variable, "unrelated-path", "/var/www", [], 1)
302            .await;
303        storage
304            .setup_variable_value(root, variable, "child-path", "/home/user/project-a/api/db", [], 1)
305            .await;
306        storage
307            .setup_variable_value(root, variable, "parent-path", "/home/user/project-a", [], 1)
308            .await;
309        storage
310            .setup_variable_value(root, variable, "exact-path", current_path, [], 1)
311            .await;
312
313        let matches = storage
314            .find_variable_values(
315                root,
316                variable,
317                current_path,
318                &BTreeMap::new(),
319                &SearchVariableTuning::default(),
320            )
321            .await
322            .unwrap();
323
324        // Assert the order based on path relevance
325        assert_eq!(matches.len(), 4);
326        assert_eq!(matches[0].value, "exact-path");
327        assert_eq!(matches[1].value, "parent-path");
328        assert_eq!(matches[2].value, "child-path");
329        assert_eq!(matches[3].value, "unrelated-path");
330    }
331
332    #[tokio::test]
333    async fn test_find_variable_values_context_relevance_ranking() {
334        let storage = SqliteStorage::new_in_memory().await.unwrap();
335        let root = "kubectl";
336        let variable = "port";
337        let current_path = "/home/user/k8s";
338        let query_context = [("namespace", "prod"), ("service", "api-gateway")];
339
340        // Setup values with different context relevance, but identical paths and usage
341        storage
342            .setup_variable_value(root, variable, "no-context", current_path, [], 1)
343            .await;
344        storage
345            .setup_variable_value(
346                root,
347                variable,
348                "partial-context",
349                current_path,
350                [("namespace", "prod")],
351                1,
352            )
353            .await;
354        storage
355            .setup_variable_value(root, variable, "full-context", current_path, query_context, 1)
356            .await;
357
358        let matches = storage
359            .find_variable_values(
360                root,
361                variable,
362                current_path,
363                &BTreeMap::from_iter(query_context.into_iter().map(|(k, v)| (k.to_owned(), v.to_owned()))),
364                &SearchVariableTuning::default(),
365            )
366            .await
367            .unwrap();
368
369        // Assert the order based on context relevance
370        assert_eq!(matches.len(), 3);
371        assert_eq!(matches[0].value, "full-context");
372        assert_eq!(matches[1].value, "partial-context");
373        assert_eq!(matches[2].value, "no-context");
374    }
375
376    #[tokio::test]
377    async fn test_find_variable_values_usage_count_is_tiebreaker_only() {
378        let storage = SqliteStorage::new_in_memory().await.unwrap();
379        let root = "git";
380        let variable = "branch";
381        let current_path = "/home/user/project";
382
383        // Setup two values with identical path/context, but different usage
384        storage
385            .setup_variable_value(root, variable, "feature-a", current_path, [], 5)
386            .await;
387        storage
388            .setup_variable_value(root, variable, "feature-b", current_path, [], 50)
389            .await;
390        // A third value with worse path relevance but massive usage
391        storage
392            .setup_variable_value(root, variable, "release-1.0", "/other/path", [], 9999)
393            .await;
394
395        let matches = storage
396            .find_variable_values(
397                root,
398                variable,
399                current_path,
400                &BTreeMap::new(),
401                &SearchVariableTuning::default(),
402            )
403            .await
404            .unwrap();
405
406        // Assert that usage count correctly breaks the tie, but doesn't override relevance
407        assert_eq!(matches.len(), 3);
408        assert_eq!(matches[0].value, "feature-b");
409        assert_eq!(matches[1].value, "feature-a");
410        assert_eq!(matches[2].value, "release-1.0");
411    }
412
413    #[tokio::test]
414    async fn test_find_variable_values_aggregates_from_multiple_variables() {
415        let storage = SqliteStorage::new_in_memory().await.unwrap();
416        let root = "kubectl";
417        let variable_composite = "pod|service";
418
419        // Setup values for the individual variables
420        storage
421            .setup_variable_value(root, "pod", "api-pod-123", "/path", [], 4)
422            .await;
423        storage
424            .setup_variable_value(root, "service", "api-service", "/path", [], 5)
425            .await;
426        // Setup a value that also exists for the composite variable
427        let sug_composite = storage
428            .setup_variable_value(root, variable_composite, "api-pod-123", "/path", [], 4)
429            .await;
430
431        let matches = storage
432            .find_variable_values(
433                root,
434                variable_composite,
435                "/path",
436                &BTreeMap::new(),
437                &SearchVariableTuning::default(),
438            )
439            .await
440            .unwrap();
441
442        assert_eq!(matches.len(), 2);
443        assert_eq!(matches[0].value, "api-pod-123");
444        assert_eq!(matches[0].id, sug_composite.id);
445        assert_eq!(matches[1].value, "api-service");
446        assert!(matches[1].id.is_none());
447    }
448
449    #[tokio::test]
450    async fn test_insert_variable_value() {
451        let storage = SqliteStorage::new_in_memory().await.unwrap();
452        let sug = VariableValue::new("cmd", "variable", "value");
453
454        let inserted_sug = storage.insert_variable_value(sug.clone()).await.unwrap();
455        assert_eq!(inserted_sug.value, sug.value);
456
457        // Try inserting the same value again
458        match storage.insert_variable_value(sug.clone()).await {
459            Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
460            res => panic!("Expected VariableValueAlreadyExists error, got {res:?}"),
461        }
462    }
463
464    #[tokio::test]
465    async fn test_update_variable_value() {
466        let storage = SqliteStorage::new_in_memory().await.unwrap();
467        let sug1 = VariableValue::new("cmd", "variable", "value_orig");
468
469        // Insert initial value
470        let mut var1 = storage.insert_variable_value(sug1).await.unwrap();
471
472        // Test successful update
473        var1.value = "value_updated".to_string();
474        let res = storage.update_variable_value(var1.clone()).await;
475        assert!(res.is_ok(), "Expected successful update, got {res:?}");
476        let sug1 = res.unwrap();
477        assert_eq!(sug1.value, "value_updated");
478
479        // Test update non-existent value (wrong ID)
480        let mut non_existent_sug = sug1.clone();
481        non_existent_sug.id = Some(999);
482        match storage.update_variable_value(non_existent_sug).await {
483            Err(_) => (),
484            res => panic!("Expected error, got {res:?}"),
485        }
486
487        // Test update causing constraint violation
488        let var2 = VariableValue::new("cmd", "variable", "value_other");
489        let mut sug2 = storage.insert_variable_value(var2).await.unwrap();
490        sug2.value = "value_updated".to_string();
491        match storage.update_variable_value(sug2).await {
492            Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
493            res => panic!("Expected VariableValueAlreadyExists error for constraint violation, got {res:?}"),
494        }
495    }
496
497    #[tokio::test]
498    async fn test_increment_variable_value_usage() {
499        let storage = SqliteStorage::new_in_memory().await.unwrap();
500
501        // Setup the value
502        let val = storage
503            .insert_variable_value(VariableValue::new("root", "variable", "value"))
504            .await
505            .unwrap();
506        let val_id = val.id.unwrap();
507
508        // Insert
509        let count = storage
510            .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
511            .await
512            .unwrap();
513        assert_eq!(count, 1);
514
515        // Update
516        let count = storage
517            .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
518            .await
519            .unwrap();
520        assert_eq!(count, 2);
521    }
522
523    #[tokio::test]
524    async fn test_delete_variable_value() {
525        let storage = SqliteStorage::new_in_memory().await.unwrap();
526        let sug = VariableValue::new("cmd", "variable_del", "value_to_delete");
527
528        // Insert values
529        let sug = storage.insert_variable_value(sug).await.unwrap();
530        let id_to_delete = sug.id.unwrap();
531
532        // Test successful deletion
533        let res = storage.delete_variable_value(id_to_delete).await;
534        assert!(res.is_ok(), "Expected successful update, got {res:?}");
535
536        // Test deleting a non-existent value
537        match storage.delete_variable_value(id_to_delete).await {
538            Err(_) => (),
539            res => panic!("Expected error, got {res:?}"),
540        }
541    }
542
543    impl SqliteStorage {
544        /// A helper function to make setting up test data cleaner.
545        /// It inserts a variable value if it doesn't exist and then increments its usage.
546        async fn setup_variable_value(
547            &self,
548            root: &'static str,
549            variable: &'static str,
550            value: &'static str,
551            path: &'static str,
552            context: impl IntoIterator<Item = (&str, &str)>,
553            usage_count: i32,
554        ) -> VariableValue {
555            let context = serde_json::to_string(&BTreeMap::<String, String>::from_iter(
556                context.into_iter().map(|(k, v)| (k.to_string(), v.to_string())),
557            ))
558            .unwrap();
559
560            self.client
561                .conn_mut(move |conn| {
562                    let sug = conn.query_row(
563                        r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value) 
564                    VALUES (?1, ?2, ?3)
565                    ON CONFLICT (flat_root_cmd, flat_variable, value) DO UPDATE SET
566                        value = excluded.value
567                    RETURNING id, flat_root_cmd, flat_variable, value"#,
568                        (root, variable, value),
569                        |r| VariableValue::try_from(r),
570                    )?;
571                    conn.execute(
572                        r#"INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
573                        VALUES (?1, ?2, ?3, ?4)
574                        ON CONFLICT(value_id, path, context_json) DO UPDATE SET
575                            usage_count = excluded.usage_count;
576                        "#,
577                        (&sug.id, path, &context, usage_count),
578                    )?;
579                    Ok(sug)
580                })
581                .await
582                .unwrap()
583        }
584    }
585}