aprender_present_lib/browser/
shell_autocomplete.rs1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[cfg(target_arch = "wasm32")]
12use wasm_bindgen::prelude::*;
13
14const HEADER_SIZE: usize = 32;
16
17#[derive(Debug)]
19pub struct ShellAutocomplete {
20 n: usize,
22 ngrams: HashMap<String, HashMap<String, u32>>,
24 command_freq: HashMap<String, u32>,
26 trie: Trie,
28 total_commands: usize,
30}
31
32#[derive(Debug, Default)]
34struct Trie {
35 children: HashMap<char, Trie>,
36 is_end: bool,
37 command: Option<String>,
38}
39
40impl Trie {
41 fn new() -> Self {
42 Self::default()
43 }
44
45 fn insert(&mut self, word: &str) {
46 let mut node = self;
47 for c in word.chars() {
48 node = node.children.entry(c).or_default();
49 }
50 node.is_end = true;
51 node.command = Some(word.to_string());
52 }
53
54 fn find_prefix(&self, prefix: &str, limit: usize) -> Vec<String> {
55 let mut results = Vec::new();
56 let mut node = self;
57
58 for c in prefix.chars() {
60 match node.children.get(&c) {
61 Some(child) => node = child,
62 None => return results,
63 }
64 }
65
66 Self::collect_commands_recursive(node, &mut results, limit);
68 results
69 }
70
71 fn collect_commands_recursive(node: &Trie, results: &mut Vec<String>, limit: usize) {
72 if results.len() >= limit {
73 return;
74 }
75 if let Some(ref cmd) = node.command {
76 results.push(cmd.clone());
77 }
78 for child in node.children.values() {
79 Self::collect_commands_recursive(child, results, limit);
80 if results.len() >= limit {
81 return;
82 }
83 }
84 }
85}
86
87#[derive(Debug, Serialize, Deserialize)]
89struct MarkovModelData {
90 n: usize,
91 ngrams: HashMap<String, HashMap<String, u32>>,
92 command_freq: HashMap<String, u32>,
93 total_commands: usize,
94 #[serde(default)]
95 last_trained_pos: usize,
96}
97
98const SHELL_MODEL_BYTES: &[u8] = include_bytes!("../../assets/aprender-shell-base.apr");
100
101impl ShellAutocomplete {
102 pub fn new() -> Result<Self, String> {
107 Self::load_from_bytes(SHELL_MODEL_BYTES)
108 }
109
110 pub fn load_from_bytes(bytes: &[u8]) -> Result<Self, String> {
113 if bytes.len() < HEADER_SIZE {
115 return Err("Model file too small".to_string());
116 }
117 if &bytes[0..4] != b"APRN" {
118 return Err(format!("Invalid magic bytes: {:?}", &bytes[0..4]));
119 }
120
121 let metadata_size = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
133 let payload_size =
134 u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
135 let compression = bytes[20];
136
137 let metadata_start = HEADER_SIZE;
139 let metadata_end = metadata_start + metadata_size;
140 let payload_start = metadata_end;
141 let payload_end = payload_start + payload_size;
142
143 if payload_end > bytes.len() {
144 return Err(format!(
145 "Payload extends beyond file: {} > {}",
146 payload_end,
147 bytes.len()
148 ));
149 }
150
151 let payload_compressed = &bytes[payload_start..payload_end];
152
153 let payload_decompressed: Vec<u8> = match compression {
155 0x00 => payload_compressed.to_vec(), #[cfg(feature = "shell-autocomplete")]
157 0x01 | 0x02 => {
158 zstd::decode_all(payload_compressed)
160 .map_err(|e| format!("Failed to decompress: {}", e))?
161 }
162 #[cfg(not(feature = "shell-autocomplete"))]
163 0x01 | 0x02 => {
164 return Err(
165 "Zstd compression requires the 'shell-autocomplete' feature".to_string()
166 );
167 }
168 _ => return Err(format!("Unknown compression type: 0x{:02X}", compression)),
169 };
170
171 let model_data: MarkovModelData = bincode::deserialize(&payload_decompressed)
173 .map_err(|e| format!("Failed to deserialize model: {}", e))?;
174
175 let mut trie = Trie::new();
177 for cmd in model_data.command_freq.keys() {
178 trie.insert(cmd);
179 }
180
181 Ok(Self {
182 n: model_data.n,
183 ngrams: model_data.ngrams,
184 command_freq: model_data.command_freq,
185 trie,
186 total_commands: model_data.total_commands,
187 })
188 }
189
190 pub fn suggest(&self, prefix: &str, count: usize) -> Vec<(String, f32)> {
192 let prefix = prefix.trim();
193 let tokens: Vec<&str> = prefix.split_whitespace().collect();
194 let ends_with_space = prefix.is_empty() || prefix.ends_with(' ');
195
196 let capacity = count * 4;
197 let mut suggestions = Vec::with_capacity(capacity);
198 let mut seen = std::collections::HashSet::with_capacity(capacity);
199
200 for cmd in self.trie.find_prefix(prefix, capacity) {
202 if Self::is_corrupted_command(&cmd) {
203 continue;
204 }
205 let freq = self.command_freq.get(&cmd).copied().unwrap_or(1);
206 let score = freq as f32 / self.total_commands.max(1) as f32;
207 seen.insert(cmd.clone());
208 suggestions.push((cmd, score));
209 }
210
211 if !tokens.is_empty() && ends_with_space {
213 let context_start = tokens.len().saturating_sub(self.n - 1);
214 let context = tokens[context_start..].join(" ");
215 let prefix_trimmed = prefix.trim();
216
217 if let Some(next_tokens) = self.ngrams.get(&context) {
218 let total: u32 = next_tokens.values().sum();
219 let mut completion = String::with_capacity(prefix_trimmed.len() + 32);
220
221 for (token, ngram_count) in next_tokens {
222 completion.clear();
223 completion.push_str(prefix_trimmed);
224 completion.push(' ');
225 completion.push_str(token);
226
227 let score = *ngram_count as f32 / total as f32;
228
229 if !seen.contains(&completion) {
230 seen.insert(completion.clone());
231 suggestions.push((completion.clone(), score * 0.8));
232 }
233 }
234 }
235 }
236
237 if !tokens.is_empty() && !ends_with_space && tokens.len() >= 2 {
239 let partial_token = tokens.last().unwrap_or(&"");
240 let context_tokens = &tokens[..tokens.len() - 1];
241 let context_start = context_tokens.len().saturating_sub(self.n - 1);
242 let context = context_tokens[context_start..].join(" ");
243 let context_prefix = context_tokens.join(" ");
244
245 if let Some(next_tokens) = self.ngrams.get(&context) {
246 let total: u32 = next_tokens.values().sum();
247 let mut completion = String::with_capacity(context_prefix.len() + 32);
248
249 for (token, ngram_count) in next_tokens {
250 if token.starts_with(partial_token) && !Self::is_corrupted_token(token) {
251 completion.clear();
252 completion.push_str(&context_prefix);
253 completion.push(' ');
254 completion.push_str(token);
255
256 let score = *ngram_count as f32 / total as f32;
257
258 if !seen.contains(&completion) {
259 seen.insert(completion.clone());
260 suggestions.push((completion.clone(), score * 0.9));
261 }
262 }
263 }
264 }
265 }
266
267 if prefix.is_empty() && suggestions.is_empty() {
269 let mut top_cmds: Vec<_> = self
270 .command_freq
271 .iter()
272 .map(|(k, v)| (k.clone(), *v as f32 / self.total_commands.max(1) as f32))
273 .collect();
274 top_cmds.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
275 suggestions = top_cmds;
276 }
277
278 suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
280 suggestions.truncate(count);
281
282 suggestions
283 }
284
285 fn is_corrupted_command(cmd: &str) -> bool {
287 if cmd.contains(" ") {
288 return true;
289 }
290 if cmd.trim_end().ends_with('\\') {
291 return true;
292 }
293 cmd.split_whitespace().any(Self::is_corrupted_token)
294 }
295
296 fn is_corrupted_token(token: &str) -> bool {
298 if let Some(dash_pos) = token.find('-') {
299 if dash_pos > 0 && dash_pos < token.len() - 1 {
300 let before = &token[..dash_pos];
301 let after = &token[dash_pos + 1..];
302 let subcommands = [
303 "commit", "checkout", "clone", "push", "pull", "merge", "rebase", "status",
304 "add", "build", "run", "test", "install",
305 ];
306 if subcommands.contains(&before) && (after.len() <= 2 || after.starts_with('-')) {
307 return true;
308 }
309 }
310 }
311 false
312 }
313
314 pub fn suggest_json(&self, prefix: &str, count: usize) -> String {
316 let suggestions = self.suggest(prefix, count);
317 let items: Vec<_> = suggestions
318 .iter()
319 .map(|(text, score)| {
320 format!(
321 r#"{{"text":"{}","score":{:.4}}}"#,
322 text.replace('"', "\\\""),
323 score
324 )
325 })
326 .collect();
327 format!(r#"{{"suggestions":[{}]}}"#, items.join(","))
328 }
329
330 pub fn model_info_json(&self) -> String {
332 format!(
333 r#"{{"model_name":"aprender-shell-base","model_type":"ngram_lm","vocab_size":{},"ngram_size":{},"ngram_count":{},"total_commands":{}}}"#,
334 self.vocab_size(),
335 self.n,
336 self.ngram_count(),
337 self.total_commands
338 )
339 }
340
341 pub fn vocab_size(&self) -> usize {
343 self.command_freq.len()
344 }
345
346 pub fn ngram_count(&self) -> usize {
348 self.ngrams.values().map(HashMap::len).sum()
349 }
350
351 pub fn ngram_size(&self) -> usize {
353 self.n
354 }
355
356 pub fn estimated_memory_bytes(&self) -> usize {
358 let ngram_size: usize = self
359 .ngrams
360 .iter()
361 .map(|(k, v)| k.len() + v.keys().map(|k2| k2.len() + 4).sum::<usize>())
362 .sum();
363 let vocab_size: usize = self.command_freq.keys().map(|k| k.len() + 4).sum();
364 ngram_size + vocab_size + std::mem::size_of::<Self>()
365 }
366}
367
368#[cfg(target_arch = "wasm32")]
374#[wasm_bindgen]
375pub struct ShellAutocompleteDemo {
376 inner: ShellAutocomplete,
377}
378
379#[cfg(target_arch = "wasm32")]
380#[wasm_bindgen]
381impl ShellAutocompleteDemo {
382 #[wasm_bindgen(js_name = "fromBytes")]
391 pub fn from_bytes(bytes: &[u8]) -> Result<ShellAutocompleteDemo, JsValue> {
392 console_error_panic_hook::set_once();
393
394 let inner =
395 ShellAutocomplete::load_from_bytes(bytes).map_err(|e| JsValue::from_str(e.as_str()))?;
396
397 web_sys::console::log_1(
398 &format!(
399 "ShellAutocomplete loaded from bytes: {} commands, {} n-grams",
400 inner.vocab_size(),
401 inner.ngram_count()
402 )
403 .into(),
404 );
405
406 Ok(Self { inner })
407 }
408
409 #[wasm_bindgen(constructor)]
415 pub fn new() -> Result<ShellAutocompleteDemo, JsValue> {
416 console_error_panic_hook::set_once();
417
418 let inner = ShellAutocomplete::new().map_err(|e| JsValue::from_str(&e))?;
419
420 web_sys::console::log_1(
421 &format!(
422 "ShellAutocomplete loaded (embedded): {} commands, {} n-grams",
423 inner.vocab_size(),
424 inner.ngram_count()
425 )
426 .into(),
427 );
428
429 Ok(Self { inner })
430 }
431
432 #[wasm_bindgen]
434 pub fn suggest(&self, prefix: &str, count: usize) -> String {
435 self.inner.suggest_json(prefix, count)
436 }
437
438 #[wasm_bindgen]
440 pub fn model_info(&self) -> String {
441 self.inner.model_info_json()
442 }
443
444 pub fn vocab_size(&self) -> usize {
446 self.inner.vocab_size()
447 }
448
449 pub fn ngram_count(&self) -> usize {
451 self.inner.ngram_count()
452 }
453
454 pub fn ngram_size(&self) -> usize {
456 self.inner.ngram_size()
457 }
458
459 pub fn memory_bytes(&self) -> usize {
461 self.inner.estimated_memory_bytes()
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_trie_basic() {
471 let mut trie = Trie::new();
472 trie.insert("git status");
473 trie.insert("git commit");
474 trie.insert("cargo build");
475
476 let results = trie.find_prefix("git", 10);
477 assert_eq!(results.len(), 2);
478 }
479
480 #[test]
481 fn test_corrupted_detection() {
482 assert!(ShellAutocomplete::is_corrupted_command("git commit-m"));
483 assert!(!ShellAutocomplete::is_corrupted_command("git commit -m"));
484 assert!(!ShellAutocomplete::is_corrupted_command(
485 "git checkout feature-branch"
486 ));
487 }
488}