Skip to main content

aprender_present_lib/browser/
shell_autocomplete.rs

1//! Shell Command Autocomplete Demo
2//!
3//! Real WASM implementation using the trained aprender-shell-base.apr model.
4//! Uses N-gram Markov model for command prediction.
5//!
6//! Spec: docs/specifications/showcase-demo-aprender-shell-apr.md
7
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[cfg(target_arch = "wasm32")]
12use wasm_bindgen::prelude::*;
13
14/// APR format header size
15const HEADER_SIZE: usize = 32;
16
17/// Shell command autocomplete using N-gram Markov model
18#[derive(Debug)]
19pub struct ShellAutocomplete {
20    /// N-gram size (typically 3)
21    n: usize,
22    /// N-gram counts: context -> (next_token -> count)
23    ngrams: HashMap<String, HashMap<String, u32>>,
24    /// Command frequency for ranking
25    command_freq: HashMap<String, u32>,
26    /// Prefix trie for fast lookup
27    trie: Trie,
28    /// Total commands in training data
29    total_commands: usize,
30}
31
32/// Simple trie for prefix matching
33#[derive(Debug, Default)]
34struct Trie {
35    children: HashMap<char, Trie>,
36    is_end: bool,
37    command: Option<String>,
38}
39
40impl Trie {
41    fn new() -> Self {
42        Self::default()
43    }
44
45    fn insert(&mut self, word: &str) {
46        let mut node = self;
47        for c in word.chars() {
48            node = node.children.entry(c).or_default();
49        }
50        node.is_end = true;
51        node.command = Some(word.to_string());
52    }
53
54    fn find_prefix(&self, prefix: &str, limit: usize) -> Vec<String> {
55        let mut results = Vec::new();
56        let mut node = self;
57
58        // Navigate to prefix node
59        for c in prefix.chars() {
60            match node.children.get(&c) {
61                Some(child) => node = child,
62                None => return results,
63            }
64        }
65
66        // Collect all commands under this prefix
67        Self::collect_commands_recursive(node, &mut results, limit);
68        results
69    }
70
71    fn collect_commands_recursive(node: &Trie, results: &mut Vec<String>, limit: usize) {
72        if results.len() >= limit {
73            return;
74        }
75        if let Some(ref cmd) = node.command {
76            results.push(cmd.clone());
77        }
78        for child in node.children.values() {
79            Self::collect_commands_recursive(child, results, limit);
80            if results.len() >= limit {
81                return;
82            }
83        }
84    }
85}
86
87/// Serialized model format (bincode)
88#[derive(Debug, Serialize, Deserialize)]
89struct MarkovModelData {
90    n: usize,
91    ngrams: HashMap<String, HashMap<String, u32>>,
92    command_freq: HashMap<String, u32>,
93    total_commands: usize,
94    #[serde(default)]
95    last_trained_pos: usize,
96}
97
98/// Embedded model for testing and convenience
99const SHELL_MODEL_BYTES: &[u8] = include_bytes!("../../assets/aprender-shell-base.apr");
100
101impl ShellAutocomplete {
102    /// Create a new ShellAutocomplete with the embedded model.
103    ///
104    /// This is a convenience method for testing and demos that loads
105    /// the model compiled into the binary.
106    pub fn new() -> Result<Self, String> {
107        Self::load_from_bytes(SHELL_MODEL_BYTES)
108    }
109
110    /// Load ShellAutocomplete from raw .apr bytes.
111    /// This is the primary method for loading the model.
112    pub fn load_from_bytes(bytes: &[u8]) -> Result<Self, String> {
113        // Verify magic bytes and minimum size
114        if bytes.len() < HEADER_SIZE {
115            return Err("Model file too small".to_string());
116        }
117        if &bytes[0..4] != b"APRN" {
118            return Err(format!("Invalid magic bytes: {:?}", &bytes[0..4]));
119        }
120
121        // Parse 32-byte APR header
122        // Offset 0-3: Magic "APRN"
123        // Offset 4-5: Version (major, minor)
124        // Offset 6-7: Model type (u16 LE)
125        // Offset 8-11: Metadata size (u32 LE)
126        // Offset 12-15: Payload size (u32 LE)
127        // Offset 16-19: Uncompressed size (u32 LE)
128        // Offset 20: Compression type
129        // Offset 21: Flags
130        // Offset 22-31: Reserved
131
132        let metadata_size = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
133        let payload_size =
134            u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
135        let compression = bytes[20];
136
137        // Calculate offsets
138        let metadata_start = HEADER_SIZE;
139        let metadata_end = metadata_start + metadata_size;
140        let payload_start = metadata_end;
141        let payload_end = payload_start + payload_size;
142
143        if payload_end > bytes.len() {
144            return Err(format!(
145                "Payload extends beyond file: {} > {}",
146                payload_end,
147                bytes.len()
148            ));
149        }
150
151        let payload_compressed = &bytes[payload_start..payload_end];
152
153        // Decompress payload if needed
154        let payload_decompressed: Vec<u8> = match compression {
155            0x00 => payload_compressed.to_vec(), // No compression
156            #[cfg(feature = "shell-autocomplete")]
157            0x01 | 0x02 => {
158                // Zstd compression
159                zstd::decode_all(payload_compressed)
160                    .map_err(|e| format!("Failed to decompress: {}", e))?
161            }
162            #[cfg(not(feature = "shell-autocomplete"))]
163            0x01 | 0x02 => {
164                return Err(
165                    "Zstd compression requires the 'shell-autocomplete' feature".to_string()
166                );
167            }
168            _ => return Err(format!("Unknown compression type: 0x{:02X}", compression)),
169        };
170
171        // Deserialize the model data with bincode
172        let model_data: MarkovModelData = bincode::deserialize(&payload_decompressed)
173            .map_err(|e| format!("Failed to deserialize model: {}", e))?;
174
175        // Build trie from commands
176        let mut trie = Trie::new();
177        for cmd in model_data.command_freq.keys() {
178            trie.insert(cmd);
179        }
180
181        Ok(Self {
182            n: model_data.n,
183            ngrams: model_data.ngrams,
184            command_freq: model_data.command_freq,
185            trie,
186            total_commands: model_data.total_commands,
187        })
188    }
189
190    /// Suggest completions for a prefix
191    pub fn suggest(&self, prefix: &str, count: usize) -> Vec<(String, f32)> {
192        let prefix = prefix.trim();
193        let tokens: Vec<&str> = prefix.split_whitespace().collect();
194        let ends_with_space = prefix.is_empty() || prefix.ends_with(' ');
195
196        let capacity = count * 4;
197        let mut suggestions = Vec::with_capacity(capacity);
198        let mut seen = std::collections::HashSet::with_capacity(capacity);
199
200        // Strategy 1: Trie prefix match for exact commands
201        for cmd in self.trie.find_prefix(prefix, capacity) {
202            if Self::is_corrupted_command(&cmd) {
203                continue;
204            }
205            let freq = self.command_freq.get(&cmd).copied().unwrap_or(1);
206            let score = freq as f32 / self.total_commands.max(1) as f32;
207            seen.insert(cmd.clone());
208            suggestions.push((cmd, score));
209        }
210
211        // Strategy 2: N-gram prediction for next token (only when prefix ends with space)
212        if !tokens.is_empty() && ends_with_space {
213            let context_start = tokens.len().saturating_sub(self.n - 1);
214            let context = tokens[context_start..].join(" ");
215            let prefix_trimmed = prefix.trim();
216
217            if let Some(next_tokens) = self.ngrams.get(&context) {
218                let total: u32 = next_tokens.values().sum();
219                let mut completion = String::with_capacity(prefix_trimmed.len() + 32);
220
221                for (token, ngram_count) in next_tokens {
222                    completion.clear();
223                    completion.push_str(prefix_trimmed);
224                    completion.push(' ');
225                    completion.push_str(token);
226
227                    let score = *ngram_count as f32 / total as f32;
228
229                    if !seen.contains(&completion) {
230                        seen.insert(completion.clone());
231                        suggestions.push((completion.clone(), score * 0.8));
232                    }
233                }
234            }
235        }
236
237        // Strategy 3: N-gram prediction with partial token filter
238        if !tokens.is_empty() && !ends_with_space && tokens.len() >= 2 {
239            let partial_token = tokens.last().unwrap_or(&"");
240            let context_tokens = &tokens[..tokens.len() - 1];
241            let context_start = context_tokens.len().saturating_sub(self.n - 1);
242            let context = context_tokens[context_start..].join(" ");
243            let context_prefix = context_tokens.join(" ");
244
245            if let Some(next_tokens) = self.ngrams.get(&context) {
246                let total: u32 = next_tokens.values().sum();
247                let mut completion = String::with_capacity(context_prefix.len() + 32);
248
249                for (token, ngram_count) in next_tokens {
250                    if token.starts_with(partial_token) && !Self::is_corrupted_token(token) {
251                        completion.clear();
252                        completion.push_str(&context_prefix);
253                        completion.push(' ');
254                        completion.push_str(token);
255
256                        let score = *ngram_count as f32 / total as f32;
257
258                        if !seen.contains(&completion) {
259                            seen.insert(completion.clone());
260                            suggestions.push((completion.clone(), score * 0.9));
261                        }
262                    }
263                }
264            }
265        }
266
267        // If no prefix and no suggestions, return top commands
268        if prefix.is_empty() && suggestions.is_empty() {
269            let mut top_cmds: Vec<_> = self
270                .command_freq
271                .iter()
272                .map(|(k, v)| (k.clone(), *v as f32 / self.total_commands.max(1) as f32))
273                .collect();
274            top_cmds.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
275            suggestions = top_cmds;
276        }
277
278        // Sort by score and truncate
279        suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
280        suggestions.truncate(count);
281
282        suggestions
283    }
284
285    /// Detect corrupted commands
286    fn is_corrupted_command(cmd: &str) -> bool {
287        if cmd.contains("  ") {
288            return true;
289        }
290        if cmd.trim_end().ends_with('\\') {
291            return true;
292        }
293        cmd.split_whitespace().any(Self::is_corrupted_token)
294    }
295
296    /// Detect corrupted tokens
297    fn is_corrupted_token(token: &str) -> bool {
298        if let Some(dash_pos) = token.find('-') {
299            if dash_pos > 0 && dash_pos < token.len() - 1 {
300                let before = &token[..dash_pos];
301                let after = &token[dash_pos + 1..];
302                let subcommands = [
303                    "commit", "checkout", "clone", "push", "pull", "merge", "rebase", "status",
304                    "add", "build", "run", "test", "install",
305                ];
306                if subcommands.contains(&before) && (after.len() <= 2 || after.starts_with('-')) {
307                    return true;
308                }
309            }
310        }
311        false
312    }
313
314    /// Get JSON-formatted suggestions (for WASM interop)
315    pub fn suggest_json(&self, prefix: &str, count: usize) -> String {
316        let suggestions = self.suggest(prefix, count);
317        let items: Vec<_> = suggestions
318            .iter()
319            .map(|(text, score)| {
320                format!(
321                    r#"{{"text":"{}","score":{:.4}}}"#,
322                    text.replace('"', "\\\""),
323                    score
324                )
325            })
326            .collect();
327        format!(r#"{{"suggestions":[{}]}}"#, items.join(","))
328    }
329
330    /// Get model info as JSON
331    pub fn model_info_json(&self) -> String {
332        format!(
333            r#"{{"model_name":"aprender-shell-base","model_type":"ngram_lm","vocab_size":{},"ngram_size":{},"ngram_count":{},"total_commands":{}}}"#,
334            self.vocab_size(),
335            self.n,
336            self.ngram_count(),
337            self.total_commands
338        )
339    }
340
341    /// Vocabulary size (unique commands)
342    pub fn vocab_size(&self) -> usize {
343        self.command_freq.len()
344    }
345
346    /// N-gram count
347    pub fn ngram_count(&self) -> usize {
348        self.ngrams.values().map(HashMap::len).sum()
349    }
350
351    /// N-gram size
352    pub fn ngram_size(&self) -> usize {
353        self.n
354    }
355
356    /// Estimated memory usage
357    pub fn estimated_memory_bytes(&self) -> usize {
358        let ngram_size: usize = self
359            .ngrams
360            .iter()
361            .map(|(k, v)| k.len() + v.keys().map(|k2| k2.len() + 4).sum::<usize>())
362            .sum();
363        let vocab_size: usize = self.command_freq.keys().map(|k| k.len() + 4).sum();
364        ngram_size + vocab_size + std::mem::size_of::<Self>()
365    }
366}
367
368// ============================================================================ //
369// WASM EXPORTS - Browser-accessible API //
370// ============================================================================ //
371
372/// WASM-exported shell autocomplete demo
373#[cfg(target_arch = "wasm32")]
374#[wasm_bindgen]
375pub struct ShellAutocompleteDemo {
376    inner: ShellAutocomplete,
377}
378
379#[cfg(target_arch = "wasm32")]
380#[wasm_bindgen]
381impl ShellAutocompleteDemo {
382    /// Create a new ShellAutocompleteDemo from bytes fetched by JavaScript.
383    ///
384    /// This is the preferred constructor for dynamic model loading:
385    /// ```js
386    /// const response = await fetch('./models/shell.apr');
387    /// const bytes = new Uint8Array(await response.arrayBuffer());
388    /// const demo = ShellAutocompleteDemo.from_bytes(bytes);
389    /// ```
390    #[wasm_bindgen(js_name = "fromBytes")]
391    pub fn from_bytes(bytes: &[u8]) -> Result<ShellAutocompleteDemo, JsValue> {
392        console_error_panic_hook::set_once();
393
394        let inner =
395            ShellAutocomplete::load_from_bytes(bytes).map_err(|e| JsValue::from_str(e.as_str()))?;
396
397        web_sys::console::log_1(
398            &format!(
399                "ShellAutocomplete loaded from bytes: {} commands, {} n-grams",
400                inner.vocab_size(),
401                inner.ngram_count()
402            )
403            .into(),
404        );
405
406        Ok(Self { inner })
407    }
408
409    /// Create with embedded model (for demos/testing).
410    ///
411    /// Uses the model compiled into the WASM binary.
412    /// This constructor is primarily for testing and quick demos where the model
413    /// is hardcoded into the WASM bundle via `include_bytes!`.
414    #[wasm_bindgen(constructor)]
415    pub fn new() -> Result<ShellAutocompleteDemo, JsValue> {
416        console_error_panic_hook::set_once();
417
418        let inner = ShellAutocomplete::new().map_err(|e| JsValue::from_str(&e))?;
419
420        web_sys::console::log_1(
421            &format!(
422                "ShellAutocomplete loaded (embedded): {} commands, {} n-grams",
423                inner.vocab_size(),
424                inner.ngram_count()
425            )
426            .into(),
427        );
428
429        Ok(Self { inner })
430    }
431
432    /// Get suggestions for a prefix (returns JSON)
433    #[wasm_bindgen]
434    pub fn suggest(&self, prefix: &str, count: usize) -> String {
435        self.inner.suggest_json(prefix, count)
436    }
437
438    /// Get model info as JSON
439    #[wasm_bindgen]
440    pub fn model_info(&self) -> String {
441        self.inner.model_info_json()
442    }
443
444    /// Get vocabulary size
445    pub fn vocab_size(&self) -> usize {
446        self.inner.vocab_size()
447    }
448
449    /// Get n-gram count
450    pub fn ngram_count(&self) -> usize {
451        self.inner.ngram_count()
452    }
453
454    /// Get n-gram size (n value)
455    pub fn ngram_size(&self) -> usize {
456        self.inner.ngram_size()
457    }
458
459    /// Get estimated memory usage in bytes
460    pub fn memory_bytes(&self) -> usize {
461        self.inner.estimated_memory_bytes()
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_trie_basic() {
471        let mut trie = Trie::new();
472        trie.insert("git status");
473        trie.insert("git commit");
474        trie.insert("cargo build");
475
476        let results = trie.find_prefix("git", 10);
477        assert_eq!(results.len(), 2);
478    }
479
480    #[test]
481    fn test_corrupted_detection() {
482        assert!(ShellAutocomplete::is_corrupted_command("git commit-m"));
483        assert!(!ShellAutocomplete::is_corrupted_command("git commit -m"));
484        assert!(!ShellAutocomplete::is_corrupted_command(
485            "git checkout feature-branch"
486        ));
487    }
488}