1use std::collections::BTreeMap;
2
3use color_eyre::{Report, eyre::eyre};
4use rusqlite::{ErrorCode, Row, types::Value};
5use tracing::instrument;
6
7use super::SqliteStorage;
8use crate::{
9 config::SearchVariableTuning,
10 errors::{Result, UserFacingError},
11 model::VariableValue,
12};
13
14impl SqliteStorage {
15 #[instrument(skip_all)]
23 pub async fn find_variable_values(
24 &self,
25 flat_root_cmd: impl Into<String>,
26 flat_variable_name: impl Into<String>,
27 mut flat_variable_names: Vec<String>,
28 working_path: impl Into<String>,
29 context: &BTreeMap<String, String>,
30 tuning: &SearchVariableTuning,
31 ) -> Result<Vec<(VariableValue, f64)>> {
32 let flat_variable_name = flat_variable_name.into();
34 flat_variable_names.push(flat_variable_name.clone());
35 flat_variable_names.dedup();
36
37 let mut all_sql_params = Vec::with_capacity(10 + flat_variable_names.len());
45 all_sql_params.push(Value::from(tuning.path.exact));
46 all_sql_params.push(Value::from(tuning.path.ancestor));
47 all_sql_params.push(Value::from(tuning.path.descendant));
48 all_sql_params.push(Value::from(tuning.path.unrelated));
49 all_sql_params.push(Value::from(tuning.path.points));
50 all_sql_params.push(Value::from(tuning.context.points));
51 all_sql_params.push(Value::from(flat_root_cmd.into()));
52 all_sql_params.push(Value::from(flat_variable_name));
53 all_sql_params.push(Value::from(working_path.into()));
54 all_sql_params.push(Value::from(serde_json::to_string(context)?));
55 let prev_params_len = all_sql_params.len();
56 let mut in_placeholders = Vec::new();
57 for (idx, flat_name) in flat_variable_names.into_iter().enumerate() {
58 all_sql_params.push(Value::from(flat_name));
59 in_placeholders.push(format!("?{}", idx + prev_params_len + 1));
60 }
61 let in_placeholders = in_placeholders.join(",");
62
63 let query = format!(
65 r#"WITH
66 -- Pre-calculate the total number of variables in the query context
67 context_info AS (
68 SELECT MAX(CAST(total AS REAL)) AS total_variables
69 FROM (
70 SELECT COUNT(*) as total FROM json_each(?10)
71 UNION ALL SELECT 0
72 )
73 ),
74 -- Calculate the individual relevance score for each unique usage record
75 value_scores AS (
76 SELECT
77 v.value,
78 u.usage_count,
79 CASE
80 -- Exact path match
81 WHEN u.path = ?9 THEN ?1
82 -- Ascendant path match (parent)
83 WHEN ?9 LIKE u.path || '/%' THEN ?2
84 -- Descendant path match (child)
85 WHEN u.path LIKE ?9 || '/%' THEN ?3
86 -- Other/unrelated path
87 ELSE ?4
88 END AS path_relevance,
89 (
90 SELECT
91 CASE
92 WHEN ci.total_variables > 0 THEN (CAST(COUNT(*) AS REAL) / ci.total_variables)
93 ELSE 0
94 END
95 FROM json_each(?10) AS query_ctx
96 CROSS JOIN context_info ci
97 WHERE json_extract(u.context_json, '$."' || query_ctx.key || '"') = query_ctx.value
98 ) AS context_relevance
99 FROM variable_value v
100 JOIN variable_value_usage u ON v.id = u.value_id
101 WHERE v.flat_root_cmd = ?7 AND v.flat_variable IN ({in_placeholders})
102 ),
103 -- Group by values to find the best relevance score and the total usage count
104 agg_values AS (
105 SELECT
106 vs.value,
107 MAX(
108 (vs.path_relevance * ?5)
109 + (vs.context_relevance * ?6)
110 ) as relevance_score,
111 SUM(vs.usage_count) as total_usage
112 FROM value_scores vs
113 GROUP BY vs.value
114 )
115 -- Calculate the final score and join back to find the ID
116 SELECT
117 v.id,
118 ?7 AS flat_root_cmd,
119 ?8 AS flat_variable,
120 a.value,
121 (a.relevance_score + log(a.total_usage + 1)) AS final_score
122 FROM agg_values a
123 LEFT JOIN variable_value v ON v.flat_root_cmd = ?7 AND v.flat_variable = ?8 AND v.value = a.value
124 ORDER BY final_score DESC;"#
125 );
126
127 self.client
129 .conn(move |conn| {
130 tracing::trace!("Querying variable values:\n{query}");
131 tracing::trace!("With parameters:\n{all_sql_params:?}");
132 Ok(conn
133 .prepare(&query)?
134 .query_map(rusqlite::params_from_iter(all_sql_params.iter()), |r| {
135 Ok((VariableValue::try_from(r)?, r.get(4)?))
136 })?
137 .collect::<Result<Vec<_>, _>>()?)
138 })
139 .await
140 }
141
142 #[instrument(skip_all)]
144 pub async fn insert_variable_value(&self, mut value: VariableValue) -> Result<VariableValue> {
145 if value.id.is_some() {
147 return Err(eyre!("ID should not be set when inserting a new value").into());
148 };
149
150 self.client
152 .conn_mut(move |conn| {
153 let query = r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value)
154 VALUES (?1, ?2, ?3)
155 RETURNING id"#;
156 tracing::trace!("Inserting a variable value: {query}");
157 let res = conn.query_row(query, (&value.flat_root_cmd, &value.flat_variable, &value.value), |r| {
158 r.get(0)
159 });
160 match res {
161 Ok(id) => {
162 value.id = Some(id);
163 Ok(value)
164 }
165 Err(err) => match err.sqlite_error_code() {
166 Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
167 _ => Err(Report::from(err).into()),
168 },
169 }
170 })
171 .await
172 }
173
174 #[instrument(skip_all)]
176 pub async fn update_variable_value(&self, value: VariableValue) -> Result<VariableValue> {
177 let Some(value_id) = value.id else {
179 return Err(eyre!("ID must be set when updating a variable value").into());
180 };
181
182 self.client
184 .conn_mut(move |conn| {
185 let query = r#"
186 UPDATE variable_value
187 SET flat_root_cmd = ?2,
188 flat_variable = ?3,
189 value = ?4
190 WHERE rowid = ?1
191 "#;
192 tracing::trace!("Updating a variable value: {query}");
193 let res = conn.execute(
194 query,
195 (&value_id, &value.flat_root_cmd, &value.flat_variable, &value.value),
196 );
197 match res {
198 Ok(0) => Err(eyre!("Variable value not found: {value_id}")
199 .wrap_err("Couldn't update a variable value")
200 .into()),
201 Ok(_) => Ok(value),
202 Err(err) => match err.sqlite_error_code() {
203 Some(ErrorCode::ConstraintViolation) => Err(UserFacingError::VariableValueAlreadyExists.into()),
204 _ => Err(Report::from(err).into()),
205 },
206 }
207 })
208 .await
209 }
210
211 #[instrument(skip_all)]
213 pub async fn increment_variable_value_usage(
214 &self,
215 value_id: i32,
216 path: impl AsRef<str> + Send + 'static,
217 context: &BTreeMap<String, String>,
218 ) -> Result<i32> {
219 let context = serde_json::to_string(context)?;
220 self.client
221 .conn_mut(move |conn| {
222 let query = r#"
223 INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
224 VALUES (?1, ?2, ?3, 1)
225 ON CONFLICT(value_id, path, context_json) DO UPDATE SET
226 usage_count = usage_count + 1
227 RETURNING usage_count;"#;
228 tracing::trace!("Incrementing variable value usage: {query}");
229 Ok(conn.query_row(query, (&value_id, &path.as_ref(), &context), |r| r.get(0))?)
230 })
231 .await
232 }
233
234 #[instrument(skip_all)]
238 pub async fn delete_variable_value(&self, value_id: i32) -> Result<()> {
239 self.client
240 .conn_mut(move |conn| {
241 let query = "DELETE FROM variable_value WHERE rowid = ?1";
242 tracing::trace!("Deleting a variable value: {query}");
243 let res = conn.execute(query, (&value_id,));
244 match res {
245 Ok(0) => Err(eyre!("Variable value not found: {value_id}").into()),
246 Ok(_) => Ok(()),
247 Err(err) => Err(Report::from(err).into()),
248 }
249 })
250 .await
251 }
252}
253
254impl<'a> TryFrom<&'a Row<'a>> for VariableValue {
255 type Error = rusqlite::Error;
256
257 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
258 Ok(Self {
259 id: row.get(0)?,
260 flat_root_cmd: row.get(1)?,
261 flat_variable: row.get(2)?,
262 value: row.get(3)?,
263 })
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use pretty_assertions::assert_eq;
270
271 use super::*;
272 use crate::errors::AppError;
273
274 #[tokio::test]
275 async fn test_find_variable_values_empty() {
276 let storage = SqliteStorage::new_in_memory().await.unwrap();
277 let values = storage
278 .find_variable_values(
279 "cmd",
280 "variable",
281 Vec::new(),
282 "/some/path",
283 &BTreeMap::new(),
284 &SearchVariableTuning::default(),
285 )
286 .await
287 .unwrap();
288 assert!(values.is_empty());
289 }
290
291 #[tokio::test]
292 async fn test_find_variable_values_path_relevance_ranking() {
293 let storage = SqliteStorage::new_in_memory().await.unwrap();
294 let root = "docker";
295 let variable = "image";
296 let current_path = "/home/user/project-a/api";
297
298 storage
300 .setup_variable_value(root, variable, "unrelated-path", "/var/www", [], 1)
301 .await;
302 storage
303 .setup_variable_value(root, variable, "child-path", "/home/user/project-a/api/db", [], 1)
304 .await;
305 storage
306 .setup_variable_value(root, variable, "parent-path", "/home/user/project-a", [], 1)
307 .await;
308 storage
309 .setup_variable_value(root, variable, "exact-path", current_path, [], 1)
310 .await;
311
312 let matches = storage
313 .find_variable_values(
314 root,
315 variable,
316 Vec::new(),
317 current_path,
318 &BTreeMap::new(),
319 &SearchVariableTuning::default(),
320 )
321 .await
322 .unwrap();
323
324 assert_eq!(matches.len(), 4);
326 assert_eq!(matches[0].0.value, "exact-path");
327 assert_eq!(matches[1].0.value, "parent-path");
328 assert_eq!(matches[2].0.value, "child-path");
329 assert_eq!(matches[3].0.value, "unrelated-path");
330 }
331
332 #[tokio::test]
333 async fn test_find_variable_values_context_relevance_ranking() {
334 let storage = SqliteStorage::new_in_memory().await.unwrap();
335 let root = "kubectl";
336 let variable = "port";
337 let current_path = "/home/user/k8s";
338 let query_context = [("namespace", "prod"), ("service", "api-gateway")];
339
340 storage
342 .setup_variable_value(root, variable, "no-context", current_path, [], 1)
343 .await;
344 storage
345 .setup_variable_value(
346 root,
347 variable,
348 "partial-context",
349 current_path,
350 [("namespace", "prod")],
351 1,
352 )
353 .await;
354 storage
355 .setup_variable_value(root, variable, "full-context", current_path, query_context, 1)
356 .await;
357
358 let matches = storage
359 .find_variable_values(
360 root,
361 variable,
362 Vec::new(),
363 current_path,
364 &BTreeMap::from_iter(query_context.into_iter().map(|(k, v)| (k.to_owned(), v.to_owned()))),
365 &SearchVariableTuning::default(),
366 )
367 .await
368 .unwrap();
369
370 assert_eq!(matches.len(), 3);
372 assert_eq!(matches[0].0.value, "full-context");
373 assert_eq!(matches[1].0.value, "partial-context");
374 assert_eq!(matches[2].0.value, "no-context");
375 }
376
377 #[tokio::test]
378 async fn test_find_variable_values_usage_count_is_tiebreaker_only() {
379 let storage = SqliteStorage::new_in_memory().await.unwrap();
380 let root = "git";
381 let variable = "branch";
382 let current_path = "/home/user/project";
383
384 storage
386 .setup_variable_value(root, variable, "feature-a", current_path, [], 5)
387 .await;
388 storage
389 .setup_variable_value(root, variable, "feature-b", current_path, [], 50)
390 .await;
391 storage
393 .setup_variable_value(root, variable, "release-1.0", "/other/path", [], 9999)
394 .await;
395
396 let matches = storage
397 .find_variable_values(
398 root,
399 variable,
400 Vec::new(),
401 current_path,
402 &BTreeMap::new(),
403 &SearchVariableTuning::default(),
404 )
405 .await
406 .unwrap();
407
408 assert_eq!(matches.len(), 3);
410 assert_eq!(matches[0].0.value, "feature-b");
411 assert_eq!(matches[1].0.value, "feature-a");
412 assert_eq!(matches[2].0.value, "release-1.0");
413 }
414
415 #[tokio::test]
416 async fn test_find_variable_values_aggregates_from_multiple_variables() {
417 let storage = SqliteStorage::new_in_memory().await.unwrap();
418 let root = "kubectl";
419 let variable_composite = "pod|service";
420 let variable_composite_names = variable_composite.split("|").map(String::from).collect::<Vec<_>>();
421
422 storage
424 .setup_variable_value(root, "pod", "api-pod-123", "/path", [], 4)
425 .await;
426 storage
427 .setup_variable_value(root, "service", "api-service", "/path", [], 5)
428 .await;
429 let sug_composite = storage
431 .setup_variable_value(root, variable_composite, "api-pod-123", "/path", [], 4)
432 .await;
433
434 let matches = storage
435 .find_variable_values(
436 root,
437 variable_composite,
438 variable_composite_names,
439 "/path",
440 &BTreeMap::new(),
441 &SearchVariableTuning::default(),
442 )
443 .await
444 .unwrap();
445
446 assert_eq!(matches.len(), 2);
447 assert_eq!(matches[0].0.value, "api-pod-123");
448 assert_eq!(matches[0].0.id, sug_composite.id);
449 assert_eq!(matches[1].0.value, "api-service");
450 assert!(matches[1].0.id.is_none());
451 }
452
453 #[tokio::test]
454 async fn test_insert_variable_value() {
455 let storage = SqliteStorage::new_in_memory().await.unwrap();
456 let sug = VariableValue::new("cmd", "variable", "value");
457
458 let inserted_sug = storage.insert_variable_value(sug.clone()).await.unwrap();
459 assert_eq!(inserted_sug.value, sug.value);
460
461 match storage.insert_variable_value(sug.clone()).await {
463 Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
464 res => panic!("Expected VariableValueAlreadyExists error, got {res:?}"),
465 }
466 }
467
468 #[tokio::test]
469 async fn test_update_variable_value() {
470 let storage = SqliteStorage::new_in_memory().await.unwrap();
471 let sug1 = VariableValue::new("cmd", "variable", "value_orig");
472
473 let mut var1 = storage.insert_variable_value(sug1).await.unwrap();
475
476 var1.value = "value_updated".to_string();
478 let res = storage.update_variable_value(var1.clone()).await;
479 assert!(res.is_ok(), "Expected successful update, got {res:?}");
480 let sug1 = res.unwrap();
481 assert_eq!(sug1.value, "value_updated");
482
483 let mut non_existent_sug = sug1.clone();
485 non_existent_sug.id = Some(999);
486 match storage.update_variable_value(non_existent_sug).await {
487 Err(_) => (),
488 res => panic!("Expected error, got {res:?}"),
489 }
490
491 let var2 = VariableValue::new("cmd", "variable", "value_other");
493 let mut sug2 = storage.insert_variable_value(var2).await.unwrap();
494 sug2.value = "value_updated".to_string();
495 match storage.update_variable_value(sug2).await {
496 Err(AppError::UserFacing(UserFacingError::VariableValueAlreadyExists)) => (),
497 res => panic!("Expected VariableValueAlreadyExists error for constraint violation, got {res:?}"),
498 }
499 }
500
501 #[tokio::test]
502 async fn test_increment_variable_value_usage() {
503 let storage = SqliteStorage::new_in_memory().await.unwrap();
504
505 let val = storage
507 .insert_variable_value(VariableValue::new("root", "variable", "value"))
508 .await
509 .unwrap();
510 let val_id = val.id.unwrap();
511
512 let count = storage
514 .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
515 .await
516 .unwrap();
517 assert_eq!(count, 1);
518
519 let count = storage
521 .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
522 .await
523 .unwrap();
524 assert_eq!(count, 2);
525 }
526
527 #[tokio::test]
528 async fn test_delete_variable_value() {
529 let storage = SqliteStorage::new_in_memory().await.unwrap();
530 let sug = VariableValue::new("cmd", "variable_del", "value_to_delete");
531
532 let sug = storage.insert_variable_value(sug).await.unwrap();
534 let id_to_delete = sug.id.unwrap();
535
536 let res = storage.delete_variable_value(id_to_delete).await;
538 assert!(res.is_ok(), "Expected successful update, got {res:?}");
539
540 match storage.delete_variable_value(id_to_delete).await {
542 Err(_) => (),
543 res => panic!("Expected error, got {res:?}"),
544 }
545 }
546
547 impl SqliteStorage {
548 async fn setup_variable_value(
551 &self,
552 root: &'static str,
553 variable: &'static str,
554 value: &'static str,
555 path: &'static str,
556 context: impl IntoIterator<Item = (&str, &str)>,
557 usage_count: i32,
558 ) -> VariableValue {
559 let context = serde_json::to_string(&BTreeMap::<String, String>::from_iter(
560 context.into_iter().map(|(k, v)| (k.to_string(), v.to_string())),
561 ))
562 .unwrap();
563
564 self.client
565 .conn_mut(move |conn| {
566 let sug = conn.query_row(
567 r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value)
568 VALUES (?1, ?2, ?3)
569 ON CONFLICT (flat_root_cmd, flat_variable, value) DO UPDATE SET
570 value = excluded.value
571 RETURNING id, flat_root_cmd, flat_variable, value"#,
572 (root, variable, value),
573 |r| VariableValue::try_from(r),
574 )?;
575 conn.execute(
576 r#"INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
577 VALUES (?1, ?2, ?3, ?4)
578 ON CONFLICT(value_id, path, context_json) DO UPDATE SET
579 usage_count = excluded.usage_count;
580 "#,
581 (&sug.id, path, &context, usage_count),
582 )?;
583 Ok(sug)
584 })
585 .await
586 .unwrap()
587 }
588 }
589}