1use crate::config::get_data_dir;
2use anyhow::{Context, Result};
3use cmdhub_shared::{
4 AciCommandContract, DbAciRecord, CREATE_APPS_FTS_TABLE, CREATE_APPS_TABLE,
5 CREATE_ARGUMENTS_TABLE, CREATE_COMMANDS_VEC_TABLE,
6};
7use rusqlite::Connection;
8use std::path::PathBuf;
9
10pub fn resolve_db_path() -> PathBuf {
11 get_data_dir().join("cmdhub.db")
12}
13
14static EMBEDDED_STARTER_DB_ZST: &[u8] = include_bytes!("../assets/starter.db.zst");
19
20fn db_is_empty(path: &std::path::Path) -> bool {
25 if !path.exists() {
26 return true;
27 }
28 match Connection::open(path) {
29 Ok(c) => match c.query_row("SELECT count(*) FROM apps", [], |r| r.get::<_, i64>(0)) {
30 Ok(n) => n == 0, Err(_) => false, },
33 Err(_) => false, }
35}
36
37pub fn hydrate_starter_if_empty() -> Result<()> {
41 if std::env::var_os("CMDH_NO_STARTER").is_some() {
44 return Ok(());
45 }
46 let db_path = resolve_db_path();
47 if !db_is_empty(&db_path) {
48 return Ok(());
49 }
50 if let Some(parent) = db_path.parent() {
51 std::fs::create_dir_all(parent)
52 .context("Failed to create database parent directories")?;
53 }
54 for ext in ["-wal", "-shm"] {
56 let p = db_path.with_extension(format!("db{ext}"));
57 let _ = std::fs::remove_file(&p);
58 }
59 let decompressed = zstd::decode_all(EMBEDDED_STARTER_DB_ZST)
60 .context("Failed to decompress embedded starter database")?;
61 std::fs::write(&db_path, &decompressed)
62 .context("Failed to write embedded starter database")?;
63 eprintln!(
64 "Seeded local registry from the built-in starter set. Run `cmdh update` for the full catalog."
65 );
66 Ok(())
67}
68
69pub fn open_db() -> Result<Connection> {
70 let db_path = resolve_db_path();
71 if let Some(parent) = db_path.parent() {
72 std::fs::create_dir_all(parent).context("Failed to create database parent directories")?;
73 }
74
75 unsafe {
76 type SqliteVecInitFn = unsafe extern "C" fn();
77 let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
78 #[allow(clippy::missing_transmute_annotations)]
79 let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
80 }
81
82 let conn = Connection::open(&db_path).context("Failed to open SQLite database file")?;
83 let _ = conn.execute("PRAGMA journal_mode = WAL;", []);
84 let _ = conn.execute("PRAGMA synchronous = NORMAL;", []);
85 let _ = conn.execute("PRAGMA foreign_keys = ON;", []);
86 Ok(conn)
87}
88
89pub fn init_db(conn: &Connection) -> Result<()> {
90 conn.execute(CREATE_APPS_TABLE, [])
91 .context("Failed to create apps table")?;
92 conn.execute(CREATE_ARGUMENTS_TABLE, [])
93 .context("Failed to create arguments table")?;
94 conn.execute(CREATE_APPS_FTS_TABLE, [])
95 .context("Failed to create apps_fts table")?;
96
97 if let Err(e) = conn.execute(CREATE_COMMANDS_VEC_TABLE, []) {
99 eprintln!("Warning: Failed to initialize sqlite-vec commands_vec table: {}. Falling back to FTS5 search.", e);
100 }
101
102 conn.execute(
103 "CREATE TABLE IF NOT EXISTS sync_meta (
104 key TEXT PRIMARY KEY,
105 value TEXT NOT NULL
106 );",
107 [],
108 )
109 .context("Failed to create sync_meta table")?;
110
111 Ok(())
112}
113
114fn concept_synonyms(token: &str) -> &'static [&'static str] {
119 match token {
120 "networking" | "network" => &["vpc", "subnet", "gateway", "route", "firewall"],
121 "firewall" => &["security", "firewall", "acl"],
122 "storage" => &["bucket", "volume", "disk", "blob"],
123 "database" | "db" => &["database", "sql", "table", "rds"],
124 "serverless" => &["lambda", "function", "faas"],
125 "container" | "containers" => &["container", "image", "pod"],
126 "kubernetes" | "k8s" => &["pod", "deployment", "namespace", "cluster"],
127 "secret" | "secrets" => &["secret", "credential", "key", "vault"],
128 "dns" => &["dns", "domain", "record", "zone"],
129 "delete" | "erase" => &["remove", "unlink", "trash"],
131 "remove" => &["delete", "unlink"],
132 "clear" | "clean" | "cleanup" | "purge" => &["prune", "remove", "rm", "delete", "unused"],
134 "prune" => &["clean", "remove", "delete", "unused"],
135 "view" | "read" => &["show", "display"],
136 "deploy" | "deployment" => &["apply", "install"],
137 "history" => &["log", "commits"],
138 "cat" => &["bat", "less", "pager"],
139 "fuzzy" => &["fzf", "skim", "finder"],
140 "finder" => &["find", "fd"],
141 "download" => &["curl", "wget"],
142 "diff" => &["delta", "difft"],
143 "grep" => &["ripgrep", "rg"],
144 _ => &[],
145 }
146}
147
148fn tool_alias_synonyms(token: &str) -> &'static [&'static str] {
154 match token {
155 "fuzzy" => &["fzf", "skim"],
156 "finder" => &["fzf", "fd"],
157 "download" => &["curl", "wget", "aria2"],
158 "diff" => &["delta", "difft"],
159 "grep" => &["ripgrep", "rg"],
160 _ => &[],
161 }
162}
163
164fn preprocess_query(query: &str, use_and: bool) -> String {
165 let stop_words: std::collections::HashSet<&str> = [
166 "how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
167 "or", "from", "my", "your", "our", "me", "us",
168 ]
169 .iter()
170 .cloned()
171 .collect();
172
173 let base: Vec<String> = query
174 .split(|c: char| !c.is_alphanumeric() && c != '_')
175 .filter(|w| !w.is_empty())
176 .map(|w| w.to_lowercase())
177 .filter(|w| !stop_words.contains(w.as_str()))
178 .collect();
179
180 let mut terms: Vec<String> = base.iter().map(|w| format!("{}*", w)).collect();
181
182 if !use_and {
185 let mut seen: std::collections::HashSet<String> = base.iter().cloned().collect();
186 for w in &base {
187 for syn in concept_synonyms(w) {
188 if seen.insert((*syn).to_string()) {
189 terms.push(format!("{}*", syn));
190 }
191 }
192 }
193 }
194
195 if terms.is_empty() {
196 "*".to_string()
197 } else if use_and {
198 terms.join(" ")
199 } else {
200 terms.join(" OR ")
201 }
202}
203
204fn detect_vec_dim(conn: &Connection) -> Option<usize> {
206 let sql: String = conn
207 .query_row(
208 "SELECT sql FROM sqlite_master WHERE type='table' AND name='commands_vec'",
209 [],
210 |row| row.get(0),
211 )
212 .ok()?;
213 let pos = sql.find("float[")?;
214 let rest = &sql[pos + 6..];
215 let end = rest.find(']')?;
216 rest[..end].parse().ok()
217}
218
219pub(crate) fn provenance_expr(conn: &Connection) -> &'static str {
223 let has = conn
224 .query_row(
225 "SELECT 1 FROM pragma_table_info('arguments') WHERE name = 'provenance'",
226 [],
227 |_| Ok(()),
228 )
229 .is_ok();
230 if has {
231 "arg.provenance"
232 } else {
233 "'inferred'"
234 }
235}
236
237pub fn calculate_confidence(lowest_dist: f32, and_match: bool) -> String {
238 let hard = 0.82;
239 let soft = 0.76;
240 if lowest_dist > hard && !and_match {
241 "none".to_string()
242 } else if (soft < lowest_dist && lowest_dist <= hard) || (lowest_dist > hard && and_match) {
243 "low".to_string()
244 } else {
245 "high".to_string()
246 }
247}
248
249pub fn search_cascading(
250 conn: &Connection,
251 query: &str,
252 query_vector: Option<&[f32]>,
253 limit: usize,
254 enable_vector: bool,
255) -> Result<Vec<AciCommandContract>> {
256 let cleaned_query = crate::robustness::preprocess_robustness(query);
257 let and_query = preprocess_query(&cleaned_query, true);
258 let or_query = preprocess_query(&cleaned_query, false);
259 let mut confidence = "high".to_string();
260 let prov = provenance_expr(conn);
261
262 #[allow(unused_assignments)]
267 let mut adapted_query_vector = None;
268 let query_vector: Option<&[f32]> = if enable_vector {
269 if let (Some(q), Some(db_dim)) = (query_vector, detect_vec_dim(conn)) {
270 if q.len() != db_dim {
271 let mut adapted = vec![0.0f32; db_dim];
272 let copy_len = q.len().min(db_dim);
273 adapted[..copy_len].copy_from_slice(&q[..copy_len]);
274 adapted_query_vector = Some(adapted);
275 adapted_query_vector.as_deref()
276 } else {
277 query_vector
278 }
279 } else {
280 query_vector
281 }
282 } else {
283 query_vector
284 };
285
286 let vec_bytes: Option<Vec<u8>> = if enable_vector {
288 query_vector.map(|q_vec| {
289 let mut bytes = Vec::with_capacity(q_vec.len() * 4);
290 for &val in q_vec {
291 bytes.extend_from_slice(&val.to_le_bytes());
292 }
293 bytes
294 })
295 } else {
296 None
297 };
298
299 let mut and_match = false;
300 if and_query != "*" {
301 if let Ok(count) = conn.query_row::<u64, _, _>(
302 "SELECT count(*) FROM apps_fts WHERE apps_fts MATCH :query",
303 rusqlite::named_params! { ":query": &and_query },
304 |row| row.get(0),
305 ) {
306 if count > 0 {
307 and_match = true;
308 }
309 }
310 }
311
312 let processed_query = if and_match {
313 let base_tokens: Vec<String> = cleaned_query
314 .split(|c: char| !c.is_alphanumeric() && c != '_')
315 .filter(|w| !w.is_empty())
316 .map(|w| w.to_lowercase())
317 .collect();
318
319 let mut syn_terms = Vec::new();
320 let mut seen = std::collections::HashSet::new();
321 for w in &base_tokens {
322 seen.insert(w.clone());
323 }
324
325 for w in &base_tokens {
333 for syn in tool_alias_synonyms(w) {
334 let syn_str = (*syn).to_string();
335 if seen.insert(syn_str.clone()) {
336 syn_terms.push(format!("{}*", syn_str));
337 }
338 }
339 }
340
341 if syn_terms.is_empty() {
342 and_query.clone()
343 } else {
344 format!("({}) OR {}", and_query, syn_terms.join(" OR "))
345 }
346 } else {
347 or_query.clone()
348 };
349 let cand_query = processed_query.clone();
350
351 let trimmed_query = query.trim().to_lowercase();
353 let mut exact_stmt = conn.prepare(&format!(
354 "SELECT \
355 arg.app_id, \
356 app.name, \
357 arg.cmd_path, \
358 arg.node_type, \
359 arg.description, \
360 arg.risk_level, \
361 arg.example_template, \
362 app.os_aliases, \
363 app.install_instructions, \
364 app.popularity, \
365 arg.docker_image, \
366 arg.script_url, \
367 arg.source_url, \
368 {prov} \
369 FROM arguments arg \
370 JOIN apps app ON arg.app_id = app.app_id \
371 WHERE LOWER(arg.cmd_path) = :query \
372 OR (LOWER(app.name) = :query AND arg.node_type = 'root') \
373 LIMIT :limit_num"
374 ))?;
375
376 let exact_rows = exact_stmt.query_map(
377 rusqlite::named_params! {
378 ":query": trimmed_query,
379 ":limit_num": limit,
380 },
381 |row| {
382 Ok(DbAciRecord {
383 app_id: row.get(0)?,
384 name: row.get(1)?,
385 cmd_path: row.get(2)?,
386 node_type: row.get(3)?,
387 description: row.get(4)?,
388 risk_level: row.get(5)?,
389 example_template: row.get(6)?,
390 os_aliases: row.get(7)?,
391 install_instructions: row.get(8)?,
392 popularity: row.get(9)?,
393 docker_image: row.get(10)?,
394 script_url: row.get(11)?,
395 source_url: row.get(12)?,
396 provenance: row.get(13)?,
397 })
398 },
399 )?;
400
401 let mut exact_results = Vec::new();
402 for record in exact_rows.flatten() {
403 if let Ok(contract) = AciCommandContract::try_from(record) {
404 exact_results.push(contract);
405 }
406 }
407
408 if let Some(ref vb) = vec_bytes {
415 let lowest_dist: f32 = conn
416 .query_row(
417 "SELECT v.distance \
418 FROM ( \
419 SELECT cmd_path, distance \
420 FROM commands_vec \
421 WHERE embedding MATCH :query_vector AND k = 100 \
422 ) v \
423 JOIN arguments arg ON v.cmd_path = arg.cmd_path \
424 ORDER BY v.distance ASC \
425 LIMIT 1",
426 rusqlite::named_params! { ":query_vector": vb },
427 |row| row.get(0),
428 )
429 .unwrap_or(f32::MAX);
430
431 confidence = calculate_confidence(lowest_dist, and_match);
432 }
433
434 let qtok_n = content_tokens(query).len();
453 let (pw_lin, pw_cube): (f64, f64) = if qtok_n <= 1 {
454 (0.05, 0.0)
455 } else {
456 (0.0, 0.015)
457 };
458 let cold_floor = std::env::var("CMDH_COLD_FLOOR")
465 .ok()
466 .and_then(|s| s.parse::<f64>().ok())
467 .unwrap_or(1.0);
468 let mut top_apps = Vec::new();
469 if let Some(ref vb) = vec_bytes {
470 let mut app_stmt = conn.prepare(
471 "WITH fts_matched AS ( \
472 SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 5.0, 2.0) ASC) as fts_pos \
473 FROM apps_fts WHERE apps_fts MATCH :query LIMIT 300 \
474 ), \
475 fts_ordered AS ( \
476 SELECT arg.app_id, MIN(m.fts_pos) as fts_pos \
477 FROM fts_matched m JOIN arguments arg ON m.cmd_path = arg.cmd_path \
478 GROUP BY arg.app_id \
479 ), \
480 vec_knn AS ( \
481 SELECT cmd_path, distance FROM commands_vec \
482 WHERE embedding MATCH :query_vector AND k = 200 \
483 ), \
484 vec_rank AS ( \
485 SELECT arg.app_id, row_number() OVER (ORDER BY vk.distance ASC) as vec_pos \
486 FROM vec_knn vk JOIN arguments arg ON vk.cmd_path = arg.cmd_path \
487 WHERE arg.node_type = 'root' \
488 ), \
489 pre_scored AS ( \
490 SELECT \
491 COALESCE(fts.app_id, vec.app_id) as app_id, \
492 fts.fts_pos as fts_pos, vec.vec_pos as vec_pos \
493 FROM (SELECT app_id FROM fts_ordered UNION SELECT app_id FROM vec_rank) u \
494 LEFT JOIN fts_ordered fts ON u.app_id = fts.app_id \
495 LEFT JOIN vec_rank vec ON u.app_id = vec.app_id \
496 ), \
497 pop_ranked AS ( \
498 SELECT ps.app_id, ps.fts_pos, ps.vec_pos, a.name as nm, \
499 COALESCE(a.popularity, 0.0) as pop, \
500 row_number() OVER (ORDER BY COALESCE(a.popularity, 0.0) DESC) as pop_pos \
501 FROM pre_scored ps JOIN apps a ON ps.app_id = a.app_id \
502 ), \
503 scored AS ( \
504 SELECT app_id, nm, \
505 COALESCE((:cold_floor + (1.0 - :cold_floor) * pop) * 1.0 / (60.0 + fts_pos), 0.0) \
506 + COALESCE(1.0 / (60.0 + vec_pos), 0.0) \
507 + :pw_lin * pop + :pw_cube * pop * pop * pop as rrf_score \
508 FROM pop_ranked \
509 ), \
510 name_deduped AS ( \
511 SELECT app_id, rrf_score, \
512 row_number() OVER (PARTITION BY nm ORDER BY rrf_score DESC) as rn \
513 FROM scored \
514 ) \
515 SELECT app_id FROM name_deduped WHERE rn = 1 ORDER BY rrf_score DESC LIMIT 5"
516 )?;
517
518 let app_rows = app_stmt.query_map(
519 rusqlite::named_params! {
520 ":query": &cand_query,
521 ":query_vector": vb,
522 ":pw_lin": pw_lin,
523 ":pw_cube": pw_cube,
524 ":cold_floor": cold_floor,
525 },
526 |row| row.get::<_, String>(0),
527 )?;
528
529 for app_id in app_rows.flatten() {
530 top_apps.push(app_id);
531 }
532 } else {
533 let mut app_stmt = conn.prepare(
534 "WITH fts_matched AS ( \
535 SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 5.0, 2.0) ASC) as fts_pos \
536 FROM apps_fts WHERE apps_fts MATCH :query LIMIT 300 \
537 ), \
538 fts_ordered AS ( \
539 SELECT arg.app_id, MIN(m.fts_pos) as fts_pos \
540 FROM fts_matched m JOIN arguments arg ON m.cmd_path = arg.cmd_path \
541 GROUP BY arg.app_id \
542 ), \
543 pop_ranked AS ( \
544 SELECT ftso.app_id, ftso.fts_pos, a.name as nm, \
545 COALESCE(a.popularity, 0.0) as pop, \
546 row_number() OVER (ORDER BY COALESCE(a.popularity, 0.0) DESC) as pop_pos \
547 FROM fts_ordered ftso JOIN apps a ON ftso.app_id = a.app_id \
548 ), \
549 scored AS ( \
550 SELECT app_id, nm, \
551 COALESCE((:cold_floor + (1.0 - :cold_floor) * pop) * 1.0 / (60.0 + fts_pos), 0.0) \
552 + :pw_lin * pop + :pw_cube * pop * pop * pop as rrf_score \
553 FROM pop_ranked \
554 ), \
555 name_deduped AS ( \
556 SELECT app_id, rrf_score, \
557 row_number() OVER (PARTITION BY nm ORDER BY rrf_score DESC) as rn \
558 FROM scored \
559 ) \
560 SELECT app_id FROM name_deduped WHERE rn = 1 ORDER BY rrf_score DESC LIMIT 5"
561 )?;
562
563 let app_rows = app_stmt.query_map(
564 rusqlite::named_params! {
565 ":query": &cand_query,
566 ":pw_lin": pw_lin,
567 ":pw_cube": pw_cube,
568 ":cold_floor": cold_floor,
569 },
570 |row| row.get::<_, String>(0),
571 )?;
572
573 for app_id in app_rows.flatten() {
574 top_apps.push(app_id);
575 }
576 }
577
578 if top_apps.is_empty() {
579 return Ok(exact_results);
580 }
581
582 if processed_query != "*" {
587 let mut fts_only_stmt = conn.prepare(
588 "WITH fts_matched AS ( \
589 SELECT cmd_path FROM apps_fts WHERE apps_fts MATCH :query LIMIT 100 \
590 ) \
591 SELECT DISTINCT arg.app_id \
592 FROM fts_matched m JOIN arguments arg ON m.cmd_path = arg.cmd_path \
593 LIMIT 5",
594 )?;
595 let fts_app_rows = fts_only_stmt
596 .query_map(rusqlite::named_params! { ":query": &cand_query }, |row| {
597 row.get::<_, String>(0)
598 })?;
599 for app_id in fts_app_rows.flatten() {
600 if !top_apps.contains(&app_id) {
601 top_apps.push(app_id);
602 }
603 }
604 top_apps.truncate(8); }
606
607 while top_apps.len() < 8 {
609 top_apps.push(top_apps[0].clone());
610 }
611
612 let mut pop_map: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
615 {
616 let mut uniq: Vec<&String> = top_apps.iter().collect();
617 uniq.sort();
618 uniq.dedup();
619 let placeholders = uniq.iter().map(|_| "?").collect::<Vec<_>>().join(",");
620 if let Ok(mut pstmt) = conn.prepare(&format!(
621 "SELECT app_id, COALESCE(popularity, 0.0) FROM apps WHERE app_id IN ({placeholders})"
622 )) {
623 let params = rusqlite::params_from_iter(uniq.iter().map(|s| s.as_str()));
624 if let Ok(rows) = pstmt.query_map(params, |r| {
625 Ok((r.get::<_, String>(0)?, r.get::<_, f64>(1)?))
626 }) {
627 for kv in rows.flatten() {
628 pop_map.insert(kv.0, kv.1);
629 }
630 }
631 }
632 }
633
634 let pool = std::cmp::max(limit, 30);
638 let mut results = Vec::new();
639 if let Some(ref vb) = vec_bytes {
640 let mut stmt = conn.prepare(&format!(
641 "WITH fts_rank AS ( \
642 SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC) as fts_pos \
643 FROM apps_fts WHERE apps_fts MATCH :query \
644 LIMIT 100 \
645 ), \
646 vec_rank AS ( \
647 SELECT cmd_path, row_number() OVER (ORDER BY distance ASC) as vec_pos \
648 FROM commands_vec \
649 WHERE embedding MATCH :query_vector AND k = 100 \
650 ) \
651 SELECT \
652 arg.app_id, \
653 app.name, \
654 arg.cmd_path, \
655 arg.node_type, \
656 arg.description, \
657 arg.risk_level, \
658 arg.example_template, \
659 app.os_aliases, \
660 app.install_instructions, \
661 app.popularity, \
662 arg.docker_image, \
663 arg.script_url, \
664 arg.source_url, \
665 {prov} \
666 FROM arguments arg \
667 JOIN apps app ON arg.app_id = app.app_id \
668 LEFT JOIN fts_rank fts ON arg.cmd_path = fts.cmd_path \
669 LEFT JOIN vec_rank vec ON arg.cmd_path = vec.cmd_path \
670 WHERE (fts.cmd_path IS NOT NULL OR vec.cmd_path IS NOT NULL) \
671 AND arg.app_id IN (:app1, :app2, :app3, :app4, :app5, :app6, :app7, :app8) \
672 ORDER BY COALESCE(1.0 / (60.0 + fts.fts_pos), 0.0) + COALESCE(1.0 / (60.0 + vec.vec_pos), 0.0) DESC \
673 LIMIT :limit_num"
674 ))?;
675
676 let rows = stmt.query_map(
677 rusqlite::named_params! {
678 ":query": &processed_query,
679 ":query_vector": vb,
680 ":app1": &top_apps[0],
681 ":app2": &top_apps[1],
682 ":app3": &top_apps[2],
683 ":app4": &top_apps[3],
684 ":app5": &top_apps[4],
685 ":app6": &top_apps[5],
686 ":app7": &top_apps[6],
687 ":app8": &top_apps[7],
688 ":limit_num": pool,
689 },
690 |row| {
691 Ok(DbAciRecord {
692 app_id: row.get(0)?,
693 name: row.get(1)?,
694 cmd_path: row.get(2)?,
695 node_type: row.get(3)?,
696 description: row.get(4)?,
697 risk_level: row.get(5)?,
698 example_template: row.get(6)?,
699 os_aliases: row.get(7)?,
700 install_instructions: row.get(8)?,
701 popularity: row.get(9)?,
702 docker_image: row.get(10)?,
703 script_url: row.get(11)?,
704 source_url: row.get(12)?,
705 provenance: row.get(13)?,
706 })
707 },
708 )?;
709
710 for r in rows {
711 let record = r?;
712 if let Ok(contract) = AciCommandContract::try_from(record) {
713 results.push(contract);
714 }
715 }
716 } else {
717 let mut stmt = conn.prepare(&format!(
718 "SELECT \
719 arg.app_id, \
720 app.name, \
721 arg.cmd_path, \
722 arg.node_type, \
723 arg.description, \
724 arg.risk_level, \
725 arg.example_template, \
726 app.os_aliases, \
727 app.install_instructions, \
728 app.popularity, \
729 arg.docker_image, \
730 arg.script_url, \
731 arg.source_url, \
732 {prov} \
733 FROM arguments arg \
734 JOIN apps app ON arg.app_id = app.app_id \
735 JOIN apps_fts fts ON arg.cmd_path = fts.cmd_path \
736 WHERE apps_fts MATCH :query \
737 AND arg.app_id IN (:app1, :app2, :app3, :app4, :app5, :app6, :app7, :app8) \
738 ORDER BY bm25(apps_fts, 0.0, 5.0, 2.0) ASC \
739 LIMIT :limit_num"
740 ))?;
741
742 let rows = stmt.query_map(
743 rusqlite::named_params! {
744 ":query": &processed_query,
745 ":app1": &top_apps[0],
746 ":app2": &top_apps[1],
747 ":app3": &top_apps[2],
748 ":app4": &top_apps[3],
749 ":app5": &top_apps[4],
750 ":app6": &top_apps[5],
751 ":app7": &top_apps[6],
752 ":app8": &top_apps[7],
753 ":limit_num": pool,
754 },
755 |row| {
756 Ok(DbAciRecord {
757 app_id: row.get(0)?,
758 name: row.get(1)?,
759 cmd_path: row.get(2)?,
760 node_type: row.get(3)?,
761 description: row.get(4)?,
762 risk_level: row.get(5)?,
763 example_template: row.get(6)?,
764 os_aliases: row.get(7)?,
765 install_instructions: row.get(8)?,
766 popularity: row.get(9)?,
767 docker_image: row.get(10)?,
768 script_url: row.get(11)?,
769 source_url: row.get(12)?,
770 provenance: row.get(13)?,
771 })
772 },
773 )?;
774
775 for r in rows {
776 let record = r?;
777 if let Ok(contract) = AciCommandContract::try_from(record) {
778 results.push(contract);
779 }
780 }
781 }
782
783 let q_tokens = content_tokens(query);
789 let q_path_tokens = expand_for_path_match(&q_tokens);
794 if !q_tokens.is_empty() && results.len() > 1 {
795 let n = results.len() as i32;
796 let path_w = if q_tokens.len() >= 2 { 4 } else { 1 };
803 let pop_bonus_w = if q_tokens.len() <= 1 { 15.0 } else { 3.0 };
806 let mut scored: Vec<(i32, usize, AciCommandContract)> = results
807 .drain(..)
808 .enumerate()
809 .map(|(i, c)| {
810 let rrf = n - i as i32; let pop_bonus =
812 (pop_bonus_w * pop_map.get(&c.app_id).copied().unwrap_or(0.0)) as i32;
813 let root_bonus = if matches!(c.node_type, cmdhub_shared::NodeType::Root)
815 && q_tokens.len() <= 1
816 {
817 20
818 } else {
819 0
820 };
821 let verified_bonus = if c.verified { VERIFIED_BONUS } else { 0 };
825 let composite = rrf
826 + path_w * path_match_score(&c.cmd_path, &q_path_tokens)
827 + pop_bonus
828 + root_bonus
829 + verified_bonus;
830 (composite, i, c)
831 })
832 .collect();
833 scored.sort_by(|a, b| b.0.cmp(&a.0).then(a.1.cmp(&b.1)));
835 results = scored.into_iter().map(|(_, _, c)| c).collect();
836 }
837
838 let mut final_results = exact_results.clone();
839 final_results.append(&mut results);
840
841 const PER_APP_CAP: usize = 3;
847 let mut seen_paths = std::collections::HashSet::new();
848 let mut per_app: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
849 final_results.retain(|r| {
850 if !seen_paths.insert(r.cmd_path.clone()) {
851 return false;
852 }
853 let n = per_app.entry(r.app_id.clone()).or_insert(0);
854 *n += 1;
855 *n <= PER_APP_CAP
856 });
857
858 final_results.truncate(limit);
859 for r in &mut final_results {
860 r.confidence = confidence.clone();
861 }
862 Ok(final_results)
863}
864
865const VERIFIED_BONUS: i32 = 3;
869
870fn expand_for_path_match(
876 tokens: &std::collections::HashSet<String>,
877) -> std::collections::HashSet<String> {
878 let mut out: std::collections::HashSet<String> = std::collections::HashSet::new();
879 for t in tokens {
880 out.insert(t.clone());
881 if let Some(stem) = t.strip_suffix('s') {
882 if stem.len() >= 3 {
883 out.insert(stem.to_string());
884 }
885 }
886 for syn in concept_synonyms(t) {
887 out.insert((*syn).to_string());
888 }
889 }
890 out
891}
892
893fn content_tokens(query: &str) -> std::collections::HashSet<String> {
895 let stop: std::collections::HashSet<&str> = [
899 "how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
900 "or", "from", "my", "your", "our", "me", "us", "i", "want", "know", "using", "use", "do",
901 "can", "get", "please", "help", "show", "view",
902 ]
903 .iter()
904 .cloned()
905 .collect();
906 query
907 .split(|c: char| !c.is_alphanumeric() && c != '_')
908 .filter(|w| !w.is_empty())
909 .map(|w| w.to_lowercase())
910 .filter(|w| !stop.contains(w.as_str()))
911 .collect()
912}
913
914fn path_match_score(cmd_path: &str, q_tokens: &std::collections::HashSet<String>) -> i32 {
919 if !cmd_path.contains('.') {
923 return 0;
924 }
925 let after_binary = cmd_path.split_once('.').map(|x| x.1).unwrap_or(cmd_path);
927 let tokens: Vec<String> = after_binary
928 .split(|c: char| !c.is_alphanumeric() && c != '_')
929 .filter(|w| !w.is_empty())
930 .map(|w| w.to_lowercase())
931 .collect();
932 if tokens.is_empty() {
933 return 0;
934 }
935 let overlap = tokens
938 .iter()
939 .filter(|t| {
940 q_tokens.contains(*t)
941 || t.strip_suffix('s')
942 .is_some_and(|s| s.len() >= 3 && q_tokens.contains(s))
943 })
944 .count() as i32;
945 let extra = tokens.len() as i32 - overlap;
946 (3 * overlap - extra).max(0)
949}
950
951pub fn search_commands(
952 conn: &Connection,
953 query: &str,
954 query_vector: Option<&[f32]>,
955 limit: usize,
956) -> Result<Vec<AciCommandContract>> {
957 let mut has_vector_db = false;
958 if query_vector.is_some() {
959 if let Ok(count) = conn.query_row::<u64, _, _>(
960 "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='commands_vec'",
961 [],
962 |row| row.get(0),
963 ) {
964 if count > 0 {
965 if let Ok(vec_count) =
966 conn.query_row::<u64, _, _>("SELECT count(*) FROM commands_vec", [], |row| {
967 row.get(0)
968 })
969 {
970 if vec_count > 0 {
971 has_vector_db = true;
972 }
973 }
974 }
975 }
976 }
977
978 search_cascading(conn, query, query_vector, limit, has_vector_db)
979}
980
981pub fn search_all(
982 conn: &Connection,
983 query: &str,
984 query_vector: Option<&[f32]>,
985 limit: usize,
986) -> Result<Vec<AciCommandContract>> {
987 let mut results = search_commands(conn, query, query_vector, limit)?;
988
989 let config_dir = crate::config::get_config_dir();
990 let skills_dir = config_dir.join("skills");
991 let local_skill = cmdhub_skills::LocalFileSkill::new(skills_dir);
992
993 let mut registry = cmdhub_skills::SkillRegistry::new();
994 registry.register(Box::new(local_skill));
995
996 if let Ok(mut skill_results) = registry.resolve(query) {
997 results.append(&mut skill_results);
998 }
999
1000 let mut seen = std::collections::HashSet::new();
1001 results.retain(|item| seen.insert(item.cmd_path.clone()));
1002 results.truncate(limit);
1003
1004 Ok(results)
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010
1011 #[test]
1012 fn test_calculate_confidence_mapping() {
1013 assert_eq!(calculate_confidence(0.70, false), "high");
1014 assert_eq!(calculate_confidence(0.75, true), "high");
1015 assert_eq!(calculate_confidence(0.78, false), "low");
1016 assert_eq!(calculate_confidence(0.82, false), "low");
1017 assert_eq!(calculate_confidence(0.85, true), "low");
1018 assert_eq!(calculate_confidence(0.83, false), "none");
1019 assert_eq!(calculate_confidence(0.90, false), "none");
1020 }
1021
1022 #[test]
1023 fn test_exact_match_priority() {
1024 let conn = Connection::open_in_memory().unwrap();
1025 init_db(&conn).unwrap();
1026
1027 conn.execute(
1028 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
1029 ("org.test.git", "git", "{}"),
1030 )
1031 .unwrap();
1032
1033 conn.execute(
1034 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
1035 VALUES (?, ?, ?, ?, ?, ?, ?)",
1036 ("git", "org.test.git", "git", "root", "Git version control", "safe", "git"),
1037 )
1038 .unwrap();
1039
1040 conn.execute(
1041 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
1042 ("git", "git", "Git version control"),
1043 )
1044 .unwrap();
1045
1046 conn.execute(
1047 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
1048 ("org.test.gitleaks", "gitleaks", "{}"),
1049 )
1050 .unwrap();
1051
1052 conn.execute(
1053 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
1054 VALUES (?, ?, ?, ?, ?, ?, ?)",
1055 ("gitleaks", "org.test.gitleaks", "gitleaks", "root", "Detect secrets in git", "safe", "gitleaks"),
1056 )
1057 .unwrap();
1058
1059 conn.execute(
1060 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
1061 ("gitleaks", "gitleaks", "Detect secrets in git"),
1062 )
1063 .unwrap();
1064
1065 let res = search_commands(&conn, "git", None, 10).unwrap();
1066 assert!(!res.is_empty());
1067 assert_eq!(res[0].cmd_path, "git");
1068 }
1069
1070 #[test]
1071 fn test_fts_fallback_and_or() {
1072 let conn = Connection::open_in_memory().unwrap();
1073 init_db(&conn).unwrap();
1074
1075 conn.execute(
1077 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
1078 ("org.test.rm", "rm", "{}"),
1079 )
1080 .unwrap();
1081 conn.execute(
1082 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
1083 VALUES (?, ?, ?, ?, ?, ?, ?)",
1084 ("rm", "org.test.rm", "rm", "root", "delete local files", "safe", "rm"),
1085 )
1086 .unwrap();
1087 conn.execute(
1088 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
1089 ("rm", "rm", "delete local files"),
1090 )
1091 .unwrap();
1092
1093 let res = search_commands(&conn, "delete local files", None, 10).unwrap();
1095 assert!(!res.is_empty());
1096 assert_eq!(res[0].cmd_path, "rm");
1097
1098 let res = search_commands(&conn, "delete my local files", None, 10).unwrap();
1100 assert!(!res.is_empty());
1101 assert_eq!(res[0].cmd_path, "rm");
1102
1103 let res = search_commands(&conn, "delete missing files", None, 10).unwrap();
1106 assert!(!res.is_empty());
1107 assert_eq!(res[0].cmd_path, "rm");
1108 }
1109
1110 #[test]
1111 fn test_hybrid_search_knn_match() {
1112 unsafe {
1113 type SqliteVecInitFn = unsafe extern "C" fn();
1114 let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
1115 #[allow(clippy::missing_transmute_annotations)]
1116 let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
1117 }
1118
1119 let conn = Connection::open_in_memory().unwrap();
1120 init_db(&conn).unwrap();
1121
1122 conn.execute(
1123 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
1124 ("org.test.knn", "knn", "{}"),
1125 )
1126 .unwrap();
1127 conn.execute(
1128 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
1129 VALUES (?, ?, ?, ?, ?, ?, ?)",
1130 ("knn", "org.test.knn", "knn", "root", "vector search helper", "safe", "knn"),
1131 )
1132 .unwrap();
1133 conn.execute(
1134 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
1135 ("knn", "knn", "vector search helper"),
1136 )
1137 .unwrap();
1138
1139 let v = vec![0.1f32; 384];
1141 let mut v_bytes = Vec::with_capacity(384 * 4);
1142 for &val in &v {
1143 v_bytes.extend_from_slice(&val.to_le_bytes());
1144 }
1145
1146 conn.execute(
1147 "INSERT INTO commands_vec (cmd_path, embedding) VALUES (?, ?)",
1148 ("knn", v_bytes),
1149 )
1150 .unwrap();
1151
1152 let query_vec = vec![0.12f32; 384];
1154 let res = search_commands(&conn, "missing_term", Some(&query_vec), 10).unwrap();
1155 assert!(!res.is_empty());
1156 assert_eq!(res[0].cmd_path, "knn");
1157 }
1158
1159 #[test]
1160 fn test_clear_maps_to_prune_synonyms() {
1161 assert!(concept_synonyms("clear").contains(&"prune"));
1162 assert!(concept_synonyms("clean").contains(&"prune"));
1163 assert!(concept_synonyms("purge").contains(&"prune"));
1164 assert!(concept_synonyms("prune").contains(&"unused"));
1165 assert!(concept_synonyms("fuzzy").contains(&"fzf"));
1166 assert!(concept_synonyms("finder").contains(&"fd"));
1167 assert!(concept_synonyms("download").contains(&"curl"));
1168 assert!(concept_synonyms("diff").contains(&"delta"));
1169 assert!(concept_synonyms("grep").contains(&"ripgrep"));
1170 }
1171
1172 #[test]
1173 fn test_tool_alias_synonyms_are_tool_names_only_not_generic_concepts() {
1174 assert!(tool_alias_synonyms("fuzzy").contains(&"fzf"));
1179 assert!(tool_alias_synonyms("grep").contains(&"rg"));
1180 assert!(tool_alias_synonyms("download").contains(&"curl"));
1181 assert!(tool_alias_synonyms("kubernetes").is_empty());
1183 assert!(tool_alias_synonyms("view").is_empty());
1184 assert!(tool_alias_synonyms("clear").is_empty());
1185 }
1186
1187 #[test]
1188 fn test_expand_for_path_match_adds_synonyms_and_singulars() {
1189 let tokens: std::collections::HashSet<String> =
1190 ["clear", "images"].iter().map(|s| s.to_string()).collect();
1191 let expanded = expand_for_path_match(&tokens);
1192 assert!(expanded.contains("prune")); assert!(expanded.contains("image")); assert!(expanded.contains("clear")); }
1196
1197 #[test]
1198 fn test_old_schema_db_without_provenance_still_works() {
1199 let conn = Connection::open_in_memory().unwrap();
1202 conn.execute_batch(
1203 "CREATE TABLE apps (app_id TEXT PRIMARY KEY, name TEXT NOT NULL, os_aliases TEXT, \
1204 install_instructions TEXT, popularity REAL DEFAULT 0.0); \
1205 CREATE TABLE arguments (cmd_path TEXT PRIMARY KEY, app_id TEXT NOT NULL, \
1206 node_name TEXT NOT NULL, node_type TEXT NOT NULL, description TEXT NOT NULL, \
1207 risk_level TEXT NOT NULL, example_template TEXT, docker_image TEXT, \
1208 script_url TEXT, source_url TEXT); \
1209 CREATE VIRTUAL TABLE apps_fts USING fts5(cmd_path UNINDEXED, name, capabilities); \
1210 INSERT INTO apps (app_id, name) VALUES ('org.test.tar', 'tar'); \
1211 INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level) \
1212 VALUES ('tar', 'org.test.tar', 'tar', 'root', 'archive files', 'safe'); \
1213 INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES ('tar', 'tar', 'archive files');",
1214 )
1215 .unwrap();
1216
1217 let res = search_commands(&conn, "tar", None, 5).unwrap();
1218 assert!(!res.is_empty());
1219 assert_eq!(res[0].cmd_path, "tar");
1220 assert!(!res[0].verified); }
1222
1223 #[test]
1224 fn test_probe_verified_outranks_inferred_twin() {
1225 let conn = Connection::open_in_memory().unwrap();
1228 init_db(&conn).unwrap();
1229 for (app, prov) in [
1230 ("org.inferred.tool", "inferred"),
1231 ("org.probed.tool", "probe"),
1232 ] {
1233 let name = if prov == "probe" { "toolp" } else { "tooli" };
1234 conn.execute(
1235 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, '{}')",
1236 (app, name),
1237 )
1238 .unwrap();
1239 let path = format!("{}.image.prune", name);
1240 conn.execute(
1241 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, \
1242 risk_level, provenance) VALUES (?, ?, 'prune', 'sub', \
1243 'Remove unused container images to free disk space', 'dangerous', ?)",
1244 (&path, app, prov),
1245 )
1246 .unwrap();
1247 conn.execute(
1248 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, \
1249 'Remove unused container images to free disk space')",
1250 (&path, name),
1251 )
1252 .unwrap();
1253 }
1254
1255 let res = search_cascading(&conn, "clear unused images", None, 5, false).unwrap();
1256 assert!(res.len() >= 2, "expected both twins, got {}", res.len());
1257 assert!(res[0].verified, "probe-verified twin must rank first");
1258 assert!(res[0].cmd_path.starts_with("toolp"));
1259 assert!(!res[1].verified);
1260 }
1261}