Skip to main content

zsh/
completion.rs

1//! SQLite-backed completion engine for zshrs
2//!
3//! Features:
4//! - FTS5 full-text search for instant fuzzy matching
5//! - Frequency tracking from command history
6//! - Persistent index survives shell restarts
7//! - Sub-millisecond queries on 40k+ completions
8
9use rusqlite::{params, Connection};
10use std::path::PathBuf;
11
12pub struct CompletionEngine {
13    conn: Connection,
14}
15
16#[derive(Debug, Clone)]
17pub struct Completion {
18    pub name: String,
19    pub kind: CompletionKind,
20    pub description: Option<String>,
21    pub frequency: u32,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum CompletionKind {
26    Command,
27    Builtin,
28    Function,
29    Alias,
30    File,
31    Directory,
32    Variable,
33    Option,
34}
35
36impl CompletionKind {
37    pub fn as_str(&self) -> &'static str {
38        match self {
39            Self::Command => "command",
40            Self::Builtin => "builtin",
41            Self::Function => "function",
42            Self::Alias => "alias",
43            Self::File => "file",
44            Self::Directory => "directory",
45            Self::Variable => "variable",
46            Self::Option => "option",
47        }
48    }
49
50    fn from_str(s: &str) -> Self {
51        match s {
52            "command" => Self::Command,
53            "builtin" => Self::Builtin,
54            "function" => Self::Function,
55            "alias" => Self::Alias,
56            "file" => Self::File,
57            "directory" => Self::Directory,
58            "variable" => Self::Variable,
59            "option" => Self::Option,
60            _ => Self::Command,
61        }
62    }
63}
64
65impl CompletionEngine {
66    pub fn new() -> rusqlite::Result<Self> {
67        let db_path = Self::db_path();
68        std::fs::create_dir_all(db_path.parent().unwrap()).ok();
69        let conn = Connection::open(&db_path)?;
70
71        let engine = Self { conn };
72        engine.init_schema()?;
73        Ok(engine)
74    }
75
76    pub fn in_memory() -> rusqlite::Result<Self> {
77        let conn = Connection::open_in_memory()?;
78        let engine = Self { conn };
79        engine.init_schema()?;
80        Ok(engine)
81    }
82
83    fn db_path() -> PathBuf {
84        dirs::cache_dir()
85            .unwrap_or_else(|| PathBuf::from("."))
86            .join("zshrs")
87            .join("completions.db")
88    }
89
90    fn init_schema(&self) -> rusqlite::Result<()> {
91        self.conn.execute_batch(
92            r#"
93            CREATE TABLE IF NOT EXISTS completions (
94                id INTEGER PRIMARY KEY,
95                name TEXT NOT NULL,
96                kind TEXT NOT NULL,
97                description TEXT,
98                frequency INTEGER DEFAULT 0,
99                UNIQUE(name, kind)
100            );
101
102            CREATE VIRTUAL TABLE IF NOT EXISTS completions_fts USING fts5(
103                name,
104                description,
105                content='completions',
106                content_rowid='id'
107            );
108
109            CREATE TRIGGER IF NOT EXISTS completions_ai AFTER INSERT ON completions BEGIN
110                INSERT INTO completions_fts(rowid, name, description)
111                VALUES (new.id, new.name, new.description);
112            END;
113
114            CREATE TRIGGER IF NOT EXISTS completions_ad AFTER DELETE ON completions BEGIN
115                INSERT INTO completions_fts(completions_fts, rowid, name, description)
116                VALUES ('delete', old.id, old.name, old.description);
117            END;
118
119            CREATE TRIGGER IF NOT EXISTS completions_au AFTER UPDATE ON completions BEGIN
120                INSERT INTO completions_fts(completions_fts, rowid, name, description)
121                VALUES ('delete', old.id, old.name, old.description);
122                INSERT INTO completions_fts(rowid, name, description)
123                VALUES (new.id, new.name, new.description);
124            END;
125
126            CREATE INDEX IF NOT EXISTS idx_completions_name ON completions(name);
127            CREATE INDEX IF NOT EXISTS idx_completions_kind ON completions(kind);
128            CREATE INDEX IF NOT EXISTS idx_completions_frequency ON completions(frequency DESC);
129        "#,
130        )?;
131        Ok(())
132    }
133
134    pub fn add_completion(
135        &self,
136        name: &str,
137        kind: CompletionKind,
138        description: Option<&str>,
139    ) -> rusqlite::Result<()> {
140        self.conn.execute(
141            "INSERT OR IGNORE INTO completions (name, kind, description) VALUES (?1, ?2, ?3)",
142            params![name, kind.as_str(), description],
143        )?;
144        Ok(())
145    }
146
147    pub fn add_completions(
148        &self,
149        completions: &[(String, CompletionKind, Option<String>)],
150    ) -> rusqlite::Result<()> {
151        let tx = self.conn.unchecked_transaction()?;
152        {
153            let mut stmt = self.conn.prepare(
154                "INSERT OR IGNORE INTO completions (name, kind, description) VALUES (?1, ?2, ?3)",
155            )?;
156            for (name, kind, desc) in completions {
157                stmt.execute(params![name, kind.as_str(), desc.as_deref()])?;
158            }
159        }
160        tx.commit()?;
161        Ok(())
162    }
163
164    pub fn increment_frequency(&self, name: &str) -> rusqlite::Result<()> {
165        self.conn.execute(
166            "UPDATE completions SET frequency = frequency + 1 WHERE name = ?1",
167            params![name],
168        )?;
169        Ok(())
170    }
171
172    pub fn search(&self, query: &str, limit: usize) -> rusqlite::Result<Vec<Completion>> {
173        if query.is_empty() {
174            return self.get_top_by_frequency(limit);
175        }
176
177        // Try prefix match first (faster)
178        let prefix_results = self.search_prefix(query, limit)?;
179        if prefix_results.len() >= limit {
180            return Ok(prefix_results);
181        }
182
183        // Fall back to FTS5 fuzzy search
184        self.search_fts(query, limit)
185    }
186
187    fn search_prefix(&self, prefix: &str, limit: usize) -> rusqlite::Result<Vec<Completion>> {
188        let mut stmt = self.conn.prepare(
189            "SELECT name, kind, description, frequency FROM completions 
190             WHERE name LIKE ?1 || '%'
191             ORDER BY frequency DESC, name ASC
192             LIMIT ?2",
193        )?;
194
195        let rows = stmt.query_map(params![prefix, limit as i64], |row| {
196            Ok(Completion {
197                name: row.get(0)?,
198                kind: CompletionKind::from_str(&row.get::<_, String>(1)?),
199                description: row.get(2)?,
200                frequency: row.get(3)?,
201            })
202        })?;
203
204        rows.collect()
205    }
206
207    fn search_fts(&self, query: &str, limit: usize) -> rusqlite::Result<Vec<Completion>> {
208        let fts_query = format!("{}*", query);
209        let mut stmt = self.conn.prepare(
210            "SELECT c.name, c.kind, c.description, c.frequency 
211             FROM completions c
212             JOIN completions_fts fts ON c.id = fts.rowid
213             WHERE completions_fts MATCH ?1
214             ORDER BY c.frequency DESC, rank
215             LIMIT ?2",
216        )?;
217
218        let rows = stmt.query_map(params![fts_query, limit as i64], |row| {
219            Ok(Completion {
220                name: row.get(0)?,
221                kind: CompletionKind::from_str(&row.get::<_, String>(1)?),
222                description: row.get(2)?,
223                frequency: row.get(3)?,
224            })
225        })?;
226
227        rows.collect()
228    }
229
230    fn get_top_by_frequency(&self, limit: usize) -> rusqlite::Result<Vec<Completion>> {
231        let mut stmt = self.conn.prepare(
232            "SELECT name, kind, description, frequency FROM completions 
233             ORDER BY frequency DESC, name ASC
234             LIMIT ?1",
235        )?;
236
237        let rows = stmt.query_map(params![limit as i64], |row| {
238            Ok(Completion {
239                name: row.get(0)?,
240                kind: CompletionKind::from_str(&row.get::<_, String>(1)?),
241                description: row.get(2)?,
242                frequency: row.get(3)?,
243            })
244        })?;
245
246        rows.collect()
247    }
248
249    pub fn count(&self) -> rusqlite::Result<usize> {
250        self.conn
251            .query_row("SELECT COUNT(*) FROM completions", [], |row| row.get(0))
252    }
253
254    pub fn index_system_commands(&self) -> rusqlite::Result<usize> {
255        let path = std::env::var("PATH").unwrap_or_default();
256        let mut completions = Vec::new();
257
258        for dir in path.split(':') {
259            if let Ok(entries) = std::fs::read_dir(dir) {
260                for entry in entries.flatten() {
261                    if let Ok(ft) = entry.file_type() {
262                        if ft.is_file() || ft.is_symlink() {
263                            if let Some(name) = entry.file_name().to_str() {
264                                completions.push((name.to_string(), CompletionKind::Command, None));
265                            }
266                        }
267                    }
268                }
269            }
270        }
271
272        let count = completions.len();
273        self.add_completions(&completions)?;
274        Ok(count)
275    }
276
277    pub fn index_shell_builtins(&self) -> rusqlite::Result<usize> {
278        let builtins = [
279            ("cd", "Change directory"),
280            ("pwd", "Print working directory"),
281            ("echo", "Print arguments"),
282            ("export", "Set environment variable"),
283            ("unset", "Unset environment variable"),
284            ("alias", "Define alias"),
285            ("unalias", "Remove alias"),
286            ("source", "Execute file in current shell"),
287            ("exit", "Exit the shell"),
288            ("jobs", "List background jobs"),
289            ("fg", "Bring job to foreground"),
290            ("bg", "Continue job in background"),
291            ("history", "Show command history"),
292            ("set", "Set shell options"),
293            ("unset", "Unset shell options"),
294            ("type", "Show command type"),
295            ("which", "Show command path"),
296            ("builtin", "Execute builtin command"),
297            ("command", "Execute external command"),
298            ("exec", "Replace shell with command"),
299            ("eval", "Evaluate arguments as command"),
300            ("read", "Read input"),
301            ("printf", "Formatted print"),
302            ("test", "Evaluate conditional expression"),
303            ("true", "Return success"),
304            ("false", "Return failure"),
305            (":", "Null command"),
306            ("return", "Return from function"),
307            ("break", "Break from loop"),
308            ("continue", "Continue loop"),
309            ("shift", "Shift positional parameters"),
310            ("wait", "Wait for background jobs"),
311            ("trap", "Set signal handler"),
312            ("umask", "Set file creation mask"),
313            ("ulimit", "Set resource limits"),
314            ("times", "Show shell times"),
315            ("kill", "Send signal to process"),
316            ("let", "Evaluate arithmetic expression"),
317            ("declare", "Declare variable"),
318            ("local", "Declare local variable"),
319            ("readonly", "Make variable readonly"),
320            ("typeset", "Declare variable type"),
321            ("hash", "Remember command path"),
322            ("dirs", "Show directory stack"),
323            ("pushd", "Push directory"),
324            ("popd", "Pop directory"),
325            ("getopts", "Parse options"),
326            ("enable", "Enable/disable builtins"),
327            ("logout", "Exit login shell"),
328            ("suspend", "Suspend shell"),
329            ("disown", "Remove job from table"),
330        ];
331
332        let completions: Vec<_> = builtins
333            .iter()
334            .map(|(name, desc)| {
335                (
336                    name.to_string(),
337                    CompletionKind::Builtin,
338                    Some(desc.to_string()),
339                )
340            })
341            .collect();
342
343        let count = completions.len();
344        self.add_completions(&completions)?;
345        Ok(count)
346    }
347}
348
349impl Default for CompletionEngine {
350    fn default() -> Self {
351        Self::new().expect("Failed to create completion engine")
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_completion_engine() {
361        let engine = CompletionEngine::in_memory().unwrap();
362
363        engine
364            .add_completion("git", CompletionKind::Command, Some("Version control"))
365            .unwrap();
366        engine
367            .add_completion("grep", CompletionKind::Command, Some("Search text"))
368            .unwrap();
369        engine
370            .add_completion("gzip", CompletionKind::Command, Some("Compress files"))
371            .unwrap();
372
373        let results = engine.search("g", 10).unwrap();
374        assert_eq!(results.len(), 3);
375
376        let results = engine.search("gi", 10).unwrap();
377        assert_eq!(results.len(), 1);
378        assert_eq!(results[0].name, "git");
379    }
380
381    #[test]
382    fn test_frequency_ranking() {
383        let engine = CompletionEngine::in_memory().unwrap();
384
385        engine
386            .add_completion("aaa", CompletionKind::Command, None)
387            .unwrap();
388        engine
389            .add_completion("aab", CompletionKind::Command, None)
390            .unwrap();
391
392        // Increment aab frequency
393        for _ in 0..5 {
394            engine.increment_frequency("aab").unwrap();
395        }
396
397        let results = engine.search("aa", 10).unwrap();
398        assert_eq!(results[0].name, "aab"); // Higher frequency first
399    }
400}