use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
const HEADER_SIZE: usize = 32;
#[derive(Debug)]
pub struct ShellAutocomplete {
n: usize,
ngrams: HashMap<String, HashMap<String, u32>>,
command_freq: HashMap<String, u32>,
trie: Trie,
total_commands: usize,
}
#[derive(Debug, Default)]
struct Trie {
children: HashMap<char, Trie>,
is_end: bool,
command: Option<String>,
}
impl Trie {
fn new() -> Self {
Self::default()
}
fn insert(&mut self, word: &str) {
let mut node = self;
for c in word.chars() {
node = node.children.entry(c).or_default();
}
node.is_end = true;
node.command = Some(word.to_string());
}
fn find_prefix(&self, prefix: &str, limit: usize) -> Vec<String> {
let mut results = Vec::new();
let mut node = self;
for c in prefix.chars() {
match node.children.get(&c) {
Some(child) => node = child,
None => return results,
}
}
Self::collect_commands_recursive(node, &mut results, limit);
results
}
fn collect_commands_recursive(node: &Trie, results: &mut Vec<String>, limit: usize) {
if results.len() >= limit {
return;
}
if let Some(ref cmd) = node.command {
results.push(cmd.clone());
}
for child in node.children.values() {
Self::collect_commands_recursive(child, results, limit);
if results.len() >= limit {
return;
}
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct MarkovModelData {
n: usize,
ngrams: HashMap<String, HashMap<String, u32>>,
command_freq: HashMap<String, u32>,
total_commands: usize,
#[serde(default)]
last_trained_pos: usize,
}
const SHELL_MODEL_BYTES: &[u8] = include_bytes!("../../assets/aprender-shell-base.apr");
impl ShellAutocomplete {
pub fn new() -> Result<Self, String> {
Self::load_from_bytes(SHELL_MODEL_BYTES)
}
pub fn load_from_bytes(bytes: &[u8]) -> Result<Self, String> {
if bytes.len() < HEADER_SIZE {
return Err("Model file too small".to_string());
}
if &bytes[0..4] != b"APRN" {
return Err(format!("Invalid magic bytes: {:?}", &bytes[0..4]));
}
let metadata_size = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
let payload_size =
u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
let compression = bytes[20];
let metadata_start = HEADER_SIZE;
let metadata_end = metadata_start + metadata_size;
let payload_start = metadata_end;
let payload_end = payload_start + payload_size;
if payload_end > bytes.len() {
return Err(format!(
"Payload extends beyond file: {} > {}",
payload_end,
bytes.len()
));
}
let payload_compressed = &bytes[payload_start..payload_end];
let payload_decompressed: Vec<u8> = match compression {
0x00 => payload_compressed.to_vec(), #[cfg(feature = "shell-autocomplete")]
0x01 | 0x02 => {
zstd::decode_all(payload_compressed)
.map_err(|e| format!("Failed to decompress: {}", e))?
}
#[cfg(not(feature = "shell-autocomplete"))]
0x01 | 0x02 => {
return Err(
"Zstd compression requires the 'shell-autocomplete' feature".to_string()
);
}
_ => return Err(format!("Unknown compression type: 0x{:02X}", compression)),
};
let model_data: MarkovModelData = bincode::deserialize(&payload_decompressed)
.map_err(|e| format!("Failed to deserialize model: {}", e))?;
let mut trie = Trie::new();
for cmd in model_data.command_freq.keys() {
trie.insert(cmd);
}
Ok(Self {
n: model_data.n,
ngrams: model_data.ngrams,
command_freq: model_data.command_freq,
trie,
total_commands: model_data.total_commands,
})
}
pub fn suggest(&self, prefix: &str, count: usize) -> Vec<(String, f32)> {
let prefix = prefix.trim();
let tokens: Vec<&str> = prefix.split_whitespace().collect();
let ends_with_space = prefix.is_empty() || prefix.ends_with(' ');
let capacity = count * 4;
let mut suggestions = Vec::with_capacity(capacity);
let mut seen = std::collections::HashSet::with_capacity(capacity);
for cmd in self.trie.find_prefix(prefix, capacity) {
if Self::is_corrupted_command(&cmd) {
continue;
}
let freq = self.command_freq.get(&cmd).copied().unwrap_or(1);
let score = freq as f32 / self.total_commands.max(1) as f32;
seen.insert(cmd.clone());
suggestions.push((cmd, score));
}
if !tokens.is_empty() && ends_with_space {
let context_start = tokens.len().saturating_sub(self.n - 1);
let context = tokens[context_start..].join(" ");
let prefix_trimmed = prefix.trim();
if let Some(next_tokens) = self.ngrams.get(&context) {
let total: u32 = next_tokens.values().sum();
let mut completion = String::with_capacity(prefix_trimmed.len() + 32);
for (token, ngram_count) in next_tokens {
completion.clear();
completion.push_str(prefix_trimmed);
completion.push(' ');
completion.push_str(token);
let score = *ngram_count as f32 / total as f32;
if !seen.contains(&completion) {
seen.insert(completion.clone());
suggestions.push((completion.clone(), score * 0.8));
}
}
}
}
if !tokens.is_empty() && !ends_with_space && tokens.len() >= 2 {
let partial_token = tokens.last().unwrap_or(&"");
let context_tokens = &tokens[..tokens.len() - 1];
let context_start = context_tokens.len().saturating_sub(self.n - 1);
let context = context_tokens[context_start..].join(" ");
let context_prefix = context_tokens.join(" ");
if let Some(next_tokens) = self.ngrams.get(&context) {
let total: u32 = next_tokens.values().sum();
let mut completion = String::with_capacity(context_prefix.len() + 32);
for (token, ngram_count) in next_tokens {
if token.starts_with(partial_token) && !Self::is_corrupted_token(token) {
completion.clear();
completion.push_str(&context_prefix);
completion.push(' ');
completion.push_str(token);
let score = *ngram_count as f32 / total as f32;
if !seen.contains(&completion) {
seen.insert(completion.clone());
suggestions.push((completion.clone(), score * 0.9));
}
}
}
}
}
if prefix.is_empty() && suggestions.is_empty() {
let mut top_cmds: Vec<_> = self
.command_freq
.iter()
.map(|(k, v)| (k.clone(), *v as f32 / self.total_commands.max(1) as f32))
.collect();
top_cmds.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
suggestions = top_cmds;
}
suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
suggestions.truncate(count);
suggestions
}
fn is_corrupted_command(cmd: &str) -> bool {
if cmd.contains(" ") {
return true;
}
if cmd.trim_end().ends_with('\\') {
return true;
}
cmd.split_whitespace().any(Self::is_corrupted_token)
}
fn is_corrupted_token(token: &str) -> bool {
if let Some(dash_pos) = token.find('-') {
if dash_pos > 0 && dash_pos < token.len() - 1 {
let before = &token[..dash_pos];
let after = &token[dash_pos + 1..];
let subcommands = [
"commit", "checkout", "clone", "push", "pull", "merge", "rebase", "status",
"add", "build", "run", "test", "install",
];
if subcommands.contains(&before) && (after.len() <= 2 || after.starts_with('-')) {
return true;
}
}
}
false
}
pub fn suggest_json(&self, prefix: &str, count: usize) -> String {
let suggestions = self.suggest(prefix, count);
let items: Vec<_> = suggestions
.iter()
.map(|(text, score)| {
format!(
r#"{{"text":"{}","score":{:.4}}}"#,
text.replace('"', "\\\""),
score
)
})
.collect();
format!(r#"{{"suggestions":[{}]}}"#, items.join(","))
}
pub fn model_info_json(&self) -> String {
format!(
r#"{{"model_name":"aprender-shell-base","model_type":"ngram_lm","vocab_size":{},"ngram_size":{},"ngram_count":{},"total_commands":{}}}"#,
self.vocab_size(),
self.n,
self.ngram_count(),
self.total_commands
)
}
pub fn vocab_size(&self) -> usize {
self.command_freq.len()
}
pub fn ngram_count(&self) -> usize {
self.ngrams.values().map(HashMap::len).sum()
}
pub fn ngram_size(&self) -> usize {
self.n
}
pub fn estimated_memory_bytes(&self) -> usize {
let ngram_size: usize = self
.ngrams
.iter()
.map(|(k, v)| k.len() + v.keys().map(|k2| k2.len() + 4).sum::<usize>())
.sum();
let vocab_size: usize = self.command_freq.keys().map(|k| k.len() + 4).sum();
ngram_size + vocab_size + std::mem::size_of::<Self>()
}
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
pub struct ShellAutocompleteDemo {
inner: ShellAutocomplete,
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl ShellAutocompleteDemo {
#[wasm_bindgen(js_name = "fromBytes")]
pub fn from_bytes(bytes: &[u8]) -> Result<ShellAutocompleteDemo, JsValue> {
console_error_panic_hook::set_once();
let inner =
ShellAutocomplete::load_from_bytes(bytes).map_err(|e| JsValue::from_str(e.as_str()))?;
web_sys::console::log_1(
&format!(
"ShellAutocomplete loaded from bytes: {} commands, {} n-grams",
inner.vocab_size(),
inner.ngram_count()
)
.into(),
);
Ok(Self { inner })
}
#[wasm_bindgen(constructor)]
pub fn new() -> Result<ShellAutocompleteDemo, JsValue> {
console_error_panic_hook::set_once();
let inner = ShellAutocomplete::new().map_err(|e| JsValue::from_str(&e))?;
web_sys::console::log_1(
&format!(
"ShellAutocomplete loaded (embedded): {} commands, {} n-grams",
inner.vocab_size(),
inner.ngram_count()
)
.into(),
);
Ok(Self { inner })
}
#[wasm_bindgen]
pub fn suggest(&self, prefix: &str, count: usize) -> String {
self.inner.suggest_json(prefix, count)
}
#[wasm_bindgen]
pub fn model_info(&self) -> String {
self.inner.model_info_json()
}
pub fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
pub fn ngram_count(&self) -> usize {
self.inner.ngram_count()
}
pub fn ngram_size(&self) -> usize {
self.inner.ngram_size()
}
pub fn memory_bytes(&self) -> usize {
self.inner.estimated_memory_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trie_basic() {
let mut trie = Trie::new();
trie.insert("git status");
trie.insert("git commit");
trie.insert("cargo build");
let results = trie.find_prefix("git", 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_corrupted_detection() {
assert!(ShellAutocomplete::is_corrupted_command("git commit-m"));
assert!(!ShellAutocomplete::is_corrupted_command("git commit -m"));
assert!(!ShellAutocomplete::is_corrupted_command(
"git checkout feature-branch"
));
}
}