1use crate::config::get_data_dir;
2use anyhow::{Context, Result};
3use cmdhub_shared::{
4 AciCommandContract, DbAciRecord, CREATE_APPS_FTS_TABLE, CREATE_APPS_TABLE,
5 CREATE_ARGUMENTS_TABLE, CREATE_COMMANDS_VEC_TABLE,
6};
7use rusqlite::Connection;
8use std::path::PathBuf;
9
10pub fn resolve_db_path() -> PathBuf {
11 get_data_dir().join("cmdhub.db")
12}
13
14pub fn open_db() -> Result<Connection> {
15 let db_path = resolve_db_path();
16 if let Some(parent) = db_path.parent() {
17 std::fs::create_dir_all(parent).context("Failed to create database parent directories")?;
18 }
19
20 unsafe {
21 type SqliteVecInitFn = unsafe extern "C" fn();
22 let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
23 #[allow(clippy::missing_transmute_annotations)]
24 let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
25 }
26
27 let conn = Connection::open(&db_path).context("Failed to open SQLite database file")?;
28 let _ = conn.execute("PRAGMA journal_mode = WAL;", []);
29 let _ = conn.execute("PRAGMA synchronous = NORMAL;", []);
30 Ok(conn)
31}
32
33pub fn init_db(conn: &Connection) -> Result<()> {
34 conn.execute(CREATE_APPS_TABLE, [])
35 .context("Failed to create apps table")?;
36 conn.execute(CREATE_ARGUMENTS_TABLE, [])
37 .context("Failed to create arguments table")?;
38 conn.execute(CREATE_APPS_FTS_TABLE, [])
39 .context("Failed to create apps_fts table")?;
40
41 if let Err(e) = conn.execute(CREATE_COMMANDS_VEC_TABLE, []) {
43 eprintln!("Warning: Failed to initialize sqlite-vec commands_vec table: {}. Falling back to FTS5 search.", e);
44 }
45 Ok(())
46}
47
48fn preprocess_query(query: &str, use_and: bool) -> String {
49 let stop_words: std::collections::HashSet<&str> = [
50 "how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
51 "or", "from", "my", "your", "our", "me", "us",
52 ]
53 .iter()
54 .cloned()
55 .collect();
56
57 let words: Vec<String> = query
58 .split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
59 .filter(|w| !w.is_empty())
60 .map(|w| w.to_lowercase())
61 .filter(|w| !stop_words.contains(w.as_str()))
62 .map(|w| format!("{}*", w))
63 .collect();
64
65 if words.is_empty() {
66 "*".to_string()
67 } else if use_and {
68 words.join(" ")
69 } else {
70 words.join(" OR ")
71 }
72}
73
74pub fn search_commands(
75 conn: &Connection,
76 query: &str,
77 query_vector: Option<&[f32]>,
78 limit: usize,
79) -> Result<Vec<AciCommandContract>> {
80 let and_query = preprocess_query(query, true);
81 let or_query = preprocess_query(query, false);
82
83 let mut and_match = false;
84 if and_query != "*" {
85 if let Ok(count) = conn.query_row::<u64, _, _>(
86 "SELECT count(*) FROM apps_fts WHERE apps_fts MATCH :query",
87 rusqlite::named_params! { ":query": &and_query },
88 |row| row.get(0),
89 ) {
90 if count > 0 {
91 and_match = true;
92 }
93 }
94 }
95
96 let processed_query = if and_match { and_query } else { or_query };
97
98 let trimmed_query = query.trim().to_lowercase();
100 let mut exact_stmt = conn.prepare(
101 "SELECT \
102 arg.app_id, \
103 app.name, \
104 arg.cmd_path, \
105 arg.node_type, \
106 arg.description, \
107 arg.risk_level, \
108 arg.example_template, \
109 app.install_instructions, \
110 arg.docker_image, \
111 arg.script_url, \
112 arg.source_url \
113 FROM arguments arg \
114 JOIN apps app ON arg.app_id = app.app_id \
115 WHERE LOWER(arg.cmd_path) = :query OR LOWER(app.name) = :query \
116 LIMIT :limit_num",
117 )?;
118
119 let exact_rows = exact_stmt.query_map(
120 rusqlite::named_params! {
121 ":query": trimmed_query,
122 ":limit_num": limit,
123 },
124 |row| {
125 Ok(DbAciRecord {
126 app_id: row.get(0)?,
127 name: row.get(1)?,
128 cmd_path: row.get(2)?,
129 node_type: row.get(3)?,
130 description: row.get(4)?,
131 risk_level: row.get(5)?,
132 example_template: row.get(6)?,
133 install_instructions: row.get(7)?,
134 docker_image: row.get(8)?,
135 script_url: row.get(9)?,
136 source_url: row.get(10)?,
137 })
138 },
139 )?;
140
141 let mut exact_results = Vec::new();
142 for record in exact_rows.flatten() {
143 if let Ok(contract) = AciCommandContract::try_from(record) {
144 exact_results.push(contract);
145 }
146 }
147
148 let mut has_vector_db = false;
150 if query_vector.is_some() {
151 if let Ok(count) = conn.query_row::<u64, _, _>(
152 "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='commands_vec'",
153 [],
154 |row| row.get(0),
155 ) {
156 if count > 0 {
157 if let Ok(vec_count) =
158 conn.query_row::<u64, _, _>("SELECT count(*) FROM commands_vec", [], |row| {
159 row.get(0)
160 })
161 {
162 if vec_count > 0 {
163 has_vector_db = true;
164 }
165 }
166 }
167 }
168 }
169
170 if has_vector_db {
171 let q_vec = query_vector.unwrap();
172 let mut vec_bytes = Vec::with_capacity(q_vec.len() * 4);
173 for &val in q_vec {
174 vec_bytes.extend_from_slice(&val.to_ne_bytes());
175 }
176 let mut stmt = conn.prepare(
177 "WITH fts_rank AS ( \
178 SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC) as fts_pos \
179 FROM apps_fts WHERE apps_fts MATCH :query \
180 LIMIT 100 \
181 ), \
182 vec_rank AS ( \
183 SELECT cmd_path, row_number() OVER (ORDER BY distance ASC) as vec_pos \
184 FROM commands_vec \
185 WHERE embedding MATCH :query_vector AND k = 100 \
186 ) \
187 SELECT \
188 arg.app_id, \
189 app.name, \
190 arg.cmd_path, \
191 arg.node_type, \
192 arg.description, \
193 arg.risk_level, \
194 arg.example_template, \
195 app.install_instructions, \
196 arg.docker_image, \
197 arg.script_url, \
198 arg.source_url \
199 FROM arguments arg \
200 JOIN apps app ON arg.app_id = app.app_id \
201 LEFT JOIN fts_rank fts ON arg.cmd_path = fts.cmd_path \
202 LEFT JOIN vec_rank vec ON arg.cmd_path = vec.cmd_path \
203 WHERE fts.cmd_path IS NOT NULL OR vec.cmd_path IS NOT NULL \
204 ORDER BY COALESCE(1.0 / (60.0 + fts.fts_pos), 0.0) + COALESCE(1.0 / (60.0 + vec.vec_pos), 0.0) DESC \
205 LIMIT :limit_num"
206 )?;
207
208 let rows = stmt.query_map(
209 rusqlite::named_params! {
210 ":query": processed_query,
211 ":query_vector": vec_bytes,
212 ":limit_num": limit,
213 },
214 |row| {
215 Ok(DbAciRecord {
216 app_id: row.get(0)?,
217 name: row.get(1)?,
218 cmd_path: row.get(2)?,
219 node_type: row.get(3)?,
220 description: row.get(4)?,
221 risk_level: row.get(5)?,
222 example_template: row.get(6)?,
223 install_instructions: row.get(7)?,
224 docker_image: row.get(8)?,
225 script_url: row.get(9)?,
226 source_url: row.get(10)?,
227 })
228 },
229 )?;
230
231 let mut results = Vec::new();
232 for r in rows {
233 let record = r?;
234 if let Ok(contract) = AciCommandContract::try_from(record) {
235 results.push(contract);
236 }
237 }
238 let mut final_results = exact_results.clone();
239 final_results.append(&mut results);
240 Ok(final_results)
241 } else {
242 let mut stmt = conn.prepare(
244 "SELECT \
245 arg.app_id, \
246 app.name, \
247 arg.cmd_path, \
248 arg.node_type, \
249 arg.description, \
250 arg.risk_level, \
251 arg.example_template, \
252 app.install_instructions, \
253 arg.docker_image, \
254 arg.script_url, \
255 arg.source_url \
256 FROM arguments arg \
257 JOIN apps app ON arg.app_id = app.app_id \
258 JOIN apps_fts fts ON arg.cmd_path = fts.cmd_path \
259 WHERE apps_fts MATCH :query \
260 ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC \
261 LIMIT :limit_num",
262 )?;
263
264 let rows = stmt.query_map(
265 rusqlite::named_params! {
266 ":query": processed_query,
267 ":limit_num": limit,
268 },
269 |row| {
270 Ok(DbAciRecord {
271 app_id: row.get(0)?,
272 name: row.get(1)?,
273 cmd_path: row.get(2)?,
274 node_type: row.get(3)?,
275 description: row.get(4)?,
276 risk_level: row.get(5)?,
277 example_template: row.get(6)?,
278 install_instructions: row.get(7)?,
279 docker_image: row.get(8)?,
280 script_url: row.get(9)?,
281 source_url: row.get(10)?,
282 })
283 },
284 )?;
285
286 let mut results = Vec::new();
287 for r in rows {
288 let record = r?;
289 if let Ok(contract) = AciCommandContract::try_from(record) {
290 results.push(contract);
291 }
292 }
293 let mut final_results = exact_results.clone();
294 final_results.append(&mut results);
295 Ok(final_results)
296 }
297}
298
299pub fn search_all(
300 conn: &Connection,
301 query: &str,
302 query_vector: Option<&[f32]>,
303 limit: usize,
304) -> Result<Vec<AciCommandContract>> {
305 let mut results = search_commands(conn, query, query_vector, limit)?;
306
307 let config_dir = crate::config::get_config_dir();
308 let skills_dir = config_dir.join("skills");
309 let local_skill = cmdhub_skills::LocalFileSkill::new(skills_dir);
310
311 let mut registry = cmdhub_skills::SkillRegistry::new();
312 registry.register(Box::new(local_skill));
313
314 if let Ok(mut skill_results) = registry.resolve(query) {
315 results.append(&mut skill_results);
316 }
317
318 let mut seen = std::collections::HashSet::new();
319 results.retain(|item| seen.insert(item.cmd_path.clone()));
320 results.truncate(limit);
321
322 Ok(results)
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_exact_match_priority() {
331 let conn = Connection::open_in_memory().unwrap();
332 init_db(&conn).unwrap();
333
334 conn.execute(
335 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
336 ("org.test.git", "git", "{}"),
337 )
338 .unwrap();
339
340 conn.execute(
341 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
342 VALUES (?, ?, ?, ?, ?, ?, ?)",
343 ("git", "org.test.git", "git", "root", "Git version control", "safe", "git"),
344 )
345 .unwrap();
346
347 conn.execute(
348 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
349 ("git", "git", "Git version control"),
350 )
351 .unwrap();
352
353 conn.execute(
354 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
355 ("org.test.gitleaks", "gitleaks", "{}"),
356 )
357 .unwrap();
358
359 conn.execute(
360 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
361 VALUES (?, ?, ?, ?, ?, ?, ?)",
362 ("gitleaks", "org.test.gitleaks", "gitleaks", "root", "Detect secrets in git", "safe", "gitleaks"),
363 )
364 .unwrap();
365
366 conn.execute(
367 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
368 ("gitleaks", "gitleaks", "Detect secrets in git"),
369 )
370 .unwrap();
371
372 let res = search_commands(&conn, "git", None, 10).unwrap();
373 assert!(!res.is_empty());
374 assert_eq!(res[0].cmd_path, "git");
375 }
376
377 #[test]
378 fn test_fts_fallback_and_or() {
379 let conn = Connection::open_in_memory().unwrap();
380 init_db(&conn).unwrap();
381
382 conn.execute(
384 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
385 ("org.test.rm", "rm", "{}"),
386 )
387 .unwrap();
388 conn.execute(
389 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
390 VALUES (?, ?, ?, ?, ?, ?, ?)",
391 ("rm", "org.test.rm", "rm", "root", "delete local files", "safe", "rm"),
392 )
393 .unwrap();
394 conn.execute(
395 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
396 ("rm", "rm", "delete local files"),
397 )
398 .unwrap();
399
400 let res = search_commands(&conn, "delete local files", None, 10).unwrap();
402 assert!(!res.is_empty());
403 assert_eq!(res[0].cmd_path, "rm");
404
405 let res = search_commands(&conn, "delete my local files", None, 10).unwrap();
407 assert!(!res.is_empty());
408 assert_eq!(res[0].cmd_path, "rm");
409
410 let res = search_commands(&conn, "delete missing files", None, 10).unwrap();
413 assert!(!res.is_empty());
414 assert_eq!(res[0].cmd_path, "rm");
415 }
416
417 #[test]
418 fn test_hybrid_search_knn_match() {
419 unsafe {
420 type SqliteVecInitFn = unsafe extern "C" fn();
421 let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
422 #[allow(clippy::missing_transmute_annotations)]
423 let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
424 }
425
426 let conn = Connection::open_in_memory().unwrap();
427 init_db(&conn).unwrap();
428
429 conn.execute(
430 "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
431 ("org.test.knn", "knn", "{}"),
432 )
433 .unwrap();
434 conn.execute(
435 "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
436 VALUES (?, ?, ?, ?, ?, ?, ?)",
437 ("knn", "org.test.knn", "knn", "root", "vector search helper", "safe", "knn"),
438 )
439 .unwrap();
440 conn.execute(
441 "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
442 ("knn", "knn", "vector search helper"),
443 )
444 .unwrap();
445
446 let v = vec![0.1f32; 512];
448 let mut v_bytes = Vec::with_capacity(512 * 4);
449 for &val in &v {
450 v_bytes.extend_from_slice(&val.to_ne_bytes());
451 }
452
453 conn.execute(
454 "INSERT INTO commands_vec (cmd_path, embedding) VALUES (?, ?)",
455 ("knn", v_bytes),
456 )
457 .unwrap();
458
459 let query_vec = vec![0.12f32; 512];
461 let res = search_commands(&conn, "missing_term", Some(&query_vec), 10).unwrap();
462 assert!(!res.is_empty());
463 assert_eq!(res[0].cmd_path, "knn");
464 }
465}