1use std::collections::BTreeMap;
2
3use color_eyre::{
4 Report, Result,
5 eyre::{Context, eyre},
6};
7use rusqlite::{ErrorCode, Row, types::Value};
8use tracing::instrument;
9
10use super::SqliteStorage;
11use crate::{
12 config::SearchVariableTuning,
13 errors::{InsertError, UpdateError},
14 model::VariableValue,
15 utils::{flatten_str, flatten_variable},
16};
17
18impl SqliteStorage {
19 #[instrument(skip_all)]
28 pub async fn find_variable_values(
29 &self,
30 root_cmd: impl AsRef<str>,
31 variable_name: impl AsRef<str>,
32 working_path: impl Into<String>,
33 context: &BTreeMap<String, String>,
34 tuning: &SearchVariableTuning,
35 ) -> Result<Vec<VariableValue>> {
36 let flat_root_cmd = flatten_str(root_cmd);
39 let flat_variable = flatten_variable(variable_name);
40 let mut flat_variable_values = flat_variable.split('|').map(str::to_owned).collect::<Vec<_>>();
41 flat_variable_values.push(flat_variable.clone());
43 flat_variable_values.dedup();
44
45 let mut all_sql_params = Vec::with_capacity(2 + flat_variable_values.len());
53 all_sql_params.push(Value::from(tuning.path.exact));
54 all_sql_params.push(Value::from(tuning.path.ancestor));
55 all_sql_params.push(Value::from(tuning.path.descendant));
56 all_sql_params.push(Value::from(tuning.path.unrelated));
57 all_sql_params.push(Value::from(tuning.path.points));
58 all_sql_params.push(Value::from(tuning.context.points));
59 all_sql_params.push(Value::from(flat_root_cmd));
60 all_sql_params.push(Value::from(flat_variable));
61 all_sql_params.push(Value::from(working_path.into()));
62 all_sql_params.push(Value::from(serde_json::to_string(context)?));
63 let prev_params_len = all_sql_params.len();
64 let mut in_placeholders = Vec::new();
65 for (idx, variable_param) in flat_variable_values.into_iter().enumerate() {
66 all_sql_params.push(Value::from(variable_param));
67 in_placeholders.push(format!("?{}", idx + prev_params_len + 1));
68 }
69 let in_placeholders = in_placeholders.join(",");
70
71 let query = format!(
73 r#"WITH
74 -- Pre-calculate the total number of variables in the query context
75 context_info AS (
76 SELECT MAX(CAST(total AS REAL)) AS total_variables
77 FROM (
78 SELECT COUNT(*) as total FROM json_each(?10)
79 UNION ALL SELECT 0
80 )
81 ),
82 -- Calculate the individual relevance score for each unique usage record
83 value_scores AS (
84 SELECT
85 v.value,
86 u.usage_count,
87 CASE
88 -- Exact path match
89 WHEN u.path = ?9 THEN ?1
90 -- Ascendant path match (parent)
91 WHEN ?9 LIKE u.path || '/%' THEN ?2
92 -- Descendant path match (child)
93 WHEN u.path LIKE ?9 || '/%' THEN ?3
94 -- Other/unrelated path
95 ELSE ?4
96 END AS path_relevance,
97 (
98 SELECT
99 CASE
100 WHEN ci.total_variables > 0 THEN (CAST(COUNT(*) AS REAL) / ci.total_variables)
101 ELSE 0
102 END
103 FROM json_each(?10) AS query_ctx
104 CROSS JOIN context_info ci
105 WHERE json_extract(u.context_json, '$."' || query_ctx.key || '"') = query_ctx.value
106 ) AS context_relevance
107 FROM variable_value v
108 JOIN variable_value_usage u ON v.id = u.value_id
109 WHERE v.flat_root_cmd = ?7 AND v.flat_variable IN ({in_placeholders})
110 ),
111 -- Group by values to find the best relevance score and the total usage count
112 agg_values AS (
113 SELECT
114 vs.value,
115 MAX(
116 (vs.path_relevance * ?5)
117 + (vs.context_relevance * ?6)
118 ) as relevance_score,
119 SUM(vs.usage_count) as total_usage
120 FROM value_scores vs
121 GROUP BY vs.value
122 )
123 -- Calculate the final score and join back to find the ID
124 SELECT
125 v.id,
126 ?7 AS flat_root_cmd,
127 ?8 AS flat_variable,
128 a.value,
129 (a.relevance_score + log(a.total_usage + 1)) AS final_score
130 FROM agg_values a
131 LEFT JOIN variable_value v ON v.flat_root_cmd = ?7 AND v.flat_variable = ?8 AND v.value = a.value
132 ORDER BY final_score DESC;"#
133 );
134
135 self.client
137 .conn::<_, _, Report>(move |conn| {
138 tracing::trace!("Querying variable values:\n{query}");
139 tracing::trace!("With parameters:\n{all_sql_params:?}");
140 Ok(conn
141 .prepare(&query)?
142 .query(rusqlite::params_from_iter(all_sql_params.iter()))?
143 .and_then(|r| VariableValue::try_from(r))
144 .collect::<Result<Vec<_>, _>>()?)
145 })
146 .await
147 .wrap_err("Couldn't find variable values")
148 }
149
150 #[instrument(skip_all)]
152 pub async fn insert_variable_value(&self, mut value: VariableValue) -> Result<VariableValue, InsertError> {
153 if value.id.is_some() {
155 return Err(eyre!("ID should not be set when inserting a new value").into());
156 };
157
158 self.client
160 .conn_mut(move |conn| {
161 let res = conn.query_row(
162 r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value)
163 VALUES (?1, ?2, ?3)
164 RETURNING id"#,
165 (&value.flat_root_cmd, &value.flat_variable, &value.value),
166 |r| r.get(0),
167 );
168 match res {
169 Ok(id) => {
170 value.id = Some(id);
171 Ok(value)
172 }
173 Err(err) => match err.sqlite_error_code() {
174 Some(ErrorCode::ConstraintViolation) => Err(InsertError::AlreadyExists),
175 _ => Err(Report::from(err).wrap_err("Couldn't insert a variable value").into()),
176 },
177 }
178 })
179 .await
180 }
181
182 #[instrument(skip_all)]
184 pub async fn update_variable_value(&self, value: VariableValue) -> Result<VariableValue, UpdateError> {
185 let Some(value_id) = value.id else {
187 return Err(eyre!("ID must be set when updating a variable value").into());
188 };
189
190 self.client
192 .conn_mut(move |conn| {
193 let res = conn.execute(
194 r#"
195 UPDATE variable_value
196 SET flat_root_cmd = ?2,
197 flat_variable = ?3,
198 value = ?4
199 WHERE rowid = ?1
200 "#,
201 (&value_id, &value.flat_root_cmd, &value.flat_variable, &value.value),
202 );
203 match res {
204 Ok(0) => Err(eyre!("Variable value not found: {value_id}")
205 .wrap_err("Couldn't update a variable value")
206 .into()),
207 Ok(_) => Ok(value),
208 Err(err) => match err.sqlite_error_code() {
209 Some(ErrorCode::ConstraintViolation) => Err(UpdateError::AlreadyExists),
210 _ => Err(Report::from(err).wrap_err("Couldn't update a variable value").into()),
211 },
212 }
213 })
214 .await
215 }
216
217 #[instrument(skip_all)]
219 pub async fn increment_variable_value_usage(
220 &self,
221 value_id: i32,
222 path: impl AsRef<str> + Send + 'static,
223 context: &BTreeMap<String, String>,
224 ) -> Result<i32, UpdateError> {
225 let context = serde_json::to_string(context)?;
226 self.client
227 .conn_mut(move |conn| {
228 let res = conn.query_row(
229 r#"
230 INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
231 VALUES (?1, ?2, ?3, 1)
232 ON CONFLICT(value_id, path, context_json) DO UPDATE SET
233 usage_count = usage_count + 1
234 RETURNING usage_count;"#,
235 (&value_id, &path.as_ref(), &context),
236 |r| r.get(0),
237 );
238 match res {
239 Ok(u) => Ok(u),
240 Err(err) => Err(Report::from(err)
241 .wrap_err("Couldn't update a variable value usage")
242 .into()),
243 }
244 })
245 .await
246 }
247
248 #[instrument(skip_all)]
252 pub async fn delete_variable_value(&self, value_id: i32) -> Result<()> {
253 self.client
254 .conn_mut(move |conn| {
255 let res = conn.execute("DELETE FROM variable_value WHERE rowid = ?1", (&value_id,));
256 match res {
257 Ok(0) => {
258 Err(eyre!("Variable value not found: {value_id}").wrap_err("Couldn't delete a variable value"))
259 }
260 Ok(_) => Ok(()),
261 Err(err) => Err(Report::from(err).wrap_err("Couldn't delete a variable value")),
262 }
263 })
264 .await
265 }
266}
267
268impl<'a> TryFrom<&'a Row<'a>> for VariableValue {
269 type Error = rusqlite::Error;
270
271 fn try_from(row: &'a Row<'a>) -> Result<Self, Self::Error> {
272 Ok(Self {
273 id: row.get(0)?,
274 flat_root_cmd: row.get(1)?,
275 flat_variable: row.get(2)?,
276 value: row.get(3)?,
277 })
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use pretty_assertions::assert_eq;
284
285 use super::*;
286
287 #[tokio::test]
288 async fn test_find_variable_values_empty() {
289 let storage = SqliteStorage::new_in_memory().await.unwrap();
290 let values = storage
291 .find_variable_values(
292 "cmd",
293 "variable",
294 "/some/path",
295 &BTreeMap::new(),
296 &SearchVariableTuning::default(),
297 )
298 .await
299 .unwrap();
300 assert!(values.is_empty());
301 }
302
303 #[tokio::test]
304 async fn test_find_variable_values_path_relevance_ranking() {
305 let storage = SqliteStorage::new_in_memory().await.unwrap();
306 let root = "docker";
307 let variable = "image";
308 let current_path = "/home/user/project-a/api";
309
310 storage
312 .setup_variable_value(root, variable, "unrelated-path", "/var/www", [], 1)
313 .await;
314 storage
315 .setup_variable_value(root, variable, "child-path", "/home/user/project-a/api/db", [], 1)
316 .await;
317 storage
318 .setup_variable_value(root, variable, "parent-path", "/home/user/project-a", [], 1)
319 .await;
320 storage
321 .setup_variable_value(root, variable, "exact-path", current_path, [], 1)
322 .await;
323
324 let matches = storage
325 .find_variable_values(
326 root,
327 variable,
328 current_path,
329 &BTreeMap::new(),
330 &SearchVariableTuning::default(),
331 )
332 .await
333 .unwrap();
334
335 assert_eq!(matches.len(), 4);
337 assert_eq!(matches[0].value, "exact-path");
338 assert_eq!(matches[1].value, "parent-path");
339 assert_eq!(matches[2].value, "child-path");
340 assert_eq!(matches[3].value, "unrelated-path");
341 }
342
343 #[tokio::test]
344 async fn test_find_variable_values_context_relevance_ranking() {
345 let storage = SqliteStorage::new_in_memory().await.unwrap();
346 let root = "kubectl";
347 let variable = "port";
348 let current_path = "/home/user/k8s";
349 let query_context = [("namespace", "prod"), ("service", "api-gateway")];
350
351 storage
353 .setup_variable_value(root, variable, "no-context", current_path, [], 1)
354 .await;
355 storage
356 .setup_variable_value(
357 root,
358 variable,
359 "partial-context",
360 current_path,
361 [("namespace", "prod")],
362 1,
363 )
364 .await;
365 storage
366 .setup_variable_value(root, variable, "full-context", current_path, query_context, 1)
367 .await;
368
369 let matches = storage
370 .find_variable_values(
371 root,
372 variable,
373 current_path,
374 &BTreeMap::from_iter(query_context.into_iter().map(|(k, v)| (k.to_owned(), v.to_owned()))),
375 &SearchVariableTuning::default(),
376 )
377 .await
378 .unwrap();
379
380 assert_eq!(matches.len(), 3);
382 assert_eq!(matches[0].value, "full-context");
383 assert_eq!(matches[1].value, "partial-context");
384 assert_eq!(matches[2].value, "no-context");
385 }
386
387 #[tokio::test]
388 async fn test_find_variable_values_usage_count_is_tiebreaker_only() {
389 let storage = SqliteStorage::new_in_memory().await.unwrap();
390 let root = "git";
391 let variable = "branch";
392 let current_path = "/home/user/project";
393
394 storage
396 .setup_variable_value(root, variable, "feature-a", current_path, [], 5)
397 .await;
398 storage
399 .setup_variable_value(root, variable, "feature-b", current_path, [], 50)
400 .await;
401 storage
403 .setup_variable_value(root, variable, "release-1.0", "/other/path", [], 9999)
404 .await;
405
406 let matches = storage
407 .find_variable_values(
408 root,
409 variable,
410 current_path,
411 &BTreeMap::new(),
412 &SearchVariableTuning::default(),
413 )
414 .await
415 .unwrap();
416
417 assert_eq!(matches.len(), 3);
419 assert_eq!(matches[0].value, "feature-b");
420 assert_eq!(matches[1].value, "feature-a");
421 assert_eq!(matches[2].value, "release-1.0");
422 }
423
424 #[tokio::test]
425 async fn test_find_variable_values_aggregates_from_multiple_variables() {
426 let storage = SqliteStorage::new_in_memory().await.unwrap();
427 let root = "kubectl";
428 let variable_composite = "pod|service";
429
430 storage
432 .setup_variable_value(root, "pod", "api-pod-123", "/path", [], 4)
433 .await;
434 storage
435 .setup_variable_value(root, "service", "api-service", "/path", [], 5)
436 .await;
437 let sug_composite = storage
439 .setup_variable_value(root, variable_composite, "api-pod-123", "/path", [], 4)
440 .await;
441
442 let matches = storage
443 .find_variable_values(
444 root,
445 variable_composite,
446 "/path",
447 &BTreeMap::new(),
448 &SearchVariableTuning::default(),
449 )
450 .await
451 .unwrap();
452
453 assert_eq!(matches.len(), 2);
454 assert_eq!(matches[0].value, "api-pod-123");
455 assert_eq!(matches[0].id, sug_composite.id);
456 assert_eq!(matches[1].value, "api-service");
457 assert!(matches[1].id.is_none());
458 }
459
460 #[tokio::test]
461 async fn test_insert_variable_value() {
462 let storage = SqliteStorage::new_in_memory().await.unwrap();
463 let sug = VariableValue::new("cmd", "variable", "value");
464
465 let inserted_sug = storage.insert_variable_value(sug.clone()).await.unwrap();
466 assert_eq!(inserted_sug.value, sug.value);
467
468 match storage.insert_variable_value(sug.clone()).await {
470 Err(InsertError::AlreadyExists) => (),
471 res => panic!("Expected AlreadyExists error, got {res:?}"),
472 }
473 }
474
475 #[tokio::test]
476 async fn test_update_variable_value() {
477 let storage = SqliteStorage::new_in_memory().await.unwrap();
478 let sug1 = VariableValue::new("cmd", "variable", "value_orig");
479
480 let mut var1 = storage.insert_variable_value(sug1).await.unwrap();
482
483 var1.value = "value_updated".to_string();
485 let res = storage.update_variable_value(var1.clone()).await;
486 assert!(res.is_ok(), "Expected successful update, got {res:?}");
487 let sug1 = res.unwrap();
488 assert_eq!(sug1.value, "value_updated");
489
490 let mut non_existent_sug = sug1.clone();
492 non_existent_sug.id = Some(999);
493 match storage.update_variable_value(non_existent_sug).await {
494 Err(_) => (),
495 res => panic!("Expected error, got {res:?}"),
496 }
497
498 let var2 = VariableValue::new("cmd", "variable", "value_other");
500 let mut sug2 = storage.insert_variable_value(var2).await.unwrap();
501 sug2.value = "value_updated".to_string();
502 match storage.update_variable_value(sug2).await {
503 Err(UpdateError::AlreadyExists) => (),
504 res => panic!("Expected AlreadyExists error for constraint violation, got {res:?}"),
505 }
506 }
507
508 #[tokio::test]
509 async fn test_increment_variable_value_usage() {
510 let storage = SqliteStorage::new_in_memory().await.unwrap();
511
512 let val = storage
514 .insert_variable_value(VariableValue::new("root", "variable", "value"))
515 .await
516 .unwrap();
517 let val_id = val.id.unwrap();
518
519 let count = storage
521 .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
522 .await
523 .unwrap();
524 assert_eq!(count, 1);
525
526 let count = storage
528 .increment_variable_value_usage(val_id, "/path", &BTreeMap::new())
529 .await
530 .unwrap();
531 assert_eq!(count, 2);
532 }
533
534 #[tokio::test]
535 async fn test_delete_variable_value() {
536 let storage = SqliteStorage::new_in_memory().await.unwrap();
537 let sug = VariableValue::new("cmd", "variable_del", "value_to_delete");
538
539 let sug = storage.insert_variable_value(sug).await.unwrap();
541 let id_to_delete = sug.id.unwrap();
542
543 let res = storage.delete_variable_value(id_to_delete).await;
545 assert!(res.is_ok(), "Expected successful update, got {res:?}");
546
547 match storage.delete_variable_value(id_to_delete).await {
549 Err(_) => (),
550 res => panic!("Expected error, got {res:?}"),
551 }
552 }
553
554 impl SqliteStorage {
555 async fn setup_variable_value(
558 &self,
559 root: &'static str,
560 variable: &'static str,
561 value: &'static str,
562 path: &'static str,
563 context: impl IntoIterator<Item = (&str, &str)>,
564 usage_count: i32,
565 ) -> VariableValue {
566 let context = serde_json::to_string(&BTreeMap::<String, String>::from_iter(
567 context.into_iter().map(|(k, v)| (k.to_string(), v.to_string())),
568 ))
569 .unwrap();
570
571 self.client
572 .conn_mut::<_, _, Report>(move |conn| {
573 let sug = conn.query_row(
574 r#"INSERT INTO variable_value (flat_root_cmd, flat_variable, value)
575 VALUES (?1, ?2, ?3)
576 ON CONFLICT (flat_root_cmd, flat_variable, value) DO UPDATE SET
577 value = excluded.value
578 RETURNING id, flat_root_cmd, flat_variable, value"#,
579 (root, variable, value),
580 |r| VariableValue::try_from(r),
581 )?;
582 conn.execute(
583 r#"INSERT INTO variable_value_usage (value_id, path, context_json, usage_count)
584 VALUES (?1, ?2, ?3, ?4)
585 ON CONFLICT(value_id, path, context_json) DO UPDATE SET
586 usage_count = excluded.usage_count;
587 "#,
588 (&sug.id, path, &context, usage_count),
589 )?;
590 Ok(sug)
591 })
592 .await
593 .unwrap()
594 }
595 }
596}