Skip to main content

cmdhub_cli/
db.rs

1use crate::config::get_data_dir;
2use anyhow::{Context, Result};
3use cmdhub_shared::{
4    AciCommandContract, DbAciRecord, CREATE_APPS_FTS_TABLE, CREATE_APPS_TABLE,
5    CREATE_ARGUMENTS_TABLE, CREATE_COMMANDS_VEC_TABLE,
6};
7use rusqlite::Connection;
8use std::path::PathBuf;
9
10pub fn resolve_db_path() -> PathBuf {
11    get_data_dir().join("cmdhub.db")
12}
13
14pub fn open_db() -> Result<Connection> {
15    let db_path = resolve_db_path();
16    if let Some(parent) = db_path.parent() {
17        std::fs::create_dir_all(parent).context("Failed to create database parent directories")?;
18    }
19
20    unsafe {
21        type SqliteVecInitFn = unsafe extern "C" fn();
22        let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
23        #[allow(clippy::missing_transmute_annotations)]
24        let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
25    }
26
27    let conn = Connection::open(&db_path).context("Failed to open SQLite database file")?;
28    let _ = conn.execute("PRAGMA journal_mode = WAL;", []);
29    let _ = conn.execute("PRAGMA synchronous = NORMAL;", []);
30    Ok(conn)
31}
32
33pub fn init_db(conn: &Connection) -> Result<()> {
34    conn.execute(CREATE_APPS_TABLE, [])
35        .context("Failed to create apps table")?;
36    conn.execute(CREATE_ARGUMENTS_TABLE, [])
37        .context("Failed to create arguments table")?;
38    conn.execute(CREATE_APPS_FTS_TABLE, [])
39        .context("Failed to create apps_fts table")?;
40
41    // Commands vector table may fail to create if sqlite-vec is not fully supported or active
42    if let Err(e) = conn.execute(CREATE_COMMANDS_VEC_TABLE, []) {
43        eprintln!("Warning: Failed to initialize sqlite-vec commands_vec table: {}. Falling back to FTS5 search.", e);
44    }
45    Ok(())
46}
47
48fn preprocess_query(query: &str, use_and: bool) -> String {
49    let stop_words: std::collections::HashSet<&str> = [
50        "how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
51        "or", "from", "my", "your", "our", "me", "us",
52    ]
53    .iter()
54    .cloned()
55    .collect();
56
57    let words: Vec<String> = query
58        .split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
59        .filter(|w| !w.is_empty())
60        .map(|w| w.to_lowercase())
61        .filter(|w| !stop_words.contains(w.as_str()))
62        .map(|w| format!("{}*", w))
63        .collect();
64
65    if words.is_empty() {
66        "*".to_string()
67    } else if use_and {
68        words.join(" ")
69    } else {
70        words.join(" OR ")
71    }
72}
73
74pub fn search_commands(
75    conn: &Connection,
76    query: &str,
77    query_vector: Option<&[f32]>,
78    limit: usize,
79) -> Result<Vec<AciCommandContract>> {
80    let and_query = preprocess_query(query, true);
81    let or_query = preprocess_query(query, false);
82
83    let mut and_match = false;
84    if and_query != "*" {
85        if let Ok(count) = conn.query_row::<u64, _, _>(
86            "SELECT count(*) FROM apps_fts WHERE apps_fts MATCH :query",
87            rusqlite::named_params! { ":query": &and_query },
88            |row| row.get(0),
89        ) {
90            if count > 0 {
91                and_match = true;
92            }
93        }
94    }
95
96    let processed_query = if and_match { and_query } else { or_query };
97
98    // 1. Fast exact-match short-circuiting check (LOWER check for path/name)
99    let trimmed_query = query.trim().to_lowercase();
100    let mut exact_stmt = conn.prepare(
101        "SELECT \
102            arg.app_id, \
103            app.name, \
104            arg.cmd_path, \
105            arg.node_type, \
106            arg.description, \
107            arg.risk_level, \
108            arg.example_template, \
109            app.install_instructions, \
110            arg.docker_image, \
111            arg.script_url, \
112            arg.source_url \
113        FROM arguments arg \
114        JOIN apps app ON arg.app_id = app.app_id \
115        WHERE LOWER(arg.cmd_path) = :query OR LOWER(app.name) = :query \
116        LIMIT :limit_num",
117    )?;
118
119    let exact_rows = exact_stmt.query_map(
120        rusqlite::named_params! {
121            ":query": trimmed_query,
122            ":limit_num": limit,
123        },
124        |row| {
125            Ok(DbAciRecord {
126                app_id: row.get(0)?,
127                name: row.get(1)?,
128                cmd_path: row.get(2)?,
129                node_type: row.get(3)?,
130                description: row.get(4)?,
131                risk_level: row.get(5)?,
132                example_template: row.get(6)?,
133                install_instructions: row.get(7)?,
134                docker_image: row.get(8)?,
135                script_url: row.get(9)?,
136                source_url: row.get(10)?,
137            })
138        },
139    )?;
140
141    let mut exact_results = Vec::new();
142    for record in exact_rows.flatten() {
143        if let Ok(contract) = AciCommandContract::try_from(record) {
144            exact_results.push(contract);
145        }
146    }
147
148    // Check if commands_vec table exists and has any data
149    let mut has_vector_db = false;
150    if query_vector.is_some() {
151        if let Ok(count) = conn.query_row::<u64, _, _>(
152            "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='commands_vec'",
153            [],
154            |row| row.get(0),
155        ) {
156            if count > 0 {
157                if let Ok(vec_count) =
158                    conn.query_row::<u64, _, _>("SELECT count(*) FROM commands_vec", [], |row| {
159                        row.get(0)
160                    })
161                {
162                    if vec_count > 0 {
163                        has_vector_db = true;
164                    }
165                }
166            }
167        }
168    }
169
170    if has_vector_db {
171        let q_vec = query_vector.unwrap();
172        let mut vec_bytes = Vec::with_capacity(q_vec.len() * 4);
173        for &val in q_vec {
174            vec_bytes.extend_from_slice(&val.to_ne_bytes());
175        }
176        let mut stmt = conn.prepare(
177            "WITH fts_rank AS ( \
178                SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC) as fts_pos \
179                FROM apps_fts WHERE apps_fts MATCH :query \
180                LIMIT 100 \
181            ), \
182            vec_rank AS ( \
183                SELECT cmd_path, row_number() OVER (ORDER BY distance ASC) as vec_pos \
184                FROM commands_vec \
185                WHERE embedding MATCH :query_vector AND k = 100 \
186            ) \
187            SELECT \
188                arg.app_id, \
189                app.name, \
190                arg.cmd_path, \
191                arg.node_type, \
192                arg.description, \
193                arg.risk_level, \
194                arg.example_template, \
195                app.install_instructions, \
196                arg.docker_image, \
197                arg.script_url, \
198                arg.source_url \
199            FROM arguments arg \
200            JOIN apps app ON arg.app_id = app.app_id \
201            LEFT JOIN fts_rank fts ON arg.cmd_path = fts.cmd_path \
202            LEFT JOIN vec_rank vec ON arg.cmd_path = vec.cmd_path \
203            WHERE fts.cmd_path IS NOT NULL OR vec.cmd_path IS NOT NULL \
204            ORDER BY COALESCE(1.0 / (60.0 + fts.fts_pos), 0.0) + COALESCE(1.0 / (60.0 + vec.vec_pos), 0.0) DESC \
205            LIMIT :limit_num"
206        )?;
207
208        let rows = stmt.query_map(
209            rusqlite::named_params! {
210                ":query": processed_query,
211                ":query_vector": vec_bytes,
212                ":limit_num": limit,
213            },
214            |row| {
215                Ok(DbAciRecord {
216                    app_id: row.get(0)?,
217                    name: row.get(1)?,
218                    cmd_path: row.get(2)?,
219                    node_type: row.get(3)?,
220                    description: row.get(4)?,
221                    risk_level: row.get(5)?,
222                    example_template: row.get(6)?,
223                    install_instructions: row.get(7)?,
224                    docker_image: row.get(8)?,
225                    script_url: row.get(9)?,
226                    source_url: row.get(10)?,
227                })
228            },
229        )?;
230
231        let mut results = Vec::new();
232        for r in rows {
233            let record = r?;
234            if let Ok(contract) = AciCommandContract::try_from(record) {
235                results.push(contract);
236            }
237        }
238        let mut final_results = exact_results.clone();
239        final_results.append(&mut results);
240        Ok(final_results)
241    } else {
242        // Fallback to pure FTS5 MATCH BM25 search
243        let mut stmt = conn.prepare(
244            "SELECT \
245                arg.app_id, \
246                app.name, \
247                arg.cmd_path, \
248                arg.node_type, \
249                arg.description, \
250                arg.risk_level, \
251                arg.example_template, \
252                app.install_instructions, \
253                arg.docker_image, \
254                arg.script_url, \
255                arg.source_url \
256            FROM arguments arg \
257            JOIN apps app ON arg.app_id = app.app_id \
258            JOIN apps_fts fts ON arg.cmd_path = fts.cmd_path \
259            WHERE apps_fts MATCH :query \
260            ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC \
261            LIMIT :limit_num",
262        )?;
263
264        let rows = stmt.query_map(
265            rusqlite::named_params! {
266                ":query": processed_query,
267                ":limit_num": limit,
268            },
269            |row| {
270                Ok(DbAciRecord {
271                    app_id: row.get(0)?,
272                    name: row.get(1)?,
273                    cmd_path: row.get(2)?,
274                    node_type: row.get(3)?,
275                    description: row.get(4)?,
276                    risk_level: row.get(5)?,
277                    example_template: row.get(6)?,
278                    install_instructions: row.get(7)?,
279                    docker_image: row.get(8)?,
280                    script_url: row.get(9)?,
281                    source_url: row.get(10)?,
282                })
283            },
284        )?;
285
286        let mut results = Vec::new();
287        for r in rows {
288            let record = r?;
289            if let Ok(contract) = AciCommandContract::try_from(record) {
290                results.push(contract);
291            }
292        }
293        let mut final_results = exact_results.clone();
294        final_results.append(&mut results);
295        Ok(final_results)
296    }
297}
298
299pub fn search_all(
300    conn: &Connection,
301    query: &str,
302    query_vector: Option<&[f32]>,
303    limit: usize,
304) -> Result<Vec<AciCommandContract>> {
305    let mut results = search_commands(conn, query, query_vector, limit)?;
306
307    let config_dir = crate::config::get_config_dir();
308    let skills_dir = config_dir.join("skills");
309    let local_skill = cmdhub_skills::LocalFileSkill::new(skills_dir);
310
311    let mut registry = cmdhub_skills::SkillRegistry::new();
312    registry.register(Box::new(local_skill));
313
314    if let Ok(mut skill_results) = registry.resolve(query) {
315        results.append(&mut skill_results);
316    }
317
318    let mut seen = std::collections::HashSet::new();
319    results.retain(|item| seen.insert(item.cmd_path.clone()));
320    results.truncate(limit);
321
322    Ok(results)
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_exact_match_priority() {
331        let conn = Connection::open_in_memory().unwrap();
332        init_db(&conn).unwrap();
333
334        conn.execute(
335            "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
336            ("org.test.git", "git", "{}"),
337        )
338        .unwrap();
339
340        conn.execute(
341            "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
342             VALUES (?, ?, ?, ?, ?, ?, ?)",
343            ("git", "org.test.git", "git", "root", "Git version control", "safe", "git"),
344        )
345        .unwrap();
346
347        conn.execute(
348            "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
349            ("git", "git", "Git version control"),
350        )
351        .unwrap();
352
353        conn.execute(
354            "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
355            ("org.test.gitleaks", "gitleaks", "{}"),
356        )
357        .unwrap();
358
359        conn.execute(
360            "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
361             VALUES (?, ?, ?, ?, ?, ?, ?)",
362            ("gitleaks", "org.test.gitleaks", "gitleaks", "root", "Detect secrets in git", "safe", "gitleaks"),
363        )
364        .unwrap();
365
366        conn.execute(
367            "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
368            ("gitleaks", "gitleaks", "Detect secrets in git"),
369        )
370        .unwrap();
371
372        let res = search_commands(&conn, "git", None, 10).unwrap();
373        assert!(!res.is_empty());
374        assert_eq!(res[0].cmd_path, "git");
375    }
376
377    #[test]
378    fn test_fts_fallback_and_or() {
379        let conn = Connection::open_in_memory().unwrap();
380        init_db(&conn).unwrap();
381
382        // Setup test apps and arguments
383        conn.execute(
384            "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
385            ("org.test.rm", "rm", "{}"),
386        )
387        .unwrap();
388        conn.execute(
389            "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
390             VALUES (?, ?, ?, ?, ?, ?, ?)",
391            ("rm", "org.test.rm", "rm", "root", "delete local files", "safe", "rm"),
392        )
393        .unwrap();
394        conn.execute(
395            "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
396            ("rm", "rm", "delete local files"),
397        )
398        .unwrap();
399
400        // 1. Search that matches AND exactly: "delete local files"
401        let res = search_commands(&conn, "delete local files", None, 10).unwrap();
402        assert!(!res.is_empty());
403        assert_eq!(res[0].cmd_path, "rm");
404
405        // 2. Search that matches AND with stop words: "delete my local files"
406        let res = search_commands(&conn, "delete my local files", None, 10).unwrap();
407        assert!(!res.is_empty());
408        assert_eq!(res[0].cmd_path, "rm");
409
410        // 3. Search that has no complete AND matches: "delete missing files" (FTS AND query = "delete* AND missing* AND files*")
411        // It must fallback to OR and still match "rm" (matches delete, files)
412        let res = search_commands(&conn, "delete missing files", None, 10).unwrap();
413        assert!(!res.is_empty());
414        assert_eq!(res[0].cmd_path, "rm");
415    }
416
417    #[test]
418    fn test_hybrid_search_knn_match() {
419        unsafe {
420            type SqliteVecInitFn = unsafe extern "C" fn();
421            let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
422            #[allow(clippy::missing_transmute_annotations)]
423            let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
424        }
425
426        let conn = Connection::open_in_memory().unwrap();
427        init_db(&conn).unwrap();
428
429        conn.execute(
430            "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
431            ("org.test.knn", "knn", "{}"),
432        )
433        .unwrap();
434        conn.execute(
435            "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
436             VALUES (?, ?, ?, ?, ?, ?, ?)",
437            ("knn", "org.test.knn", "knn", "root", "vector search helper", "safe", "knn"),
438        )
439        .unwrap();
440        conn.execute(
441            "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
442            ("knn", "knn", "vector search helper"),
443        )
444        .unwrap();
445
446        // Insert vector
447        let v = vec![0.1f32; 512];
448        let mut v_bytes = Vec::with_capacity(512 * 4);
449        for &val in &v {
450            v_bytes.extend_from_slice(&val.to_ne_bytes());
451        }
452
453        conn.execute(
454            "INSERT INTO commands_vec (cmd_path, embedding) VALUES (?, ?)",
455            ("knn", v_bytes),
456        )
457        .unwrap();
458
459        // Search with query vector
460        let query_vec = vec![0.12f32; 512];
461        let res = search_commands(&conn, "missing_term", Some(&query_vec), 10).unwrap();
462        assert!(!res.is_empty());
463        assert_eq!(res[0].cmd_path, "knn");
464    }
465}