1use crate::config::get_data_dir;
2use anyhow::{Context, Result};
3use cmdhub_shared::{
4 AciCommandContract, DbAciRecord, CREATE_APPS_FTS_TABLE, CREATE_APPS_TABLE,
5 CREATE_ARGUMENTS_TABLE, CREATE_COMMANDS_VEC_TABLE,
6};
7use rusqlite::Connection;
8use std::path::PathBuf;
9
10pub fn resolve_db_path() -> PathBuf {
11 get_data_dir().join("cmdhub.db")
12}
13
14pub fn open_db() -> Result<Connection> {
15 let db_path = resolve_db_path();
16 if let Some(parent) = db_path.parent() {
17 std::fs::create_dir_all(parent).context("Failed to create database parent directories")?;
18 }
19
20 unsafe {
21 type SqliteVecInitFn = unsafe extern "C" fn();
22 let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
23 #[allow(clippy::missing_transmute_annotations)]
24 let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
25 }
26
27 let conn = Connection::open(&db_path).context("Failed to open SQLite database file")?;
28 Ok(conn)
29}
30
31pub fn init_db(conn: &Connection) -> Result<()> {
32 conn.execute(CREATE_APPS_TABLE, [])
33 .context("Failed to create apps table")?;
34 conn.execute(CREATE_ARGUMENTS_TABLE, [])
35 .context("Failed to create arguments table")?;
36 conn.execute(CREATE_APPS_FTS_TABLE, [])
37 .context("Failed to create apps_fts table")?;
38
39 if let Err(e) = conn.execute(CREATE_COMMANDS_VEC_TABLE, []) {
41 eprintln!("Warning: Failed to initialize sqlite-vec commands_vec table: {}. Falling back to FTS5 search.", e);
42 }
43 Ok(())
44}
45
46fn preprocess_query(query: &str) -> String {
47 let stop_words: std::collections::HashSet<&str> = [
48 "how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
49 "or", "from", "my", "your", "our", "me", "us",
50 ]
51 .iter()
52 .cloned()
53 .collect();
54
55 let words: Vec<String> = query
56 .split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
57 .filter(|w| !w.is_empty())
58 .map(|w| w.to_lowercase())
59 .filter(|w| !stop_words.contains(w.as_str()))
60 .map(|w| format!("{}*", w))
61 .collect();
62
63 if words.is_empty() {
64 "*".to_string()
65 } else {
66 words.join(" OR ")
67 }
68}
69
70pub fn search_commands(
71 conn: &Connection,
72 query: &str,
73 query_vector: Option<&[f32]>,
74 limit: usize,
75) -> Result<Vec<AciCommandContract>> {
76 let processed_query = preprocess_query(query);
77
78 let trimmed_query = query.trim().to_lowercase();
80 let mut exact_stmt = conn.prepare(
81 "SELECT \
82 arg.app_id, \
83 app.name, \
84 arg.cmd_path, \
85 arg.node_type, \
86 arg.description, \
87 arg.risk_level, \
88 arg.example_template, \
89 app.install_instructions \
90 FROM arguments arg \
91 JOIN apps app ON arg.app_id = app.app_id \
92 WHERE LOWER(arg.cmd_path) = :query OR LOWER(app.name) = :query \
93 LIMIT :limit_num",
94 )?;
95
96 let exact_rows = exact_stmt.query_map(
97 rusqlite::named_params! {
98 ":query": trimmed_query,
99 ":limit_num": limit,
100 },
101 |row| {
102 Ok(DbAciRecord {
103 app_id: row.get(0)?,
104 name: row.get(1)?,
105 cmd_path: row.get(2)?,
106 node_type: row.get(3)?,
107 description: row.get(4)?,
108 risk_level: row.get(5)?,
109 example_template: row.get(6)?,
110 install_instructions: row.get(7)?,
111 })
112 },
113 )?;
114
115 let mut exact_results = Vec::new();
116 for record in exact_rows.flatten() {
117 if let Ok(contract) = AciCommandContract::try_from(record) {
118 exact_results.push(contract);
119 }
120 }
121
122 let mut has_vector_db = false;
124 if query_vector.is_some() {
125 if let Ok(count) = conn.query_row::<u64, _, _>(
126 "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='commands_vec'",
127 [],
128 |row| row.get(0),
129 ) {
130 if count > 0 {
131 if let Ok(vec_count) =
132 conn.query_row::<u64, _, _>("SELECT count(*) FROM commands_vec", [], |row| {
133 row.get(0)
134 })
135 {
136 if vec_count > 0 {
137 has_vector_db = true;
138 }
139 }
140 }
141 }
142 }
143
144 if has_vector_db {
145 let q_vec = query_vector.unwrap();
146 let mut vec_bytes = Vec::with_capacity(q_vec.len() * 4);
147 for &val in q_vec {
148 vec_bytes.extend_from_slice(&val.to_ne_bytes());
149 }
150 let mut stmt = conn.prepare(
151 "WITH fts_rank AS ( \
152 SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC) as fts_pos \
153 FROM apps_fts WHERE apps_fts MATCH :query \
154 LIMIT 100 \
155 ), \
156 vec_rank AS ( \
157 SELECT cmd_path, row_number() OVER (ORDER BY vec_distance_cosine(embedding, :query_vector) ASC) as vec_pos \
158 FROM commands_vec \
159 LIMIT 100 \
160 ) \
161 SELECT \
162 arg.app_id, \
163 app.name, \
164 arg.cmd_path, \
165 arg.node_type, \
166 arg.description, \
167 arg.risk_level, \
168 arg.example_template, \
169 app.install_instructions \
170 FROM arguments arg \
171 JOIN apps app ON arg.app_id = app.app_id \
172 LEFT JOIN fts_rank fts ON arg.cmd_path = fts.cmd_path \
173 LEFT JOIN vec_rank vec ON arg.cmd_path = vec.cmd_path \
174 WHERE fts.cmd_path IS NOT NULL OR vec.cmd_path IS NOT NULL \
175 ORDER BY COALESCE(1.0 / (60.0 + fts.fts_pos), 0.0) + COALESCE(1.0 / (60.0 + vec.vec_pos), 0.0) DESC \
176 LIMIT :limit_num"
177 )?;
178
179 let rows = stmt.query_map(
180 rusqlite::named_params! {
181 ":query": processed_query,
182 ":query_vector": vec_bytes,
183 ":limit_num": limit,
184 },
185 |row| {
186 Ok(DbAciRecord {
187 app_id: row.get(0)?,
188 name: row.get(1)?,
189 cmd_path: row.get(2)?,
190 node_type: row.get(3)?,
191 description: row.get(4)?,
192 risk_level: row.get(5)?,
193 example_template: row.get(6)?,
194 install_instructions: row.get(7)?,
195 })
196 },
197 )?;
198
199 let mut results = Vec::new();
200 for r in rows {
201 let record = r?;
202 if let Ok(contract) = AciCommandContract::try_from(record) {
203 results.push(contract);
204 }
205 }
206 let mut final_results = exact_results.clone();
207 final_results.append(&mut results);
208 Ok(final_results)
209 } else {
210 let mut stmt = conn.prepare(
212 "SELECT \
213 arg.app_id, \
214 app.name, \
215 arg.cmd_path, \
216 arg.node_type, \
217 arg.description, \
218 arg.risk_level, \
219 arg.example_template, \
220 app.install_instructions \
221 FROM arguments arg \
222 JOIN apps app ON arg.app_id = app.app_id \
223 JOIN apps_fts fts ON arg.cmd_path = fts.cmd_path \
224 WHERE apps_fts MATCH :query \
225 ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC \
226 LIMIT :limit_num",
227 )?;
228
229 let rows = stmt.query_map(
230 rusqlite::named_params! {
231 ":query": processed_query,
232 ":limit_num": limit,
233 },
234 |row| {
235 Ok(DbAciRecord {
236 app_id: row.get(0)?,
237 name: row.get(1)?,
238 cmd_path: row.get(2)?,
239 node_type: row.get(3)?,
240 description: row.get(4)?,
241 risk_level: row.get(5)?,
242 example_template: row.get(6)?,
243 install_instructions: row.get(7)?,
244 })
245 },
246 )?;
247
248 let mut results = Vec::new();
249 for r in rows {
250 let record = r?;
251 if let Ok(contract) = AciCommandContract::try_from(record) {
252 results.push(contract);
253 }
254 }
255 let mut final_results = exact_results.clone();
256 final_results.append(&mut results);
257 Ok(final_results)
258 }
259}
260
261pub fn search_all(
262 conn: &Connection,
263 query: &str,
264 query_vector: Option<&[f32]>,
265 limit: usize,
266) -> Result<Vec<AciCommandContract>> {
267 let mut results = search_commands(conn, query, query_vector, limit)?;
268
269 let config_dir = crate::config::get_config_dir();
270 let skills_dir = config_dir.join("skills");
271 let local_skill = cmdhub_skills::LocalFileSkill::new(skills_dir);
272
273 let mut registry = cmdhub_skills::SkillRegistry::new();
274 registry.register(Box::new(local_skill));
275
276 if let Ok(mut skill_results) = registry.resolve(query) {
277 results.append(&mut skill_results);
278 }
279
280 let mut seen = std::collections::HashSet::new();
281 results.retain(|item| seen.insert(item.cmd_path.clone()));
282 results.truncate(limit);
283
284 Ok(results)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_exact_match_priority() {
293 let conn = Connection::open_in_memory().unwrap();
294 init_db(&conn).unwrap();
295
296 conn.execute(
297 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
298 ("org.test.git", "git", "{}"),
299 )
300 .unwrap();
301
302 conn.execute(
303 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
304 VALUES (?, ?, ?, ?, ?, ?, ?)",
305 ("git", "org.test.git", "git", "root", "Git version control", "safe", "git"),
306 )
307 .unwrap();
308
309 conn.execute(
310 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
311 ("git", "git", "Git version control"),
312 )
313 .unwrap();
314
315 conn.execute(
316 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
317 ("org.test.gitleaks", "gitleaks", "{}"),
318 )
319 .unwrap();
320
321 conn.execute(
322 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
323 VALUES (?, ?, ?, ?, ?, ?, ?)",
324 ("gitleaks", "org.test.gitleaks", "gitleaks", "root", "Detect secrets in git", "safe", "gitleaks"),
325 )
326 .unwrap();
327
328 conn.execute(
329 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
330 ("gitleaks", "gitleaks", "Detect secrets in git"),
331 )
332 .unwrap();
333
334 let res = search_commands(&conn, "git", None, 10).unwrap();
335 assert!(!res.is_empty());
336 assert_eq!(res[0].cmd_path, "git");
337 }
338}