use crate::config::get_data_dir;
use anyhow::{Context, Result};
use cmdhub_shared::{
AciCommandContract, DbAciRecord, CREATE_APPS_FTS_TABLE, CREATE_APPS_TABLE,
CREATE_ARGUMENTS_TABLE, CREATE_COMMANDS_VEC_TABLE,
};
use rusqlite::Connection;
use std::path::PathBuf;
pub fn resolve_db_path() -> PathBuf {
get_data_dir().join("cmdhub.db")
}
static EMBEDDED_STARTER_DB_ZST: &[u8] = include_bytes!("../assets/starter.db.zst");
fn db_is_empty(path: &std::path::Path) -> bool {
if !path.exists() {
return true;
}
match Connection::open(path) {
Ok(c) => match c.query_row("SELECT count(*) FROM apps", [], |r| r.get::<_, i64>(0)) {
Ok(n) => n == 0, Err(_) => false, },
Err(_) => false, }
}
pub fn hydrate_starter_if_empty() -> Result<()> {
if std::env::var_os("CMDH_NO_STARTER").is_some() {
return Ok(());
}
let db_path = resolve_db_path();
if !db_is_empty(&db_path) {
return Ok(());
}
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)
.context("Failed to create database parent directories")?;
}
for ext in ["-wal", "-shm"] {
let p = db_path.with_extension(format!("db{ext}"));
let _ = std::fs::remove_file(&p);
}
let decompressed = zstd::decode_all(EMBEDDED_STARTER_DB_ZST)
.context("Failed to decompress embedded starter database")?;
std::fs::write(&db_path, &decompressed)
.context("Failed to write embedded starter database")?;
eprintln!(
"Seeded local registry from the built-in starter set. Run `cmdh update` for the full catalog."
);
Ok(())
}
pub fn open_db() -> Result<Connection> {
let db_path = resolve_db_path();
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent).context("Failed to create database parent directories")?;
}
unsafe {
type SqliteVecInitFn = unsafe extern "C" fn();
let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
#[allow(clippy::missing_transmute_annotations)]
let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
}
let conn = Connection::open(&db_path).context("Failed to open SQLite database file")?;
let _ = conn.execute("PRAGMA journal_mode = WAL;", []);
let _ = conn.execute("PRAGMA synchronous = NORMAL;", []);
let _ = conn.execute("PRAGMA foreign_keys = ON;", []);
Ok(conn)
}
pub fn init_db(conn: &Connection) -> Result<()> {
conn.execute(CREATE_APPS_TABLE, [])
.context("Failed to create apps table")?;
conn.execute(CREATE_ARGUMENTS_TABLE, [])
.context("Failed to create arguments table")?;
conn.execute(CREATE_APPS_FTS_TABLE, [])
.context("Failed to create apps_fts table")?;
if let Err(e) = conn.execute(CREATE_COMMANDS_VEC_TABLE, []) {
eprintln!("Warning: Failed to initialize sqlite-vec commands_vec table: {}. Falling back to FTS5 search.", e);
}
conn.execute(
"CREATE TABLE IF NOT EXISTS sync_meta (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);",
[],
)
.context("Failed to create sync_meta table")?;
Ok(())
}
fn concept_synonyms(token: &str) -> &'static [&'static str] {
match token {
"networking" | "network" => &["vpc", "subnet", "gateway", "route", "firewall"],
"firewall" => &["security", "firewall", "acl"],
"storage" => &["bucket", "volume", "disk", "blob"],
"database" | "db" => &["database", "sql", "table", "rds"],
"serverless" => &["lambda", "function", "faas"],
"container" | "containers" => &["container", "image", "pod"],
"kubernetes" | "k8s" => &["pod", "deployment", "namespace", "cluster"],
"secret" | "secrets" => &["secret", "credential", "key", "vault"],
"dns" => &["dns", "domain", "record", "zone"],
"delete" | "erase" => &["remove", "unlink", "trash"],
"remove" => &["delete", "unlink"],
"clear" | "clean" | "cleanup" | "purge" => &["prune", "remove", "rm", "delete", "unused"],
"prune" => &["clean", "remove", "delete", "unused"],
"view" | "read" => &["show", "display"],
"deploy" | "deployment" => &["apply", "install"],
"history" => &["log", "commits"],
"cat" => &["bat", "less", "pager"],
"fuzzy" => &["fzf", "skim", "finder"],
"finder" => &["find", "fd"],
"download" => &["curl", "wget"],
"diff" => &["delta", "difft"],
"grep" => &["ripgrep", "rg"],
_ => &[],
}
}
fn tool_alias_synonyms(token: &str) -> &'static [&'static str] {
match token {
"fuzzy" => &["fzf", "skim"],
"finder" => &["fzf", "fd"],
"download" => &["curl", "wget", "aria2"],
"diff" => &["delta", "difft"],
"grep" => &["ripgrep", "rg"],
_ => &[],
}
}
fn preprocess_query(query: &str, use_and: bool) -> String {
let stop_words: std::collections::HashSet<&str> = [
"how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
"or", "from", "my", "your", "our", "me", "us",
]
.iter()
.cloned()
.collect();
let base: Vec<String> = query
.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|w| !w.is_empty())
.map(|w| w.to_lowercase())
.filter(|w| !stop_words.contains(w.as_str()))
.collect();
let mut terms: Vec<String> = base.iter().map(|w| format!("{}*", w)).collect();
if !use_and {
let mut seen: std::collections::HashSet<String> = base.iter().cloned().collect();
for w in &base {
for syn in concept_synonyms(w) {
if seen.insert((*syn).to_string()) {
terms.push(format!("{}*", syn));
}
}
}
}
if terms.is_empty() {
"*".to_string()
} else if use_and {
terms.join(" ")
} else {
terms.join(" OR ")
}
}
fn detect_vec_dim(conn: &Connection) -> Option<usize> {
let sql: String = conn
.query_row(
"SELECT sql FROM sqlite_master WHERE type='table' AND name='commands_vec'",
[],
|row| row.get(0),
)
.ok()?;
let pos = sql.find("float[")?;
let rest = &sql[pos + 6..];
let end = rest.find(']')?;
rest[..end].parse().ok()
}
pub(crate) fn provenance_expr(conn: &Connection) -> &'static str {
let has = conn
.query_row(
"SELECT 1 FROM pragma_table_info('arguments') WHERE name = 'provenance'",
[],
|_| Ok(()),
)
.is_ok();
if has {
"arg.provenance"
} else {
"'inferred'"
}
}
pub fn calculate_confidence(lowest_dist: f32, and_match: bool) -> String {
let hard = 0.82;
let soft = 0.76;
if lowest_dist > hard && !and_match {
"none".to_string()
} else if (soft < lowest_dist && lowest_dist <= hard) || (lowest_dist > hard && and_match) {
"low".to_string()
} else {
"high".to_string()
}
}
pub fn search_cascading(
conn: &Connection,
query: &str,
query_vector: Option<&[f32]>,
limit: usize,
enable_vector: bool,
) -> Result<Vec<AciCommandContract>> {
let cleaned_query = crate::robustness::preprocess_robustness(query);
let and_query = preprocess_query(&cleaned_query, true);
let or_query = preprocess_query(&cleaned_query, false);
let mut confidence = "high".to_string();
let prov = provenance_expr(conn);
#[allow(unused_assignments)]
let mut adapted_query_vector = None;
let query_vector: Option<&[f32]> = if enable_vector {
if let (Some(q), Some(db_dim)) = (query_vector, detect_vec_dim(conn)) {
if q.len() != db_dim {
let mut adapted = vec![0.0f32; db_dim];
let copy_len = q.len().min(db_dim);
adapted[..copy_len].copy_from_slice(&q[..copy_len]);
adapted_query_vector = Some(adapted);
adapted_query_vector.as_deref()
} else {
query_vector
}
} else {
query_vector
}
} else {
query_vector
};
let vec_bytes: Option<Vec<u8>> = if enable_vector {
query_vector.map(|q_vec| {
let mut bytes = Vec::with_capacity(q_vec.len() * 4);
for &val in q_vec {
bytes.extend_from_slice(&val.to_le_bytes());
}
bytes
})
} else {
None
};
let mut and_match = false;
if and_query != "*" {
if let Ok(count) = conn.query_row::<u64, _, _>(
"SELECT count(*) FROM apps_fts WHERE apps_fts MATCH :query",
rusqlite::named_params! { ":query": &and_query },
|row| row.get(0),
) {
if count > 0 {
and_match = true;
}
}
}
let processed_query = if and_match {
let base_tokens: Vec<String> = cleaned_query
.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|w| !w.is_empty())
.map(|w| w.to_lowercase())
.collect();
let mut syn_terms = Vec::new();
let mut seen = std::collections::HashSet::new();
for w in &base_tokens {
seen.insert(w.clone());
}
for w in &base_tokens {
for syn in tool_alias_synonyms(w) {
let syn_str = (*syn).to_string();
if seen.insert(syn_str.clone()) {
syn_terms.push(format!("{}*", syn_str));
}
}
}
if syn_terms.is_empty() {
and_query.clone()
} else {
format!("({}) OR {}", and_query, syn_terms.join(" OR "))
}
} else {
or_query.clone()
};
let cand_query = processed_query.clone();
let trimmed_query = query.trim().to_lowercase();
let mut exact_stmt = conn.prepare(&format!(
"SELECT \
arg.app_id, \
app.name, \
arg.cmd_path, \
arg.node_type, \
arg.description, \
arg.risk_level, \
arg.example_template, \
app.os_aliases, \
app.install_instructions, \
app.popularity, \
arg.docker_image, \
arg.script_url, \
arg.source_url, \
{prov} \
FROM arguments arg \
JOIN apps app ON arg.app_id = app.app_id \
WHERE LOWER(arg.cmd_path) = :query \
OR (LOWER(app.name) = :query AND arg.node_type = 'root') \
LIMIT :limit_num"
))?;
let exact_rows = exact_stmt.query_map(
rusqlite::named_params! {
":query": trimmed_query,
":limit_num": limit,
},
|row| {
Ok(DbAciRecord {
app_id: row.get(0)?,
name: row.get(1)?,
cmd_path: row.get(2)?,
node_type: row.get(3)?,
description: row.get(4)?,
risk_level: row.get(5)?,
example_template: row.get(6)?,
os_aliases: row.get(7)?,
install_instructions: row.get(8)?,
popularity: row.get(9)?,
docker_image: row.get(10)?,
script_url: row.get(11)?,
source_url: row.get(12)?,
provenance: row.get(13)?,
})
},
)?;
let mut exact_results = Vec::new();
for record in exact_rows.flatten() {
if let Ok(contract) = AciCommandContract::try_from(record) {
exact_results.push(contract);
}
}
if let Some(ref vb) = vec_bytes {
let lowest_dist: f32 = conn
.query_row(
"SELECT v.distance \
FROM ( \
SELECT cmd_path, distance \
FROM commands_vec \
WHERE embedding MATCH :query_vector AND k = 100 \
) v \
JOIN arguments arg ON v.cmd_path = arg.cmd_path \
ORDER BY v.distance ASC \
LIMIT 1",
rusqlite::named_params! { ":query_vector": vb },
|row| row.get(0),
)
.unwrap_or(f32::MAX);
confidence = calculate_confidence(lowest_dist, and_match);
}
let qtok_n = content_tokens(query).len();
let (pw_lin, pw_cube): (f64, f64) = if qtok_n <= 1 {
(0.05, 0.0)
} else {
(0.0, 0.015)
};
let cold_floor = std::env::var("CMDH_COLD_FLOOR")
.ok()
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(1.0);
let mut top_apps = Vec::new();
if let Some(ref vb) = vec_bytes {
let mut app_stmt = conn.prepare(
"WITH fts_matched AS ( \
SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 5.0, 2.0) ASC) as fts_pos \
FROM apps_fts WHERE apps_fts MATCH :query LIMIT 300 \
), \
fts_ordered AS ( \
SELECT arg.app_id, MIN(m.fts_pos) as fts_pos \
FROM fts_matched m JOIN arguments arg ON m.cmd_path = arg.cmd_path \
GROUP BY arg.app_id \
), \
vec_knn AS ( \
SELECT cmd_path, distance FROM commands_vec \
WHERE embedding MATCH :query_vector AND k = 200 \
), \
vec_rank AS ( \
SELECT arg.app_id, row_number() OVER (ORDER BY vk.distance ASC) as vec_pos \
FROM vec_knn vk JOIN arguments arg ON vk.cmd_path = arg.cmd_path \
WHERE arg.node_type = 'root' \
), \
pre_scored AS ( \
SELECT \
COALESCE(fts.app_id, vec.app_id) as app_id, \
fts.fts_pos as fts_pos, vec.vec_pos as vec_pos \
FROM (SELECT app_id FROM fts_ordered UNION SELECT app_id FROM vec_rank) u \
LEFT JOIN fts_ordered fts ON u.app_id = fts.app_id \
LEFT JOIN vec_rank vec ON u.app_id = vec.app_id \
), \
pop_ranked AS ( \
SELECT ps.app_id, ps.fts_pos, ps.vec_pos, a.name as nm, \
COALESCE(a.popularity, 0.0) as pop, \
row_number() OVER (ORDER BY COALESCE(a.popularity, 0.0) DESC) as pop_pos \
FROM pre_scored ps JOIN apps a ON ps.app_id = a.app_id \
), \
scored AS ( \
SELECT app_id, nm, \
COALESCE((:cold_floor + (1.0 - :cold_floor) * pop) * 1.0 / (60.0 + fts_pos), 0.0) \
+ COALESCE(1.0 / (60.0 + vec_pos), 0.0) \
+ :pw_lin * pop + :pw_cube * pop * pop * pop as rrf_score \
FROM pop_ranked \
), \
name_deduped AS ( \
SELECT app_id, rrf_score, \
row_number() OVER (PARTITION BY nm ORDER BY rrf_score DESC) as rn \
FROM scored \
) \
SELECT app_id FROM name_deduped WHERE rn = 1 ORDER BY rrf_score DESC LIMIT 5"
)?;
let app_rows = app_stmt.query_map(
rusqlite::named_params! {
":query": &cand_query,
":query_vector": vb,
":pw_lin": pw_lin,
":pw_cube": pw_cube,
":cold_floor": cold_floor,
},
|row| row.get::<_, String>(0),
)?;
for app_id in app_rows.flatten() {
top_apps.push(app_id);
}
} else {
let mut app_stmt = conn.prepare(
"WITH fts_matched AS ( \
SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 5.0, 2.0) ASC) as fts_pos \
FROM apps_fts WHERE apps_fts MATCH :query LIMIT 300 \
), \
fts_ordered AS ( \
SELECT arg.app_id, MIN(m.fts_pos) as fts_pos \
FROM fts_matched m JOIN arguments arg ON m.cmd_path = arg.cmd_path \
GROUP BY arg.app_id \
), \
pop_ranked AS ( \
SELECT ftso.app_id, ftso.fts_pos, a.name as nm, \
COALESCE(a.popularity, 0.0) as pop, \
row_number() OVER (ORDER BY COALESCE(a.popularity, 0.0) DESC) as pop_pos \
FROM fts_ordered ftso JOIN apps a ON ftso.app_id = a.app_id \
), \
scored AS ( \
SELECT app_id, nm, \
COALESCE((:cold_floor + (1.0 - :cold_floor) * pop) * 1.0 / (60.0 + fts_pos), 0.0) \
+ :pw_lin * pop + :pw_cube * pop * pop * pop as rrf_score \
FROM pop_ranked \
), \
name_deduped AS ( \
SELECT app_id, rrf_score, \
row_number() OVER (PARTITION BY nm ORDER BY rrf_score DESC) as rn \
FROM scored \
) \
SELECT app_id FROM name_deduped WHERE rn = 1 ORDER BY rrf_score DESC LIMIT 5"
)?;
let app_rows = app_stmt.query_map(
rusqlite::named_params! {
":query": &cand_query,
":pw_lin": pw_lin,
":pw_cube": pw_cube,
":cold_floor": cold_floor,
},
|row| row.get::<_, String>(0),
)?;
for app_id in app_rows.flatten() {
top_apps.push(app_id);
}
}
if top_apps.is_empty() {
return Ok(exact_results);
}
if processed_query != "*" {
let mut fts_only_stmt = conn.prepare(
"WITH fts_matched AS ( \
SELECT cmd_path FROM apps_fts WHERE apps_fts MATCH :query LIMIT 100 \
) \
SELECT DISTINCT arg.app_id \
FROM fts_matched m JOIN arguments arg ON m.cmd_path = arg.cmd_path \
LIMIT 5",
)?;
let fts_app_rows = fts_only_stmt
.query_map(rusqlite::named_params! { ":query": &cand_query }, |row| {
row.get::<_, String>(0)
})?;
for app_id in fts_app_rows.flatten() {
if !top_apps.contains(&app_id) {
top_apps.push(app_id);
}
}
top_apps.truncate(8); }
while top_apps.len() < 8 {
top_apps.push(top_apps[0].clone());
}
let mut pop_map: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
{
let mut uniq: Vec<&String> = top_apps.iter().collect();
uniq.sort();
uniq.dedup();
let placeholders = uniq.iter().map(|_| "?").collect::<Vec<_>>().join(",");
if let Ok(mut pstmt) = conn.prepare(&format!(
"SELECT app_id, COALESCE(popularity, 0.0) FROM apps WHERE app_id IN ({placeholders})"
)) {
let params = rusqlite::params_from_iter(uniq.iter().map(|s| s.as_str()));
if let Ok(rows) = pstmt.query_map(params, |r| {
Ok((r.get::<_, String>(0)?, r.get::<_, f64>(1)?))
}) {
for kv in rows.flatten() {
pop_map.insert(kv.0, kv.1);
}
}
}
}
let pool = std::cmp::max(limit, 30);
let mut results = Vec::new();
if let Some(ref vb) = vec_bytes {
let mut stmt = conn.prepare(&format!(
"WITH fts_rank AS ( \
SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC) as fts_pos \
FROM apps_fts WHERE apps_fts MATCH :query \
LIMIT 100 \
), \
vec_rank AS ( \
SELECT cmd_path, row_number() OVER (ORDER BY distance ASC) as vec_pos \
FROM commands_vec \
WHERE embedding MATCH :query_vector AND k = 100 \
) \
SELECT \
arg.app_id, \
app.name, \
arg.cmd_path, \
arg.node_type, \
arg.description, \
arg.risk_level, \
arg.example_template, \
app.os_aliases, \
app.install_instructions, \
app.popularity, \
arg.docker_image, \
arg.script_url, \
arg.source_url, \
{prov} \
FROM arguments arg \
JOIN apps app ON arg.app_id = app.app_id \
LEFT JOIN fts_rank fts ON arg.cmd_path = fts.cmd_path \
LEFT JOIN vec_rank vec ON arg.cmd_path = vec.cmd_path \
WHERE (fts.cmd_path IS NOT NULL OR vec.cmd_path IS NOT NULL) \
AND arg.app_id IN (:app1, :app2, :app3, :app4, :app5, :app6, :app7, :app8) \
ORDER BY COALESCE(1.0 / (60.0 + fts.fts_pos), 0.0) + COALESCE(1.0 / (60.0 + vec.vec_pos), 0.0) DESC \
LIMIT :limit_num"
))?;
let rows = stmt.query_map(
rusqlite::named_params! {
":query": &processed_query,
":query_vector": vb,
":app1": &top_apps[0],
":app2": &top_apps[1],
":app3": &top_apps[2],
":app4": &top_apps[3],
":app5": &top_apps[4],
":app6": &top_apps[5],
":app7": &top_apps[6],
":app8": &top_apps[7],
":limit_num": pool,
},
|row| {
Ok(DbAciRecord {
app_id: row.get(0)?,
name: row.get(1)?,
cmd_path: row.get(2)?,
node_type: row.get(3)?,
description: row.get(4)?,
risk_level: row.get(5)?,
example_template: row.get(6)?,
os_aliases: row.get(7)?,
install_instructions: row.get(8)?,
popularity: row.get(9)?,
docker_image: row.get(10)?,
script_url: row.get(11)?,
source_url: row.get(12)?,
provenance: row.get(13)?,
})
},
)?;
for r in rows {
let record = r?;
if let Ok(contract) = AciCommandContract::try_from(record) {
results.push(contract);
}
}
} else {
let mut stmt = conn.prepare(&format!(
"SELECT \
arg.app_id, \
app.name, \
arg.cmd_path, \
arg.node_type, \
arg.description, \
arg.risk_level, \
arg.example_template, \
app.os_aliases, \
app.install_instructions, \
app.popularity, \
arg.docker_image, \
arg.script_url, \
arg.source_url, \
{prov} \
FROM arguments arg \
JOIN apps app ON arg.app_id = app.app_id \
JOIN apps_fts fts ON arg.cmd_path = fts.cmd_path \
WHERE apps_fts MATCH :query \
AND arg.app_id IN (:app1, :app2, :app3, :app4, :app5, :app6, :app7, :app8) \
ORDER BY bm25(apps_fts, 0.0, 5.0, 2.0) ASC \
LIMIT :limit_num"
))?;
let rows = stmt.query_map(
rusqlite::named_params! {
":query": &processed_query,
":app1": &top_apps[0],
":app2": &top_apps[1],
":app3": &top_apps[2],
":app4": &top_apps[3],
":app5": &top_apps[4],
":app6": &top_apps[5],
":app7": &top_apps[6],
":app8": &top_apps[7],
":limit_num": pool,
},
|row| {
Ok(DbAciRecord {
app_id: row.get(0)?,
name: row.get(1)?,
cmd_path: row.get(2)?,
node_type: row.get(3)?,
description: row.get(4)?,
risk_level: row.get(5)?,
example_template: row.get(6)?,
os_aliases: row.get(7)?,
install_instructions: row.get(8)?,
popularity: row.get(9)?,
docker_image: row.get(10)?,
script_url: row.get(11)?,
source_url: row.get(12)?,
provenance: row.get(13)?,
})
},
)?;
for r in rows {
let record = r?;
if let Ok(contract) = AciCommandContract::try_from(record) {
results.push(contract);
}
}
}
let q_tokens = content_tokens(query);
let q_path_tokens = expand_for_path_match(&q_tokens);
if !q_tokens.is_empty() && results.len() > 1 {
let n = results.len() as i32;
let path_w = if q_tokens.len() >= 2 { 4 } else { 1 };
let pop_bonus_w = if q_tokens.len() <= 1 { 15.0 } else { 3.0 };
let mut scored: Vec<(i32, usize, AciCommandContract)> = results
.drain(..)
.enumerate()
.map(|(i, c)| {
let rrf = n - i as i32; let pop_bonus =
(pop_bonus_w * pop_map.get(&c.app_id).copied().unwrap_or(0.0)) as i32;
let root_bonus = if matches!(c.node_type, cmdhub_shared::NodeType::Root)
&& q_tokens.len() <= 1
{
20
} else {
0
};
let verified_bonus = if c.verified { VERIFIED_BONUS } else { 0 };
let composite = rrf
+ path_w * path_match_score(&c.cmd_path, &q_path_tokens)
+ pop_bonus
+ root_bonus
+ verified_bonus;
(composite, i, c)
})
.collect();
scored.sort_by(|a, b| b.0.cmp(&a.0).then(a.1.cmp(&b.1)));
results = scored.into_iter().map(|(_, _, c)| c).collect();
}
let mut final_results = exact_results.clone();
final_results.append(&mut results);
const PER_APP_CAP: usize = 3;
let mut seen_paths = std::collections::HashSet::new();
let mut per_app: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
final_results.retain(|r| {
if !seen_paths.insert(r.cmd_path.clone()) {
return false;
}
let n = per_app.entry(r.app_id.clone()).or_insert(0);
*n += 1;
*n <= PER_APP_CAP
});
final_results.truncate(limit);
for r in &mut final_results {
r.confidence = confidence.clone();
}
Ok(final_results)
}
const VERIFIED_BONUS: i32 = 3;
fn expand_for_path_match(
tokens: &std::collections::HashSet<String>,
) -> std::collections::HashSet<String> {
let mut out: std::collections::HashSet<String> = std::collections::HashSet::new();
for t in tokens {
out.insert(t.clone());
if let Some(stem) = t.strip_suffix('s') {
if stem.len() >= 3 {
out.insert(stem.to_string());
}
}
for syn in concept_synonyms(t) {
out.insert((*syn).to_string());
}
}
out
}
fn content_tokens(query: &str) -> std::collections::HashSet<String> {
let stop: std::collections::HashSet<&str> = [
"how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
"or", "from", "my", "your", "our", "me", "us", "i", "want", "know", "using", "use", "do",
"can", "get", "please", "help", "show", "view",
]
.iter()
.cloned()
.collect();
query
.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|w| !w.is_empty())
.map(|w| w.to_lowercase())
.filter(|w| !stop.contains(w.as_str()))
.collect()
}
fn path_match_score(cmd_path: &str, q_tokens: &std::collections::HashSet<String>) -> i32 {
if !cmd_path.contains('.') {
return 0;
}
let after_binary = cmd_path.split_once('.').map(|x| x.1).unwrap_or(cmd_path);
let tokens: Vec<String> = after_binary
.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|w| !w.is_empty())
.map(|w| w.to_lowercase())
.collect();
if tokens.is_empty() {
return 0;
}
let overlap = tokens
.iter()
.filter(|t| {
q_tokens.contains(*t)
|| t.strip_suffix('s')
.is_some_and(|s| s.len() >= 3 && q_tokens.contains(s))
})
.count() as i32;
let extra = tokens.len() as i32 - overlap;
(3 * overlap - extra).max(0)
}
pub fn search_commands(
conn: &Connection,
query: &str,
query_vector: Option<&[f32]>,
limit: usize,
) -> Result<Vec<AciCommandContract>> {
let mut has_vector_db = false;
if query_vector.is_some() {
if let Ok(count) = conn.query_row::<u64, _, _>(
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='commands_vec'",
[],
|row| row.get(0),
) {
if count > 0 {
if let Ok(vec_count) =
conn.query_row::<u64, _, _>("SELECT count(*) FROM commands_vec", [], |row| {
row.get(0)
})
{
if vec_count > 0 {
has_vector_db = true;
}
}
}
}
}
search_cascading(conn, query, query_vector, limit, has_vector_db)
}
pub fn search_all(
conn: &Connection,
query: &str,
query_vector: Option<&[f32]>,
limit: usize,
) -> Result<Vec<AciCommandContract>> {
let mut results = search_commands(conn, query, query_vector, limit)?;
let config_dir = crate::config::get_config_dir();
let skills_dir = config_dir.join("skills");
let local_skill = cmdhub_skills::LocalFileSkill::new(skills_dir);
let mut registry = cmdhub_skills::SkillRegistry::new();
registry.register(Box::new(local_skill));
if let Ok(mut skill_results) = registry.resolve(query) {
results.append(&mut skill_results);
}
let mut seen = std::collections::HashSet::new();
results.retain(|item| seen.insert(item.cmd_path.clone()));
results.truncate(limit);
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_confidence_mapping() {
assert_eq!(calculate_confidence(0.70, false), "high");
assert_eq!(calculate_confidence(0.75, true), "high");
assert_eq!(calculate_confidence(0.78, false), "low");
assert_eq!(calculate_confidence(0.82, false), "low");
assert_eq!(calculate_confidence(0.85, true), "low");
assert_eq!(calculate_confidence(0.83, false), "none");
assert_eq!(calculate_confidence(0.90, false), "none");
}
#[test]
fn test_exact_match_priority() {
let conn = Connection::open_in_memory().unwrap();
init_db(&conn).unwrap();
conn.execute(
"INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
("org.test.git", "git", "{}"),
)
.unwrap();
conn.execute(
"INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
("git", "org.test.git", "git", "root", "Git version control", "safe", "git"),
)
.unwrap();
conn.execute(
"INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
("git", "git", "Git version control"),
)
.unwrap();
conn.execute(
"INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
("org.test.gitleaks", "gitleaks", "{}"),
)
.unwrap();
conn.execute(
"INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
("gitleaks", "org.test.gitleaks", "gitleaks", "root", "Detect secrets in git", "safe", "gitleaks"),
)
.unwrap();
conn.execute(
"INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
("gitleaks", "gitleaks", "Detect secrets in git"),
)
.unwrap();
let res = search_commands(&conn, "git", None, 10).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].cmd_path, "git");
}
#[test]
fn test_fts_fallback_and_or() {
let conn = Connection::open_in_memory().unwrap();
init_db(&conn).unwrap();
conn.execute(
"INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
("org.test.rm", "rm", "{}"),
)
.unwrap();
conn.execute(
"INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
("rm", "org.test.rm", "rm", "root", "delete local files", "safe", "rm"),
)
.unwrap();
conn.execute(
"INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
("rm", "rm", "delete local files"),
)
.unwrap();
let res = search_commands(&conn, "delete local files", None, 10).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].cmd_path, "rm");
let res = search_commands(&conn, "delete my local files", None, 10).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].cmd_path, "rm");
let res = search_commands(&conn, "delete missing files", None, 10).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].cmd_path, "rm");
}
#[test]
fn test_hybrid_search_knn_match() {
unsafe {
type SqliteVecInitFn = unsafe extern "C" fn();
let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
#[allow(clippy::missing_transmute_annotations)]
let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
}
let conn = Connection::open_in_memory().unwrap();
init_db(&conn).unwrap();
conn.execute(
"INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
("org.test.knn", "knn", "{}"),
)
.unwrap();
conn.execute(
"INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
("knn", "org.test.knn", "knn", "root", "vector search helper", "safe", "knn"),
)
.unwrap();
conn.execute(
"INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
("knn", "knn", "vector search helper"),
)
.unwrap();
let v = vec![0.1f32; 384];
let mut v_bytes = Vec::with_capacity(384 * 4);
for &val in &v {
v_bytes.extend_from_slice(&val.to_le_bytes());
}
conn.execute(
"INSERT INTO commands_vec (cmd_path, embedding) VALUES (?, ?)",
("knn", v_bytes),
)
.unwrap();
let query_vec = vec![0.12f32; 384];
let res = search_commands(&conn, "missing_term", Some(&query_vec), 10).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].cmd_path, "knn");
}
#[test]
fn test_clear_maps_to_prune_synonyms() {
assert!(concept_synonyms("clear").contains(&"prune"));
assert!(concept_synonyms("clean").contains(&"prune"));
assert!(concept_synonyms("purge").contains(&"prune"));
assert!(concept_synonyms("prune").contains(&"unused"));
assert!(concept_synonyms("fuzzy").contains(&"fzf"));
assert!(concept_synonyms("finder").contains(&"fd"));
assert!(concept_synonyms("download").contains(&"curl"));
assert!(concept_synonyms("diff").contains(&"delta"));
assert!(concept_synonyms("grep").contains(&"ripgrep"));
}
#[test]
fn test_tool_alias_synonyms_are_tool_names_only_not_generic_concepts() {
assert!(tool_alias_synonyms("fuzzy").contains(&"fzf"));
assert!(tool_alias_synonyms("grep").contains(&"rg"));
assert!(tool_alias_synonyms("download").contains(&"curl"));
assert!(tool_alias_synonyms("kubernetes").is_empty());
assert!(tool_alias_synonyms("view").is_empty());
assert!(tool_alias_synonyms("clear").is_empty());
}
#[test]
fn test_expand_for_path_match_adds_synonyms_and_singulars() {
let tokens: std::collections::HashSet<String> =
["clear", "images"].iter().map(|s| s.to_string()).collect();
let expanded = expand_for_path_match(&tokens);
assert!(expanded.contains("prune")); assert!(expanded.contains("image")); assert!(expanded.contains("clear")); }
#[test]
fn test_old_schema_db_without_provenance_still_works() {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(
"CREATE TABLE apps (app_id TEXT PRIMARY KEY, name TEXT NOT NULL, os_aliases TEXT, \
install_instructions TEXT, popularity REAL DEFAULT 0.0); \
CREATE TABLE arguments (cmd_path TEXT PRIMARY KEY, app_id TEXT NOT NULL, \
node_name TEXT NOT NULL, node_type TEXT NOT NULL, description TEXT NOT NULL, \
risk_level TEXT NOT NULL, example_template TEXT, docker_image TEXT, \
script_url TEXT, source_url TEXT); \
CREATE VIRTUAL TABLE apps_fts USING fts5(cmd_path UNINDEXED, name, capabilities); \
INSERT INTO apps (app_id, name) VALUES ('org.test.tar', 'tar'); \
INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level) \
VALUES ('tar', 'org.test.tar', 'tar', 'root', 'archive files', 'safe'); \
INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES ('tar', 'tar', 'archive files');",
)
.unwrap();
let res = search_commands(&conn, "tar", None, 5).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].cmd_path, "tar");
assert!(!res[0].verified); }
#[test]
fn test_probe_verified_outranks_inferred_twin() {
let conn = Connection::open_in_memory().unwrap();
init_db(&conn).unwrap();
for (app, prov) in [
("org.inferred.tool", "inferred"),
("org.probed.tool", "probe"),
] {
let name = if prov == "probe" { "toolp" } else { "tooli" };
conn.execute(
"INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, '{}')",
(app, name),
)
.unwrap();
let path = format!("{}.image.prune", name);
conn.execute(
"INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, \
risk_level, provenance) VALUES (?, ?, 'prune', 'sub', \
'Remove unused container images to free disk space', 'dangerous', ?)",
(&path, app, prov),
)
.unwrap();
conn.execute(
"INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, \
'Remove unused container images to free disk space')",
(&path, name),
)
.unwrap();
}
let res = search_cascading(&conn, "clear unused images", None, 5, false).unwrap();
assert!(res.len() >= 2, "expected both twins, got {}", res.len());
assert!(res[0].verified, "probe-verified twin must rank first");
assert!(res[0].cmd_path.starts_with("toolp"));
assert!(!res[1].verified);
}
}