1use std::collections::BTreeMap;
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{ErrorCode, Row, types::Value};
5use tracing::instrument;
6
7use super::SqliteStorage;
8use crate::{
9 config::SearchVariableTuning,
10 errors::{Result, UserFacingError},
11 model::VariableValue,
12 utils::{flatten_str, flatten_variable},
13};
14
15impl SqliteStorage {
16 #[instrument(skip_all)]
25 pub async fn find_variable_values(
26 &self,
27 root_cmd: impl AsRef<str>,
28 variable_name: impl AsRef<str>,
29 working_path: impl Into<String>,
30 context: &BTreeMap<String, String>,
31 tuning: &SearchVariableTuning,
32 ) -> Result<Vec<VariableValue>> {
33 let flat_root_cmd = flatten_str(root_cmd);
36 let flat_variable = flatten_variable(variable_name);
37 let mut flat_variable_values = flat_variable.split('|').map(str::to_owned).collect::<Vec<_>>();
38 flat_variable_values.push(flat_variable.clone());
40 flat_variable_values.dedup();
41
42 let mut all_sql_params = Vec::with_capacity(2 + flat_variable_values.len());
50 all_sql_params.push(Value::from(tuning.path.exact));
51 all_sql_params.push(Value::from(tuning.path.ancestor));
52 all_sql_params.push(Value::from(tuning.path.descendant));
53 all_sql_params.push(Value::from(tuning.path.unrelated));
54 all_sql_params.push(Value::from(tuning.path.points));
55 all_sql_params.push(Value::from(tuning.context.points));
56 all_sql_params.push(Value::from(flat_root_cmd));
57 all_sql_params.push(Value::from(flat_variable));
58 all_sql_params.push(Value::from(working_path.into()));
59 all_sql_params.push(Value::from(serde_json::to_string(context)?));
60 let prev_params_len = all_sql_params.len();
61 let mut in_placeholders = Vec::new();
62 for (idx, variable_param) in flat_variable_values.into_iter().enumerate() {
63 all_sql_params.push(Value::from(variable_param));
64 in_placeholders.push(format!("?{}", idx + prev_params_len + 1));
65 }
66 let in_placeholders = in_placeholders.join(",");
67
68 let query = format!(
70 r#"WITH
71 -- Pre-calculate the total number of variables in the query context
72 context_info AS (
73 SELECT MAX(CAST(total AS REAL)) AS total_variables
74 FROM (
75 SELECT COUNT(*) as total FROM json_each(?10)
76 UNION ALL SELECT 0
77 )
78 ),
79 -- Calculate the individual relevance score for each unique usage record
80 value_scores AS (
81 SELECT
82 v.value,
83 u.usage_count,
84 CASE
85 -- Exact path match
86 WHEN u.path = ?9 THEN ?1
87 -- Ascendant path match (parent)
88 WHEN ?9 LIKE u.path || '/%' THEN ?2
89 -- Descendant path match (child)
90 WHEN u.path LIKE ?9 || '/%' THEN ?3
91 -- Other/unrelated path
92 ELSE ?4
93 END AS path_relevance,
94 (
95 SELECT
96 CASE
97 WHEN ci.total_variables > 0 THEN (CAST(COUNT(*) AS REAL) / ci.total_variables)
98 ELSE 0
99 END
100 FROM json_each(?10) AS query_ctx
101 CROSS JOIN context_info ci
102 WHERE json_extract(u.context_json, '$."' || query_ctx.key || '"') = query_ctx.value
103 ) AS context_relevance
104 FROM variable_value v
105 JOIN variable_value_usage u ON v.id = u.value_id
106 WHERE v.flat_root_cmd = ?7 AND v.flat_variable IN ({in_placeholders})
107 ),
108 -- Group by values to find the best relevance score and the total usage count
109 agg_values AS (
110 SELECT
111 vs.value,
112 MAX(
113 (vs.path_relevance * ?5)
114 + (vs.context_relevance * ?6)
115 ) as relevance_score,
116 SUM(vs.usage_count) as total_usage
117 FROM value_scores vs
118 GROUP BY vs.value
119 )
120 -- Calculate the final score and join back to find the ID
121 SELECT
122 v.id,
123 ?7 AS flat_root_cmd,
124 ?8 AS flat_variable,
125 a.value,
126 (a.relevance_score + log(a.total_usage + 1)) AS final_score
127 FROM agg_values a
128 LEFT JOIN variable_value v ON v.flat_root_cmd = ?7 AND v.flat_variable = ?8 AND v.value = a.value
129 ORDER BY final_score DESC;"#
130 );
131
132 self.client
134 .conn(move |conn| {
135 tracing::trace!("Querying variable values:\n{query}");
136 tracing::trace!("With parameters:\n{all_sql_params:?}");
137 Ok(conn
138 .prepare(&query)?
139 .query(rusqlite::params_from_iter(all_sql_params.iter()))?
140 .and_then(|r| VariableValue::try_from(r))
141 .collect::<Result<Vec<_>, _>>()?)
142 })
143 .await
144 }
145
146 #[instrument(skip_all)]
148 pub async fn insert_variable_value(&self, mut value: VariableValue) -> Result<VariableValue> {
149 if value.id.is_some() {
151 return Err(eyre!("ID should not be set when inserting a new value").into());
152 };
153
154 self.client
156 .conn_mut(move |conn| {
157 let res = conn.query_row(
158 r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value)
159 VALUES (?1, ?2, ?3)
160 RETURNING id"#,
161 (&value.flat_root_cmd, &value.flat_variable, &value.value),
162 |r| r.get(0),
163 );
164 match res {
165 Ok(id) => {
166 value.id = Some(id);
167 Ok(value)
168 }
169 Err(err) => match err.sqlite_error_code() {
170 Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
171 _ => Err(Report::from(err).into()),
172 },
173 }
174 })
175 .await
176 }
177
178 #[instrument(skip_all)]
180 pub async fn update_variable_value(&self, value: VariableValue) -> Result<VariableValue> {
181 let Some(value_id) = value.id else {
183 return Err(eyre!("ID must be set when updating a variable value").into());
184 };
185
186 self.client
188 .conn_mut(move |conn| {
189 let res = conn.execute(
190 r#"
191 UPDATE variable_value
192 SET flat_root_cmd = ?2,
193 flat_variable = ?3,
194 value = ?4
195 WHERE rowid = ?1
196 "#,
197 (&value_id, &value.flat_root_cmd, &value.flat_variable, &value.value),
198 );
199 match res {
200 Ok(0) => Err(eyre!("Variable value not found: {value_id}")
201 .wrap_err("Couldn't update a variable value")
202 .into()),
203 Ok(_) => Ok(value),
204 Err(err) => match err.sqlite_error_code() {
205 Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
206 _ => Err(Report::from(err).into()),
207 },
208 }
209 })
210 .await
211 }
212
213 #[instrument(skip_all)]
215 pub async fn increment_variable_value_usage(
216 &self,
217 value_id: i32,
218 path: impl AsRef<str> + Send + 'static,
219 context: &BTreeMap<String, String>,
220 ) -> Result<i32> {
221 let context = serde_json::to_string(context)?;
222 self.client
223 .conn_mut(move |conn| {
224 Ok(conn.query_row(
225 r#"
226 INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
227 VALUES (?1, ?2, ?3, 1)
228 ON CONFLICT(value_id, path, context_json) DO UPDATE SET
229 usage_count = usage_count + 1
230 RETURNING usage_count;"#,
231 (&value_id, &path.as_ref(), &context),
232 |r| r.get(0),
233 )?)
234 })
235 .await
236 }
237
238 #[instrument(skip_all)]
242 pub async fn delete_variable_value(&self, value_id: i32) -> Result<()> {
243 self.client
244 .conn_mut(move |conn| {
245 let res = conn.execute("DELETE FROM variable_value WHERE rowid = ?1", (&value_id,));
246 match res {
247 Ok(0) => Err(eyre!("Variable value not found: {value_id}").into()),
248 Ok(_) => Ok(()),
249 Err(err) => Err(Report::from(err).into()),
250 }
251 })
252 .await
253 }
254}
255
256impl<'a> TryFrom<&'a Row<'a>> for VariableValue {
257 type Error = rusqlite::Error;
258
259 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
260 Ok(Self {
261 id: row.get(0)?,
262 flat_root_cmd: row.get(1)?,
263 flat_variable: row.get(2)?,
264 value: row.get(3)?,
265 })
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use pretty_assertions::assert_eq;
272
273 use super::*;
274 use crate::errors::AppError;
275
276 #[tokio::test]
277 async fn test_find_variable_values_empty() {
278 let storage = SqliteStorage::new_in_memory().await.unwrap();
279 let values = storage
280 .find_variable_values(
281 "cmd",
282 "variable",
283 "/some/path",
284 &BTreeMap::new(),
285 &SearchVariableTuning::default(),
286 )
287 .await
288 .unwrap();
289 assert!(values.is_empty());
290 }
291
292 #[tokio::test]
293 async fn test_find_variable_values_path_relevance_ranking() {
294 let storage = SqliteStorage::new_in_memory().await.unwrap();
295 let root = "docker";
296 let variable = "image";
297 let current_path = "/home/user/project-a/api";
298
299 storage
301 .setup_variable_value(root, variable, "unrelated-path", "/var/www", [], 1)
302 .await;
303 storage
304 .setup_variable_value(root, variable, "child-path", "/home/user/project-a/api/db", [], 1)
305 .await;
306 storage
307 .setup_variable_value(root, variable, "parent-path", "/home/user/project-a", [], 1)
308 .await;
309 storage
310 .setup_variable_value(root, variable, "exact-path", current_path, [], 1)
311 .await;
312
313 let matches = storage
314 .find_variable_values(
315 root,
316 variable,
317 current_path,
318 &BTreeMap::new(),
319 &SearchVariableTuning::default(),
320 )
321 .await
322 .unwrap();
323
324 assert_eq!(matches.len(), 4);
326 assert_eq!(matches[0].value, "exact-path");
327 assert_eq!(matches[1].value, "parent-path");
328 assert_eq!(matches[2].value, "child-path");
329 assert_eq!(matches[3].value, "unrelated-path");
330 }
331
332 #[tokio::test]
333 async fn test_find_variable_values_context_relevance_ranking() {
334 let storage = SqliteStorage::new_in_memory().await.unwrap();
335 let root = "kubectl";
336 let variable = "port";
337 let current_path = "/home/user/k8s";
338 let query_context = [("namespace", "prod"), ("service", "api-gateway")];
339
340 storage
342 .setup_variable_value(root, variable, "no-context", current_path, [], 1)
343 .await;
344 storage
345 .setup_variable_value(
346 root,
347 variable,
348 "partial-context",
349 current_path,
350 [("namespace", "prod")],
351 1,
352 )
353 .await;
354 storage
355 .setup_variable_value(root, variable, "full-context", current_path, query_context, 1)
356 .await;
357
358 let matches = storage
359 .find_variable_values(
360 root,
361 variable,
362 current_path,
363 &BTreeMap::from_iter(query_context.into_iter().map(|(k, v)| (k.to_owned(), v.to_owned()))),
364 &SearchVariableTuning::default(),
365 )
366 .await
367 .unwrap();
368
369 assert_eq!(matches.len(), 3);
371 assert_eq!(matches[0].value, "full-context");
372 assert_eq!(matches[1].value, "partial-context");
373 assert_eq!(matches[2].value, "no-context");
374 }
375
376 #[tokio::test]
377 async fn test_find_variable_values_usage_count_is_tiebreaker_only() {
378 let storage = SqliteStorage::new_in_memory().await.unwrap();
379 let root = "git";
380 let variable = "branch";
381 let current_path = "/home/user/project";
382
383 storage
385 .setup_variable_value(root, variable, "feature-a", current_path, [], 5)
386 .await;
387 storage
388 .setup_variable_value(root, variable, "feature-b", current_path, [], 50)
389 .await;
390 storage
392 .setup_variable_value(root, variable, "release-1.0", "/other/path", [], 9999)
393 .await;
394
395 let matches = storage
396 .find_variable_values(
397 root,
398 variable,
399 current_path,
400 &BTreeMap::new(),
401 &SearchVariableTuning::default(),
402 )
403 .await
404 .unwrap();
405
406 assert_eq!(matches.len(), 3);
408 assert_eq!(matches[0].value, "feature-b");
409 assert_eq!(matches[1].value, "feature-a");
410 assert_eq!(matches[2].value, "release-1.0");
411 }
412
413 #[tokio::test]
414 async fn test_find_variable_values_aggregates_from_multiple_variables() {
415 let storage = SqliteStorage::new_in_memory().await.unwrap();
416 let root = "kubectl";
417 let variable_composite = "pod|service";
418
419 storage
421 .setup_variable_value(root, "pod", "api-pod-123", "/path", [], 4)
422 .await;
423 storage
424 .setup_variable_value(root, "service", "api-service", "/path", [], 5)
425 .await;
426 let sug_composite = storage
428 .setup_variable_value(root, variable_composite, "api-pod-123", "/path", [], 4)
429 .await;
430
431 let matches = storage
432 .find_variable_values(
433 root,
434 variable_composite,
435 "/path",
436 &BTreeMap::new(),
437 &SearchVariableTuning::default(),
438 )
439 .await
440 .unwrap();
441
442 assert_eq!(matches.len(), 2);
443 assert_eq!(matches[0].value, "api-pod-123");
444 assert_eq!(matches[0].id, sug_composite.id);
445 assert_eq!(matches[1].value, "api-service");
446 assert!(matches[1].id.is_none());
447 }
448
449 #[tokio::test]
450 async fn test_insert_variable_value() {
451 let storage = SqliteStorage::new_in_memory().await.unwrap();
452 let sug = VariableValue::new("cmd", "variable", "value");
453
454 let inserted_sug = storage.insert_variable_value(sug.clone()).await.unwrap();
455 assert_eq!(inserted_sug.value, sug.value);
456
457 match storage.insert_variable_value(sug.clone()).await {
459 Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
460 res => panic!("Expected VariableValueAlreadyExists error, got {res:?}"),
461 }
462 }
463
464 #[tokio::test]
465 async fn test_update_variable_value() {
466 let storage = SqliteStorage::new_in_memory().await.unwrap();
467 let sug1 = VariableValue::new("cmd", "variable", "value_orig");
468
469 let mut var1 = storage.insert_variable_value(sug1).await.unwrap();
471
472 var1.value = "value_updated".to_string();
474 let res = storage.update_variable_value(var1.clone()).await;
475 assert!(res.is_ok(), "Expected successful update, got {res:?}");
476 let sug1 = res.unwrap();
477 assert_eq!(sug1.value, "value_updated");
478
479 let mut non_existent_sug = sug1.clone();
481 non_existent_sug.id = Some(999);
482 match storage.update_variable_value(non_existent_sug).await {
483 Err(_) => (),
484 res => panic!("Expected error, got {res:?}"),
485 }
486
487 let var2 = VariableValue::new("cmd", "variable", "value_other");
489 let mut sug2 = storage.insert_variable_value(var2).await.unwrap();
490 sug2.value = "value_updated".to_string();
491 match storage.update_variable_value(sug2).await {
492 Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
493 res => panic!("Expected VariableValueAlreadyExists error for constraint violation, got {res:?}"),
494 }
495 }
496
497 #[tokio::test]
498 async fn test_increment_variable_value_usage() {
499 let storage = SqliteStorage::new_in_memory().await.unwrap();
500
501 let val = storage
503 .insert_variable_value(VariableValue::new("root", "variable", "value"))
504 .await
505 .unwrap();
506 let val_id = val.id.unwrap();
507
508 let count = storage
510 .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
511 .await
512 .unwrap();
513 assert_eq!(count, 1);
514
515 let count = storage
517 .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
518 .await
519 .unwrap();
520 assert_eq!(count, 2);
521 }
522
523 #[tokio::test]
524 async fn test_delete_variable_value() {
525 let storage = SqliteStorage::new_in_memory().await.unwrap();
526 let sug = VariableValue::new("cmd", "variable_del", "value_to_delete");
527
528 let sug = storage.insert_variable_value(sug).await.unwrap();
530 let id_to_delete = sug.id.unwrap();
531
532 let res = storage.delete_variable_value(id_to_delete).await;
534 assert!(res.is_ok(), "Expected successful update, got {res:?}");
535
536 match storage.delete_variable_value(id_to_delete).await {
538 Err(_) => (),
539 res => panic!("Expected error, got {res:?}"),
540 }
541 }
542
543 impl SqliteStorage {
544 async fn setup_variable_value(
547 &self,
548 root: &'static str,
549 variable: &'static str,
550 value: &'static str,
551 path: &'static str,
552 context: impl IntoIterator<Item = (&str, &str)>,
553 usage_count: i32,
554 ) -> VariableValue {
555 let context = serde_json::to_string(&BTreeMap::<String, String>::from_iter(
556 context.into_iter().map(|(k, v)| (k.to_string(), v.to_string())),
557 ))
558 .unwrap();
559
560 self.client
561 .conn_mut(move |conn| {
562 let sug = conn.query_row(
563 r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value)
564 VALUES (?1, ?2, ?3)
565 ON CONFLICT (flat_root_cmd, flat_variable, value) DO UPDATE SET
566 value = excluded.value
567 RETURNING id, flat_root_cmd, flat_variable, value"#,
568 (root, variable, value),
569 |r| VariableValue::try_from(r),
570 )?;
571 conn.execute(
572 r#"INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
573 VALUES (?1, ?2, ?3, ?4)
574 ON CONFLICT(value_id, path, context_json) DO UPDATE SET
575 usage_count = excluded.usage_count;
576 "#,
577 (&sug.id, path, &context, usage_count),
578 )?;
579 Ok(sug)
580 })
581 .await
582 .unwrap()
583 }
584 }
585}