Skip to main content

cmdhub_cli/
db.rs

1use crate::config::get_data_dir;
2use anyhow::{Context, Result};
3use cmdhub_shared::{
4    AciCommandContract, DbAciRecord, CREATE_APPS_FTS_TABLE, CREATE_APPS_TABLE,
5    CREATE_ARGUMENTS_TABLE, CREATE_COMMANDS_VEC_TABLE,
6};
7use rusqlite::Connection;
8use std::path::PathBuf;
9
10pub fn resolve_db_path() -> PathBuf {
11    get_data_dir().join("cmdhub.db")
12}
13
14pub fn open_db() -> Result<Connection> {
15    let db_path = resolve_db_path();
16    if let Some(parent) = db_path.parent() {
17        std::fs::create_dir_all(parent).context("Failed to create database parent directories")?;
18    }
19
20    unsafe {
21        type SqliteVecInitFn = unsafe extern "C" fn();
22        let init_fn: SqliteVecInitFn = sqlite_vec::sqlite3_vec_init;
23        #[allow(clippy::missing_transmute_annotations)]
24        let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(init_fn)));
25    }
26
27    let conn = Connection::open(&db_path).context("Failed to open SQLite database file")?;
28    Ok(conn)
29}
30
31pub fn init_db(conn: &Connection) -> Result<()> {
32    conn.execute(CREATE_APPS_TABLE, [])
33        .context("Failed to create apps table")?;
34    conn.execute(CREATE_ARGUMENTS_TABLE, [])
35        .context("Failed to create arguments table")?;
36    conn.execute(CREATE_APPS_FTS_TABLE, [])
37        .context("Failed to create apps_fts table")?;
38
39    // Commands vector table may fail to create if sqlite-vec is not fully supported or active
40    if let Err(e) = conn.execute(CREATE_COMMANDS_VEC_TABLE, []) {
41        eprintln!("Warning: Failed to initialize sqlite-vec commands_vec table: {}. Falling back to FTS5 search.", e);
42    }
43    Ok(())
44}
45
46fn preprocess_query(query: &str) -> String {
47    let stop_words: std::collections::HashSet<&str> = [
48        "how", "to", "a", "the", "on", "in", "of", "for", "with", "an", "is", "at", "by", "and",
49        "or", "from", "my", "your", "our", "me", "us",
50    ]
51    .iter()
52    .cloned()
53    .collect();
54
55    let words: Vec<String> = query
56        .split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
57        .filter(|w| !w.is_empty())
58        .map(|w| w.to_lowercase())
59        .filter(|w| !stop_words.contains(w.as_str()))
60        .map(|w| format!("{}*", w))
61        .collect();
62
63    if words.is_empty() {
64        "*".to_string()
65    } else {
66        words.join(" OR ")
67    }
68}
69
70pub fn search_commands(
71    conn: &Connection,
72    query: &str,
73    query_vector: Option<&[f32]>,
74    limit: usize,
75) -> Result<Vec<AciCommandContract>> {
76    let processed_query = preprocess_query(query);
77
78    // 1. Fast exact-match short-circuiting check (LOWER check for path/name)
79    let trimmed_query = query.trim().to_lowercase();
80    let mut exact_stmt = conn.prepare(
81        "SELECT \
82            arg.app_id, \
83            app.name, \
84            arg.cmd_path, \
85            arg.node_type, \
86            arg.description, \
87            arg.risk_level, \
88            arg.example_template, \
89            app.install_instructions \
90        FROM arguments arg \
91        JOIN apps app ON arg.app_id = app.app_id \
92        WHERE LOWER(arg.cmd_path) = :query OR LOWER(app.name) = :query \
93        LIMIT :limit_num",
94    )?;
95
96    let exact_rows = exact_stmt.query_map(
97        rusqlite::named_params! {
98            ":query": trimmed_query,
99            ":limit_num": limit,
100        },
101        |row| {
102            Ok(DbAciRecord {
103                app_id: row.get(0)?,
104                name: row.get(1)?,
105                cmd_path: row.get(2)?,
106                node_type: row.get(3)?,
107                description: row.get(4)?,
108                risk_level: row.get(5)?,
109                example_template: row.get(6)?,
110                install_instructions: row.get(7)?,
111            })
112        },
113    )?;
114
115    let mut exact_results = Vec::new();
116    for record in exact_rows.flatten() {
117        if let Ok(contract) = AciCommandContract::try_from(record) {
118            exact_results.push(contract);
119        }
120    }
121
122    // Check if commands_vec table exists and has any data
123    let mut has_vector_db = false;
124    if query_vector.is_some() {
125        if let Ok(count) = conn.query_row::<u64, _, _>(
126            "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='commands_vec'",
127            [],
128            |row| row.get(0),
129        ) {
130            if count > 0 {
131                if let Ok(vec_count) =
132                    conn.query_row::<u64, _, _>("SELECT count(*) FROM commands_vec", [], |row| {
133                        row.get(0)
134                    })
135                {
136                    if vec_count > 0 {
137                        has_vector_db = true;
138                    }
139                }
140            }
141        }
142    }
143
144    if has_vector_db {
145        let q_vec = query_vector.unwrap();
146        let mut vec_bytes = Vec::with_capacity(q_vec.len() * 4);
147        for &val in q_vec {
148            vec_bytes.extend_from_slice(&val.to_ne_bytes());
149        }
150        let mut stmt = conn.prepare(
151            "WITH fts_rank AS ( \
152                SELECT cmd_path, row_number() OVER (ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC) as fts_pos \
153                FROM apps_fts WHERE apps_fts MATCH :query \
154                LIMIT 100 \
155            ), \
156            vec_rank AS ( \
157                SELECT cmd_path, row_number() OVER (ORDER BY vec_distance_cosine(embedding, :query_vector) ASC) as vec_pos \
158                FROM commands_vec \
159                LIMIT 100 \
160            ) \
161            SELECT \
162                arg.app_id, \
163                app.name, \
164                arg.cmd_path, \
165                arg.node_type, \
166                arg.description, \
167                arg.risk_level, \
168                arg.example_template, \
169                app.install_instructions \
170            FROM arguments arg \
171            JOIN apps app ON arg.app_id = app.app_id \
172            LEFT JOIN fts_rank fts ON arg.cmd_path = fts.cmd_path \
173            LEFT JOIN vec_rank vec ON arg.cmd_path = vec.cmd_path \
174            WHERE fts.cmd_path IS NOT NULL OR vec.cmd_path IS NOT NULL \
175            ORDER BY COALESCE(1.0 / (60.0 + fts.fts_pos), 0.0) + COALESCE(1.0 / (60.0 + vec.vec_pos), 0.0) DESC \
176            LIMIT :limit_num"
177        )?;
178
179        let rows = stmt.query_map(
180            rusqlite::named_params! {
181                ":query": processed_query,
182                ":query_vector": vec_bytes,
183                ":limit_num": limit,
184            },
185            |row| {
186                Ok(DbAciRecord {
187                    app_id: row.get(0)?,
188                    name: row.get(1)?,
189                    cmd_path: row.get(2)?,
190                    node_type: row.get(3)?,
191                    description: row.get(4)?,
192                    risk_level: row.get(5)?,
193                    example_template: row.get(6)?,
194                    install_instructions: row.get(7)?,
195                })
196            },
197        )?;
198
199        let mut results = Vec::new();
200        for r in rows {
201            let record = r?;
202            if let Ok(contract) = AciCommandContract::try_from(record) {
203                results.push(contract);
204            }
205        }
206        let mut final_results = exact_results.clone();
207        final_results.append(&mut results);
208        Ok(final_results)
209    } else {
210        // Fallback to pure FTS5 MATCH BM25 search
211        let mut stmt = conn.prepare(
212            "SELECT \
213                arg.app_id, \
214                app.name, \
215                arg.cmd_path, \
216                arg.node_type, \
217                arg.description, \
218                arg.risk_level, \
219                arg.example_template, \
220                app.install_instructions \
221            FROM arguments arg \
222            JOIN apps app ON arg.app_id = app.app_id \
223            JOIN apps_fts fts ON arg.cmd_path = fts.cmd_path \
224            WHERE apps_fts MATCH :query \
225            ORDER BY bm25(apps_fts, 0.0, 10.0, 1.0) ASC \
226            LIMIT :limit_num",
227        )?;
228
229        let rows = stmt.query_map(
230            rusqlite::named_params! {
231                ":query": processed_query,
232                ":limit_num": limit,
233            },
234            |row| {
235                Ok(DbAciRecord {
236                    app_id: row.get(0)?,
237                    name: row.get(1)?,
238                    cmd_path: row.get(2)?,
239                    node_type: row.get(3)?,
240                    description: row.get(4)?,
241                    risk_level: row.get(5)?,
242                    example_template: row.get(6)?,
243                    install_instructions: row.get(7)?,
244                })
245            },
246        )?;
247
248        let mut results = Vec::new();
249        for r in rows {
250            let record = r?;
251            if let Ok(contract) = AciCommandContract::try_from(record) {
252                results.push(contract);
253            }
254        }
255        let mut final_results = exact_results.clone();
256        final_results.append(&mut results);
257        Ok(final_results)
258    }
259}
260
261pub fn search_all(
262    conn: &Connection,
263    query: &str,
264    query_vector: Option<&[f32]>,
265    limit: usize,
266) -> Result<Vec<AciCommandContract>> {
267    let mut results = search_commands(conn, query, query_vector, limit)?;
268
269    let config_dir = crate::config::get_config_dir();
270    let skills_dir = config_dir.join("skills");
271    let local_skill = cmdhub_skills::LocalFileSkill::new(skills_dir);
272
273    let mut registry = cmdhub_skills::SkillRegistry::new();
274    registry.register(Box::new(local_skill));
275
276    if let Ok(mut skill_results) = registry.resolve(query) {
277        results.append(&mut skill_results);
278    }
279
280    let mut seen = std::collections::HashSet::new();
281    results.retain(|item| seen.insert(item.cmd_path.clone()));
282    results.truncate(limit);
283
284    Ok(results)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_exact_match_priority() {
293        let conn = Connection::open_in_memory().unwrap();
294        init_db(&conn).unwrap();
295
296        conn.execute(
297            "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
298            ("org.test.git", "git", "{}"),
299        )
300        .unwrap();
301
302        conn.execute(
303            "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
304             VALUES (?, ?, ?, ?, ?, ?, ?)",
305            ("git", "org.test.git", "git", "root", "Git version control", "safe", "git"),
306        )
307        .unwrap();
308
309        conn.execute(
310            "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
311            ("git", "git", "Git version control"),
312        )
313        .unwrap();
314
315        conn.execute(
316            "INSERT INTO apps (app_id, name, install_instructions) VALUES (?, ?, ?)",
317            ("org.test.gitleaks", "gitleaks", "{}"),
318        )
319        .unwrap();
320
321        conn.execute(
322            "INSERT INTO arguments (cmd_path, app_id, node_name, node_type, description, risk_level, example_template) \
323             VALUES (?, ?, ?, ?, ?, ?, ?)",
324            ("gitleaks", "org.test.gitleaks", "gitleaks", "root", "Detect secrets in git", "safe", "gitleaks"),
325        )
326        .unwrap();
327
328        conn.execute(
329            "INSERT INTO apps_fts (cmd_path, name, capabilities) VALUES (?, ?, ?)",
330            ("gitleaks", "gitleaks", "Detect secrets in git"),
331        )
332        .unwrap();
333
334        let res = search_commands(&conn, "git", None, 10).unwrap();
335        assert!(!res.is_empty());
336        assert_eq!(res[0].cmd_path, "git");
337    }
338}