intelli_shell/storage/
variable.rs

1use std::collections::BTreeMap;
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{ErrorCode, Row, types::Value};
5use tracing::instrument;
6
7use super::SqliteStorage;
8use crate::{
9    config::SearchVariableTuning,
10    errors::{Result, UserFacingError},
11    model::VariableValue,
12};
13
14impl SqliteStorage {
15    /// Finds variable values for a given root command, variable and context.
16    ///
17    /// The method searches for values matching any of these individual `flat_variable_names` terms, as well as the
18    /// `flat_variable_name` composite variable itself.
19    ///
20    /// Results are returned for the original input variable, even if they don't explicitly exists, ordered to
21    /// prioritize overall relevance.
22    #[instrument(skip_all)]
23    pub async fn find_variable_values(
24        &self,
25        flat_root_cmd: impl Into<String>,
26        flat_variable_name: impl Into<String>,
27        mut flat_variable_names: Vec<String>,
28        working_path: impl Into<String>,
29        context: &BTreeMap<String, String>,
30        tuning: &SearchVariableTuning,
31    ) -> Result<Vec<(VariableValue, f64)>> {
32        // Also search for values of the composite variable name itself
33        let flat_variable_name = flat_variable_name.into();
34        flat_variable_names.push(flat_variable_name.clone());
35        flat_variable_names.dedup();
36
37        // Prepare the query params:
38        // -- ?1~5: tuning params
39        // -- ?7: flat_root_cmd
40        // -- ?8: flat_name of the variable
41        // -- ?9: working_path
42        // -- ?10: context json
43        // -- ?n: all variable flat_names placeholders
44        let mut all_sql_params = Vec::with_capacity(10 + flat_variable_names.len());
45        all_sql_params.push(Value::from(tuning.path.exact));
46        all_sql_params.push(Value::from(tuning.path.ancestor));
47        all_sql_params.push(Value::from(tuning.path.descendant));
48        all_sql_params.push(Value::from(tuning.path.unrelated));
49        all_sql_params.push(Value::from(tuning.path.points));
50        all_sql_params.push(Value::from(tuning.context.points));
51        all_sql_params.push(Value::from(flat_root_cmd.into()));
52        all_sql_params.push(Value::from(flat_variable_name));
53        all_sql_params.push(Value::from(working_path.into()));
54        all_sql_params.push(Value::from(serde_json::to_string(context)?));
55        let prev_params_len = all_sql_params.len();
56        let mut in_placeholders = Vec::new();
57        for (idx, flat_name) in flat_variable_names.into_iter().enumerate() {
58            all_sql_params.push(Value::from(flat_name));
59            in_placeholders.push(format!("?{}", idx + prev_params_len + 1));
60        }
61        let in_placeholders = in_placeholders.join(",");
62
63        // Construct the SQL query
64        let query = format!(
65            r#"WITH
66            -- Pre-calculate the total number of variables in the query context
67            context_info AS (
68                SELECT MAX(CAST(total AS REAL)) AS total_variables
69                FROM (
70                    SELECT COUNT(*) as total FROM json_each(?10)
71                    UNION ALL SELECT 0
72                )
73            ),
74            -- Calculate the individual relevance score for each unique usage record
75            value_scores AS (
76                SELECT
77                    v.value,
78                    u.usage_count,
79                    CASE
80                        -- Exact path match
81                        WHEN u.path = ?9 THEN ?1
82                        -- Ascendant path match (parent)
83                        WHEN ?9 LIKE u.path || '/%' THEN ?2
84                        -- Descendant path match (child)
85                        WHEN u.path LIKE ?9 || '/%' THEN ?3
86                        -- Other/unrelated path
87                        ELSE ?4
88                    END AS path_relevance,
89                    (
90                        SELECT
91                            CASE
92                                WHEN ci.total_variables > 0 THEN (CAST(COUNT(*) AS REAL) / ci.total_variables)
93                                ELSE 0
94                            END
95                        FROM json_each(?10) AS query_ctx
96                        CROSS JOIN context_info ci
97                        WHERE json_extract(u.context_json, '$."' || query_ctx.key || '"') = query_ctx.value
98                    ) AS context_relevance
99                FROM variable_value v
100                JOIN variable_value_usage u ON v.id = u.value_id
101                WHERE v.flat_root_cmd = ?7 AND v.flat_variable IN ({in_placeholders})
102            ),
103            -- Group by values to find the best relevance score and the total usage count
104            agg_values AS (
105                SELECT
106                    vs.value,
107                    MAX(
108                        (vs.path_relevance * ?5)
109                        + (vs.context_relevance * ?6)
110                    ) as relevance_score,
111                    SUM(vs.usage_count) as total_usage
112                FROM value_scores vs
113                GROUP BY vs.value
114            )
115            -- Calculate the final score and join back to find the ID
116            SELECT
117                v.id,
118                ?7 AS flat_root_cmd,
119                ?8 AS flat_variable,
120                a.value,
121                (a.relevance_score + log(a.total_usage + 1)) AS final_score
122            FROM agg_values a
123            LEFT JOIN variable_value v ON v.flat_root_cmd = ?7 AND v.flat_variable = ?8 AND v.value = a.value
124            ORDER BY final_score DESC;"#
125        );
126
127        // Execute the query
128        self.client
129            .conn(move |conn| {
130                tracing::trace!("Querying variable values:\n{query}");
131                tracing::trace!("With parameters:\n{all_sql_params:?}");
132                Ok(conn
133                    .prepare(&query)?
134                    .query_map(rusqlite::params_from_iter(all_sql_params.iter()), |r| {
135                        Ok((VariableValue::try_from(r)?, r.get(4)?))
136                    })?
137                    .collect::<Result<Vec<_>, _>>()?)
138            })
139            .await
140    }
141
142    /// Inserts a new variable value into the database if it doesn't already exist
143    #[instrument(skip_all)]
144    pub async fn insert_variable_value(&self, mut value: VariableValue) -> Result<VariableValue> {
145        // Check if the value already has an ID
146        if value.id.is_some() {
147            return Err(eyre!("ID should not be set when inserting a new value").into());
148        };
149
150        // Insert the value into the database
151        self.client
152            .conn_mut(move |conn| {
153                let query = r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value) 
154                    VALUES (?1, ?2, ?3)
155                    RETURNING id"#;
156                tracing::trace!("Inserting a variable value: {query}");
157                let res = conn.query_row(query, (&value.flat_root_cmd, &value.flat_variable, &value.value), |r| {
158                    r.get(0)
159                });
160                match res {
161                    Ok(id) => {
162                        value.id = Some(id);
163                        Ok(value)
164                    }
165                    Err(err) => match err.sqlite_error_code() {
166                        Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
167                        _ => Err(Report::from(err).into()),
168                    },
169                }
170            })
171            .await
172    }
173
174    /// Updates an existing variable value
175    #[instrument(skip_all)]
176    pub async fn update_variable_value(&self, value: VariableValue) -> Result<VariableValue> {
177        // Check if the value doesn't have an ID to update
178        let Some(value_id) = value.id else {
179            return Err(eyre!("ID must be set when updating a variable value").into());
180        };
181
182        // Update the value in the database
183        self.client
184            .conn_mut(move |conn| {
185                let query = r#"
186                    UPDATE variable_value 
187                    SET flat_root_cmd = ?2, 
188                        flat_variable = ?3, 
189                        value = ?4
190                    WHERE rowid = ?1
191                    "#;
192                tracing::trace!("Updating a variable value: {query}");
193                let res = conn.execute(
194                    query,
195                    (&value_id, &value.flat_root_cmd, &value.flat_variable, &value.value),
196                );
197                match res {
198                    Ok(0) => Err(eyre!("Variable value not found: {value_id}")
199                        .wrap_err("Couldn't update a variable value")
200                        .into()),
201                    Ok(_) => Ok(value),
202                    Err(err) => match err.sqlite_error_code() {
203                        Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
204                        _ => Err(Report::from(err).into()),
205                    },
206                }
207            })
208            .await
209    }
210
211    /// Increments the usage of a variable value
212    #[instrument(skip_all)]
213    pub async fn increment_variable_value_usage(
214        &self,
215        value_id: i32,
216        path: impl AsRef<str> + Send + 'static,
217        context: &BTreeMap<String, String>,
218    ) -> Result<i32> {
219        let context = serde_json::to_string(context)?;
220        self.client
221            .conn_mut(move |conn| {
222                let query = r#"
223                    INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
224                    VALUES (?1, ?2, ?3, 1)
225                    ON CONFLICT(value_id, path, context_json) DO UPDATE SET
226                        usage_count = usage_count + 1
227                    RETURNING usage_count;"#;
228                tracing::trace!("Incrementing variable value usage: {query}");
229                Ok(conn.query_row(query, (&value_id, &path.as_ref(), &context), |r| r.get(0))?)
230            })
231            .await
232    }
233
234    /// Deletes an existing variable value from the database.
235    ///
236    /// If the value to be deleted does not exist, an error will be returned.
237    #[instrument(skip_all)]
238    pub async fn delete_variable_value(&self, value_id: i32) -> Result<()> {
239        self.client
240            .conn_mut(move |conn| {
241                let query = "DELETE FROM variable_value WHERE rowid = ?1";
242                tracing::trace!("Deleting a variable value: {query}");
243                let res = conn.execute(query, (&value_id,));
244                match res {
245                    Ok(0) => Err(eyre!("Variable value not found: {value_id}").into()),
246                    Ok(_) => Ok(()),
247                    Err(err) => Err(Report::from(err).into()),
248                }
249            })
250            .await
251    }
252}
253
254impl<'a> TryFrom<&'a Row<'a>> for VariableValue {
255    type Error = rusqlite::Error;
256
257    fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
258        Ok(Self {
259            id: row.get(0)?,
260            flat_root_cmd: row.get(1)?,
261            flat_variable: row.get(2)?,
262            value: row.get(3)?,
263        })
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use pretty_assertions::assert_eq;
270
271    use super::*;
272    use crate::errors::AppError;
273
274    #[tokio::test]
275    async fn test_find_variable_values_empty() {
276        let storage = SqliteStorage::new_in_memory().await.unwrap();
277        let values = storage
278            .find_variable_values(
279                "cmd",
280                "variable",
281                Vec::new(),
282                "/some/path",
283                &BTreeMap::new(),
284                &SearchVariableTuning::default(),
285            )
286            .await
287            .unwrap();
288        assert!(values.is_empty());
289    }
290
291    #[tokio::test]
292    async fn test_find_variable_values_path_relevance_ranking() {
293        let storage = SqliteStorage::new_in_memory().await.unwrap();
294        let root = "docker";
295        let variable = "image";
296        let current_path = "/home/user/project-a/api";
297
298        // Setup values with different path relevance, but identical usage and context
299        storage
300            .setup_variable_value(root, variable, "unrelated-path", "/var/www", [], 1)
301            .await;
302        storage
303            .setup_variable_value(root, variable, "child-path", "/home/user/project-a/api/db", [], 1)
304            .await;
305        storage
306            .setup_variable_value(root, variable, "parent-path", "/home/user/project-a", [], 1)
307            .await;
308        storage
309            .setup_variable_value(root, variable, "exact-path", current_path, [], 1)
310            .await;
311
312        let matches = storage
313            .find_variable_values(
314                root,
315                variable,
316                Vec::new(),
317                current_path,
318                &BTreeMap::new(),
319                &SearchVariableTuning::default(),
320            )
321            .await
322            .unwrap();
323
324        // Assert the order based on path relevance
325        assert_eq!(matches.len(), 4);
326        assert_eq!(matches[0].0.value, "exact-path");
327        assert_eq!(matches[1].0.value, "parent-path");
328        assert_eq!(matches[2].0.value, "child-path");
329        assert_eq!(matches[3].0.value, "unrelated-path");
330    }
331
332    #[tokio::test]
333    async fn test_find_variable_values_context_relevance_ranking() {
334        let storage = SqliteStorage::new_in_memory().await.unwrap();
335        let root = "kubectl";
336        let variable = "port";
337        let current_path = "/home/user/k8s";
338        let query_context = [("namespace", "prod"), ("service", "api-gateway")];
339
340        // Setup values with different context relevance, but identical paths and usage
341        storage
342            .setup_variable_value(root, variable, "no-context", current_path, [], 1)
343            .await;
344        storage
345            .setup_variable_value(
346                root,
347                variable,
348                "partial-context",
349                current_path,
350                [("namespace", "prod")],
351                1,
352            )
353            .await;
354        storage
355            .setup_variable_value(root, variable, "full-context", current_path, query_context, 1)
356            .await;
357
358        let matches = storage
359            .find_variable_values(
360                root,
361                variable,
362                Vec::new(),
363                current_path,
364                &BTreeMap::from_iter(query_context.into_iter().map(|(k, v)| (k.to_owned(), v.to_owned()))),
365                &SearchVariableTuning::default(),
366            )
367            .await
368            .unwrap();
369
370        // Assert the order based on context relevance
371        assert_eq!(matches.len(), 3);
372        assert_eq!(matches[0].0.value, "full-context");
373        assert_eq!(matches[1].0.value, "partial-context");
374        assert_eq!(matches[2].0.value, "no-context");
375    }
376
377    #[tokio::test]
378    async fn test_find_variable_values_usage_count_is_tiebreaker_only() {
379        let storage = SqliteStorage::new_in_memory().await.unwrap();
380        let root = "git";
381        let variable = "branch";
382        let current_path = "/home/user/project";
383
384        // Setup two values with identical path/context, but different usage
385        storage
386            .setup_variable_value(root, variable, "feature-a", current_path, [], 5)
387            .await;
388        storage
389            .setup_variable_value(root, variable, "feature-b", current_path, [], 50)
390            .await;
391        // A third value with worse path relevance but massive usage
392        storage
393            .setup_variable_value(root, variable, "release-1.0", "/other/path", [], 9999)
394            .await;
395
396        let matches = storage
397            .find_variable_values(
398                root,
399                variable,
400                Vec::new(),
401                current_path,
402                &BTreeMap::new(),
403                &SearchVariableTuning::default(),
404            )
405            .await
406            .unwrap();
407
408        // Assert that usage count correctly breaks the tie, but doesn't override relevance
409        assert_eq!(matches.len(), 3);
410        assert_eq!(matches[0].0.value, "feature-b");
411        assert_eq!(matches[1].0.value, "feature-a");
412        assert_eq!(matches[2].0.value, "release-1.0");
413    }
414
415    #[tokio::test]
416    async fn test_find_variable_values_aggregates_from_multiple_variables() {
417        let storage = SqliteStorage::new_in_memory().await.unwrap();
418        let root = "kubectl";
419        let variable_composite = "pod|service";
420        let variable_composite_names = variable_composite.split("|").map(String::from).collect::<Vec<_>>();
421
422        // Setup values for the individual variables
423        storage
424            .setup_variable_value(root, "pod", "api-pod-123", "/path", [], 4)
425            .await;
426        storage
427            .setup_variable_value(root, "service", "api-service", "/path", [], 5)
428            .await;
429        // Setup a value that also exists for the composite variable
430        let sug_composite = storage
431            .setup_variable_value(root, variable_composite, "api-pod-123", "/path", [], 4)
432            .await;
433
434        let matches = storage
435            .find_variable_values(
436                root,
437                variable_composite,
438                variable_composite_names,
439                "/path",
440                &BTreeMap::new(),
441                &SearchVariableTuning::default(),
442            )
443            .await
444            .unwrap();
445
446        assert_eq!(matches.len(), 2);
447        assert_eq!(matches[0].0.value, "api-pod-123");
448        assert_eq!(matches[0].0.id, sug_composite.id);
449        assert_eq!(matches[1].0.value, "api-service");
450        assert!(matches[1].0.id.is_none());
451    }
452
453    #[tokio::test]
454    async fn test_insert_variable_value() {
455        let storage = SqliteStorage::new_in_memory().await.unwrap();
456        let sug = VariableValue::new("cmd", "variable", "value");
457
458        let inserted_sug = storage.insert_variable_value(sug.clone()).await.unwrap();
459        assert_eq!(inserted_sug.value, sug.value);
460
461        // Try inserting the same value again
462        match storage.insert_variable_value(sug.clone()).await {
463            Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
464            res => panic!("Expected VariableValueAlreadyExists error, got {res:?}"),
465        }
466    }
467
468    #[tokio::test]
469    async fn test_update_variable_value() {
470        let storage = SqliteStorage::new_in_memory().await.unwrap();
471        let sug1 = VariableValue::new("cmd", "variable", "value_orig");
472
473        // Insert initial value
474        let mut var1 = storage.insert_variable_value(sug1).await.unwrap();
475
476        // Test successful update
477        var1.value = "value_updated".to_string();
478        let res = storage.update_variable_value(var1.clone()).await;
479        assert!(res.is_ok(), "Expected successful update, got {res:?}");
480        let sug1 = res.unwrap();
481        assert_eq!(sug1.value, "value_updated");
482
483        // Test update non-existent value (wrong ID)
484        let mut non_existent_sug = sug1.clone();
485        non_existent_sug.id = Some(999);
486        match storage.update_variable_value(non_existent_sug).await {
487            Err(_) => (),
488            res => panic!("Expected error, got {res:?}"),
489        }
490
491        // Test update causing constraint violation
492        let var2 = VariableValue::new("cmd", "variable", "value_other");
493        let mut sug2 = storage.insert_variable_value(var2).await.unwrap();
494        sug2.value = "value_updated".to_string();
495        match storage.update_variable_value(sug2).await {
496            Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
497            res => panic!("Expected VariableValueAlreadyExists error for constraint violation, got {res:?}"),
498        }
499    }
500
501    #[tokio::test]
502    async fn test_increment_variable_value_usage() {
503        let storage = SqliteStorage::new_in_memory().await.unwrap();
504
505        // Setup the value
506        let val = storage
507            .insert_variable_value(VariableValue::new("root", "variable", "value"))
508            .await
509            .unwrap();
510        let val_id = val.id.unwrap();
511
512        // Insert
513        let count = storage
514            .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
515            .await
516            .unwrap();
517        assert_eq!(count, 1);
518
519        // Update
520        let count = storage
521            .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
522            .await
523            .unwrap();
524        assert_eq!(count, 2);
525    }
526
527    #[tokio::test]
528    async fn test_delete_variable_value() {
529        let storage = SqliteStorage::new_in_memory().await.unwrap();
530        let sug = VariableValue::new("cmd", "variable_del", "value_to_delete");
531
532        // Insert values
533        let sug = storage.insert_variable_value(sug).await.unwrap();
534        let id_to_delete = sug.id.unwrap();
535
536        // Test successful deletion
537        let res = storage.delete_variable_value(id_to_delete).await;
538        assert!(res.is_ok(), "Expected successful update, got {res:?}");
539
540        // Test deleting a non-existent value
541        match storage.delete_variable_value(id_to_delete).await {
542            Err(_) => (),
543            res => panic!("Expected error, got {res:?}"),
544        }
545    }
546
547    impl SqliteStorage {
548        /// A helper function to make setting up test data cleaner.
549        /// It inserts a variable value if it doesn't exist and then increments its usage.
550        async fn setup_variable_value(
551            &self,
552            root: &'static str,
553            variable: &'static str,
554            value: &'static str,
555            path: &'static str,
556            context: impl IntoIterator<Item = (&str, &str)>,
557            usage_count: i32,
558        ) -> VariableValue {
559            let context = serde_json::to_string(&BTreeMap::<String, String>::from_iter(
560                context.into_iter().map(|(k, v)| (k.to_string(), v.to_string())),
561            ))
562            .unwrap();
563
564            self.client
565                .conn_mut(move |conn| {
566                    let sug = conn.query_row(
567                        r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value) 
568                    VALUES (?1, ?2, ?3)
569                    ON CONFLICT (flat_root_cmd, flat_variable, value) DO UPDATE SET
570                        value = excluded.value
571                    RETURNING id, flat_root_cmd, flat_variable, value"#,
572                        (root, variable, value),
573                        |r| VariableValue::try_from(r),
574                    )?;
575                    conn.execute(
576                        r#"INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
577                        VALUES (?1, ?2, ?3, ?4)
578                        ON CONFLICT(value_id, path, context_json) DO UPDATE SET
579                            usage_count = excluded.usage_count;
580                        "#,
581                        (&sug.id, path, &context, usage_count),
582                    )?;
583                    Ok(sug)
584                })
585                .await
586                .unwrap()
587        }
588    }
589}