intelli_shell/storage/
variable.rs

1use std::collections::BTreeMap;
2
3use color_eyre::{
4    Report, Result,
5    eyre::{Context, eyre},
6};
7use rusqlite::{ErrorCode, Row, types::Value};
8use tracing::instrument;
9
10use super::SqliteStorage;
11use crate::{
12    config::SearchVariableTuning,
13    errors::{InsertError, UpdateError},
14    model::VariableValue,
15    utils::{flatten_str, flatten_variable},
16};
17
18impl SqliteStorage {
19    /// Finds variable values for a given root command, variable name and context.
20    ///
21    /// The `variable_name` input can be a single term or multiple terms delimited by `|`.
22    /// The method searches for values matching any of these individual (flattened) terms, as well as the (flattened)
23    /// composite variable itself.
24    ///
25    /// Results are returned for the original input variable, even if they don't explicitly exists, ordered to
26    /// prioritize overall relevance.
27    #[instrument(skip_all)]
28    pub async fn find_variable_values(
29        &self,
30        root_cmd: impl AsRef<str>,
31        variable_name: impl AsRef<str>,
32        working_path: impl Into<String>,
33        context: &BTreeMap<String, String>,
34        tuning: &SearchVariableTuning,
35    ) -> Result<Vec<VariableValue>> {
36        // Prepare flattened inputs
37        // If the variable contains any `|` char, we will consider it as a list of variables to find values for
38        let flat_root_cmd = flatten_str(root_cmd);
39        let flat_variable = flatten_variable(variable_name);
40        let mut flat_variable_values = flat_variable.split('|').map(str::to_owned).collect::<Vec<_>>();
41        // Including values for the entire variable itself
42        flat_variable_values.push(flat_variable.clone());
43        flat_variable_values.dedup();
44
45        // Prepare the query params:
46        // -- ?1~5: tuning params
47        // -- ?7: flat_root_cmd
48        // -- ?8: original flat_variable
49        // -- ?9: working_path
50        // -- ?10: context json
51        // -- ?n: all flat_variable placeholders
52        let mut all_sql_params = Vec::with_capacity(2 + flat_variable_values.len());
53        all_sql_params.push(Value::from(tuning.path.exact));
54        all_sql_params.push(Value::from(tuning.path.ancestor));
55        all_sql_params.push(Value::from(tuning.path.descendant));
56        all_sql_params.push(Value::from(tuning.path.unrelated));
57        all_sql_params.push(Value::from(tuning.path.points));
58        all_sql_params.push(Value::from(tuning.context.points));
59        all_sql_params.push(Value::from(flat_root_cmd));
60        all_sql_params.push(Value::from(flat_variable));
61        all_sql_params.push(Value::from(working_path.into()));
62        all_sql_params.push(Value::from(serde_json::to_string(context)?));
63        let prev_params_len = all_sql_params.len();
64        let mut in_placeholders = Vec::new();
65        for (idx, variable_param) in flat_variable_values.into_iter().enumerate() {
66            all_sql_params.push(Value::from(variable_param));
67            in_placeholders.push(format!("?{}", idx + prev_params_len + 1));
68        }
69        let in_placeholders = in_placeholders.join(",");
70
71        // Construct the SQL query
72        let query = format!(
73            r#"WITH
74            -- Pre-calculate the total number of variables in the query context
75            context_info AS (
76                SELECT MAX(CAST(total AS REAL)) AS total_variables
77                FROM (
78                    SELECT COUNT(*) as total FROM json_each(?10)
79                    UNION ALL SELECT 0
80                )
81            ),
82            -- Calculate the individual relevance score for each unique usage record
83            value_scores AS (
84                SELECT
85                    v.value,
86                    u.usage_count,
87                    CASE
88                        -- Exact path match
89                        WHEN u.path = ?9 THEN ?1
90                        -- Ascendant path match (parent)
91                        WHEN ?9 LIKE u.path || '/%' THEN ?2
92                        -- Descendant path match (child)
93                        WHEN u.path LIKE ?9 || '/%' THEN ?3
94                        -- Other/unrelated path
95                        ELSE ?4
96                    END AS path_relevance,
97                    (
98                        SELECT
99                            CASE
100                                WHEN ci.total_variables > 0 THEN (CAST(COUNT(*) AS REAL) / ci.total_variables)
101                                ELSE 0
102                            END
103                        FROM json_each(?10) AS query_ctx
104                        CROSS JOIN context_info ci
105                        WHERE json_extract(u.context_json, '$."' || query_ctx.key || '"') = query_ctx.value
106                    ) AS context_relevance
107                FROM variable_value v
108                JOIN variable_value_usage u ON v.id = u.value_id
109                WHERE v.flat_root_cmd = ?7 AND v.flat_variable IN ({in_placeholders})
110            ),
111            -- Group by values to find the best relevance score and the total usage count
112            agg_values AS (
113                SELECT
114                    vs.value,
115                    MAX(
116                        (vs.path_relevance * ?5)
117                        + (vs.context_relevance * ?6)
118                    ) as relevance_score,
119                    SUM(vs.usage_count) as total_usage
120                FROM value_scores vs
121                GROUP BY vs.value
122            )
123            -- Calculate the final score and join back to find the ID
124            SELECT
125                v.id,
126                ?7 AS flat_root_cmd,
127                ?8 AS flat_variable,
128                a.value,
129                (a.relevance_score + log(a.total_usage + 1)) AS final_score
130            FROM agg_values a
131            LEFT JOIN variable_value v ON v.flat_root_cmd = ?7 AND v.flat_variable = ?8 AND v.value = a.value
132            ORDER BY final_score DESC;"#
133        );
134
135        // Execute the query
136        self.client
137            .conn::<_, _, Report>(move |conn| {
138                tracing::trace!("Querying variable values:\n{query}");
139                tracing::trace!("With parameters:\n{all_sql_params:?}");
140                Ok(conn
141                    .prepare(&query)?
142                    .query(rusqlite::params_from_iter(all_sql_params.iter()))?
143                    .and_then(|r| VariableValue::try_from(r))
144                    .collect::<Result<Vec<_>, _>>()?)
145            })
146            .await
147            .wrap_err("Couldn't find variable values")
148    }
149
150    /// Inserts a new variable value into the database if it doesn't already exist
151    #[instrument(skip_all)]
152    pub async fn insert_variable_value(&self, mut value: VariableValue) -> Result<VariableValue, InsertError> {
153        // Check if the value already has an ID
154        if value.id.is_some() {
155            return Err(eyre!("ID should not be set when inserting a new value").into());
156        };
157
158        // Insert the value into the database
159        self.client
160            .conn_mut(move |conn| {
161                let res = conn.query_row(
162                    r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value) 
163                    VALUES (?1, ?2, ?3)
164                    RETURNING id"#,
165                    (&value.flat_root_cmd, &value.flat_variable, &value.value),
166                    |r| r.get(0),
167                );
168                match res {
169                    Ok(id) => {
170                        value.id = Some(id);
171                        Ok(value)
172                    }
173                    Err(err) => match err.sqlite_error_code() {
174                        Some(ErrorCode::ConstraintViolation) => Err(InsertError::AlreadyExists),
175                        _ => Err(Report::from(err).wrap_err("Couldn't insert a variable value").into()),
176                    },
177                }
178            })
179            .await
180    }
181
182    /// Updates an existing variable value
183    #[instrument(skip_all)]
184    pub async fn update_variable_value(&self, value: VariableValue) -> Result<VariableValue, UpdateError> {
185        // Check if the value doesn't have an ID to update
186        let Some(value_id) = value.id else {
187            return Err(eyre!("ID must be set when updating a variable value").into());
188        };
189
190        // Update the value in the database
191        self.client
192            .conn_mut(move |conn| {
193                let res = conn.execute(
194                    r#"
195                    UPDATE variable_value 
196                    SET flat_root_cmd = ?2, 
197                        flat_variable = ?3, 
198                        value = ?4
199                    WHERE rowid = ?1
200                    "#,
201                    (&value_id, &value.flat_root_cmd, &value.flat_variable, &value.value),
202                );
203                match res {
204                    Ok(0) => Err(eyre!("Variable value not found: {value_id}")
205                        .wrap_err("Couldn't update a variable value")
206                        .into()),
207                    Ok(_) => Ok(value),
208                    Err(err) => match err.sqlite_error_code() {
209                        Some(ErrorCode::ConstraintViolation) => Err(UpdateError::AlreadyExists),
210                        _ => Err(Report::from(err).wrap_err("Couldn't update a variable value").into()),
211                    },
212                }
213            })
214            .await
215    }
216
217    /// Increments the usage of a variable value
218    #[instrument(skip_all)]
219    pub async fn increment_variable_value_usage(
220        &self,
221        value_id: i32,
222        path: impl AsRef<str> + Send + 'static,
223        context: &BTreeMap<String, String>,
224    ) -> Result<i32, UpdateError> {
225        let context = serde_json::to_string(context)?;
226        self.client
227            .conn_mut(move |conn| {
228                let res = conn.query_row(
229                    r#"
230                    INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
231                    VALUES (?1, ?2, ?3, 1)
232                    ON CONFLICT(value_id, path, context_json) DO UPDATE SET
233                        usage_count = usage_count + 1
234                    RETURNING usage_count;"#,
235                    (&value_id, &path.as_ref(), &context),
236                    |r| r.get(0),
237                );
238                match res {
239                    Ok(u) => Ok(u),
240                    Err(err) => Err(Report::from(err)
241                        .wrap_err("Couldn't update a variable value usage")
242                        .into()),
243                }
244            })
245            .await
246    }
247
248    /// Deletes an existing variable value from the database.
249    ///
250    /// If the value to be deleted does not exist, an error will be returned.
251    #[instrument(skip_all)]
252    pub async fn delete_variable_value(&self, value_id: i32) -> Result<()> {
253        self.client
254            .conn_mut(move |conn| {
255                let res = conn.execute("DELETE FROM variable_value WHERE rowid = ?1", (&value_id,));
256                match res {
257                    Ok(0) => {
258                        Err(eyre!("Variable value not found: {value_id}").wrap_err("Couldn't delete a variable value"))
259                    }
260                    Ok(_) => Ok(()),
261                    Err(err) => Err(Report::from(err).wrap_err("Couldn't delete a variable value")),
262                }
263            })
264            .await
265    }
266}
267
268impl<'a> TryFrom<&'a Row<'a>> for VariableValue {
269    type Error = rusqlite::Error;
270
271    fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
272        Ok(Self {
273            id: row.get(0)?,
274            flat_root_cmd: row.get(1)?,
275            flat_variable: row.get(2)?,
276            value: row.get(3)?,
277        })
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use pretty_assertions::assert_eq;
284
285    use super::*;
286
287    #[tokio::test]
288    async fn test_find_variable_values_empty() {
289        let storage = SqliteStorage::new_in_memory().await.unwrap();
290        let values = storage
291            .find_variable_values(
292                "cmd",
293                "variable",
294                "/some/path",
295                &BTreeMap::new(),
296                &SearchVariableTuning::default(),
297            )
298            .await
299            .unwrap();
300        assert!(values.is_empty());
301    }
302
303    #[tokio::test]
304    async fn test_find_variable_values_path_relevance_ranking() {
305        let storage = SqliteStorage::new_in_memory().await.unwrap();
306        let root = "docker";
307        let variable = "image";
308        let current_path = "/home/user/project-a/api";
309
310        // Setup values with different path relevance, but identical usage and context
311        storage
312            .setup_variable_value(root, variable, "unrelated-path", "/var/www", [], 1)
313            .await;
314        storage
315            .setup_variable_value(root, variable, "child-path", "/home/user/project-a/api/db", [], 1)
316            .await;
317        storage
318            .setup_variable_value(root, variable, "parent-path", "/home/user/project-a", [], 1)
319            .await;
320        storage
321            .setup_variable_value(root, variable, "exact-path", current_path, [], 1)
322            .await;
323
324        let matches = storage
325            .find_variable_values(
326                root,
327                variable,
328                current_path,
329                &BTreeMap::new(),
330                &SearchVariableTuning::default(),
331            )
332            .await
333            .unwrap();
334
335        // Assert the order based on path relevance
336        assert_eq!(matches.len(), 4);
337        assert_eq!(matches[0].value, "exact-path");
338        assert_eq!(matches[1].value, "parent-path");
339        assert_eq!(matches[2].value, "child-path");
340        assert_eq!(matches[3].value, "unrelated-path");
341    }
342
343    #[tokio::test]
344    async fn test_find_variable_values_context_relevance_ranking() {
345        let storage = SqliteStorage::new_in_memory().await.unwrap();
346        let root = "kubectl";
347        let variable = "port";
348        let current_path = "/home/user/k8s";
349        let query_context = [("namespace", "prod"), ("service", "api-gateway")];
350
351        // Setup values with different context relevance, but identical paths and usage
352        storage
353            .setup_variable_value(root, variable, "no-context", current_path, [], 1)
354            .await;
355        storage
356            .setup_variable_value(
357                root,
358                variable,
359                "partial-context",
360                current_path,
361                [("namespace", "prod")],
362                1,
363            )
364            .await;
365        storage
366            .setup_variable_value(root, variable, "full-context", current_path, query_context, 1)
367            .await;
368
369        let matches = storage
370            .find_variable_values(
371                root,
372                variable,
373                current_path,
374                &BTreeMap::from_iter(query_context.into_iter().map(|(k, v)| (k.to_owned(), v.to_owned()))),
375                &SearchVariableTuning::default(),
376            )
377            .await
378            .unwrap();
379
380        // Assert the order based on context relevance
381        assert_eq!(matches.len(), 3);
382        assert_eq!(matches[0].value, "full-context");
383        assert_eq!(matches[1].value, "partial-context");
384        assert_eq!(matches[2].value, "no-context");
385    }
386
387    #[tokio::test]
388    async fn test_find_variable_values_usage_count_is_tiebreaker_only() {
389        let storage = SqliteStorage::new_in_memory().await.unwrap();
390        let root = "git";
391        let variable = "branch";
392        let current_path = "/home/user/project";
393
394        // Setup two values with identical path/context, but different usage
395        storage
396            .setup_variable_value(root, variable, "feature-a", current_path, [], 5)
397            .await;
398        storage
399            .setup_variable_value(root, variable, "feature-b", current_path, [], 50)
400            .await;
401        // A third value with worse path relevance but massive usage
402        storage
403            .setup_variable_value(root, variable, "release-1.0", "/other/path", [], 9999)
404            .await;
405
406        let matches = storage
407            .find_variable_values(
408                root,
409                variable,
410                current_path,
411                &BTreeMap::new(),
412                &SearchVariableTuning::default(),
413            )
414            .await
415            .unwrap();
416
417        // Assert that usage count correctly breaks the tie, but doesn't override relevance
418        assert_eq!(matches.len(), 3);
419        assert_eq!(matches[0].value, "feature-b");
420        assert_eq!(matches[1].value, "feature-a");
421        assert_eq!(matches[2].value, "release-1.0");
422    }
423
424    #[tokio::test]
425    async fn test_find_variable_values_aggregates_from_multiple_variables() {
426        let storage = SqliteStorage::new_in_memory().await.unwrap();
427        let root = "kubectl";
428        let variable_composite = "pod|service";
429
430        // Setup values for the individual variables
431        storage
432            .setup_variable_value(root, "pod", "api-pod-123", "/path", [], 4)
433            .await;
434        storage
435            .setup_variable_value(root, "service", "api-service", "/path", [], 5)
436            .await;
437        // Setup a value that also exists for the composite variable
438        let sug_composite = storage
439            .setup_variable_value(root, variable_composite, "api-pod-123", "/path", [], 4)
440            .await;
441
442        let matches = storage
443            .find_variable_values(
444                root,
445                variable_composite,
446                "/path",
447                &BTreeMap::new(),
448                &SearchVariableTuning::default(),
449            )
450            .await
451            .unwrap();
452
453        assert_eq!(matches.len(), 2);
454        assert_eq!(matches[0].value, "api-pod-123");
455        assert_eq!(matches[0].id, sug_composite.id);
456        assert_eq!(matches[1].value, "api-service");
457        assert!(matches[1].id.is_none());
458    }
459
460    #[tokio::test]
461    async fn test_insert_variable_value() {
462        let storage = SqliteStorage::new_in_memory().await.unwrap();
463        let sug = VariableValue::new("cmd", "variable", "value");
464
465        let inserted_sug = storage.insert_variable_value(sug.clone()).await.unwrap();
466        assert_eq!(inserted_sug.value, sug.value);
467
468        // Try inserting the same value again
469        match storage.insert_variable_value(sug.clone()).await {
470            Err(InsertError::AlreadyExists) => (),
471            res => panic!("Expected AlreadyExists error, got {res:?}"),
472        }
473    }
474
475    #[tokio::test]
476    async fn test_update_variable_value() {
477        let storage = SqliteStorage::new_in_memory().await.unwrap();
478        let sug1 = VariableValue::new("cmd", "variable", "value_orig");
479
480        // Insert initial value
481        let mut var1 = storage.insert_variable_value(sug1).await.unwrap();
482
483        // Test successful update
484        var1.value = "value_updated".to_string();
485        let res = storage.update_variable_value(var1.clone()).await;
486        assert!(res.is_ok(), "Expected successful update, got {res:?}");
487        let sug1 = res.unwrap();
488        assert_eq!(sug1.value, "value_updated");
489
490        // Test update non-existent value (wrong ID)
491        let mut non_existent_sug = sug1.clone();
492        non_existent_sug.id = Some(999);
493        match storage.update_variable_value(non_existent_sug).await {
494            Err(_) => (),
495            res => panic!("Expected error, got {res:?}"),
496        }
497
498        // Test update causing constraint violation
499        let var2 = VariableValue::new("cmd", "variable", "value_other");
500        let mut sug2 = storage.insert_variable_value(var2).await.unwrap();
501        sug2.value = "value_updated".to_string();
502        match storage.update_variable_value(sug2).await {
503            Err(UpdateError::AlreadyExists) => (),
504            res => panic!("Expected AlreadyExists error for constraint violation, got {res:?}"),
505        }
506    }
507
508    #[tokio::test]
509    async fn test_increment_variable_value_usage() {
510        let storage = SqliteStorage::new_in_memory().await.unwrap();
511
512        // Setup the value
513        let val = storage
514            .insert_variable_value(VariableValue::new("root", "variable", "value"))
515            .await
516            .unwrap();
517        let val_id = val.id.unwrap();
518
519        // Insert
520        let count = storage
521            .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
522            .await
523            .unwrap();
524        assert_eq!(count, 1);
525
526        // Update
527        let count = storage
528            .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
529            .await
530            .unwrap();
531        assert_eq!(count, 2);
532    }
533
534    #[tokio::test]
535    async fn test_delete_variable_value() {
536        let storage = SqliteStorage::new_in_memory().await.unwrap();
537        let sug = VariableValue::new("cmd", "variable_del", "value_to_delete");
538
539        // Insert values
540        let sug = storage.insert_variable_value(sug).await.unwrap();
541        let id_to_delete = sug.id.unwrap();
542
543        // Test successful deletion
544        let res = storage.delete_variable_value(id_to_delete).await;
545        assert!(res.is_ok(), "Expected successful update, got {res:?}");
546
547        // Test deleting a non-existent value
548        match storage.delete_variable_value(id_to_delete).await {
549            Err(_) => (),
550            res => panic!("Expected error, got {res:?}"),
551        }
552    }
553
554    impl SqliteStorage {
555        /// A helper function to make setting up test data cleaner.
556        /// It inserts a variable value if it doesn't exist and then increments its usage.
557        async fn setup_variable_value(
558            &self,
559            root: &'static str,
560            variable: &'static str,
561            value: &'static str,
562            path: &'static str,
563            context: impl IntoIterator<Item = (&str, &str)>,
564            usage_count: i32,
565        ) -> VariableValue {
566            let context = serde_json::to_string(&BTreeMap::<String, String>::from_iter(
567                context.into_iter().map(|(k, v)| (k.to_string(), v.to_string())),
568            ))
569            .unwrap();
570
571            self.client
572                .conn_mut::<_, _, Report>(move |conn| {
573                    let sug = conn.query_row(
574                        r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value) 
575                    VALUES (?1, ?2, ?3)
576                    ON CONFLICT (flat_root_cmd, flat_variable, value) DO UPDATE SET
577                        value = excluded.value
578                    RETURNING id, flat_root_cmd, flat_variable, value"#,
579                        (root, variable, value),
580                        |r| VariableValue::try_from(r),
581                    )?;
582                    conn.execute(
583                        r#"INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
584                        VALUES (?1, ?2, ?3, ?4)
585                        ON CONFLICT(value_id, path, context_json) DO UPDATE SET
586                            usage_count = excluded.usage_count;
587                        "#,
588                        (&sug.id, path, &context, usage_count),
589                    )?;
590                    Ok(sug)
591                })
592                .await
593                .unwrap()
594        }
595    }
596}