use ahash::AHashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub const COMPLETION_BUDGET_MS: u64 = 50;
#[derive(Debug, Clone, PartialEq)]
pub struct CompletionItem {
pub text: String,
pub display: String,
pub kind: CompletionKind,
pub score: f32,
pub description: Option<String>,
}
impl CompletionItem {
#[must_use]
pub fn new(text: impl Into<String>, kind: CompletionKind) -> Self {
let text = text.into();
Self {
display: text.clone(),
text,
kind,
score: 1.0,
description: None,
}
}
#[must_use]
pub fn with_display(mut self, display: impl Into<String>) -> Self {
self.display = display.into();
self
}
#[must_use]
pub const fn with_score(mut self, score: f32) -> Self {
self.score = score;
self
}
#[must_use]
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CompletionKind {
Command,
File,
Directory,
Alias,
Variable,
Flag,
History,
Predicted,
Other,
}
#[derive(Debug, Clone)]
pub struct CompletionContext {
pub line: String,
pub cursor: usize,
pub word: String,
pub word_start: usize,
pub previous_words: Vec<String>,
pub cwd: PathBuf,
}
impl CompletionContext {
#[must_use]
pub fn from_line(line: &str, cursor: usize) -> Self {
let line = line.to_string();
let cursor = cursor.min(line.len());
let before_cursor = &line[..cursor];
let word_start = before_cursor
.rfind(|c: char| c.is_whitespace())
.map_or(0, |i| i + 1);
let word = before_cursor[word_start..].to_string();
let previous_words: Vec<String> = before_cursor[..word_start]
.split_whitespace()
.map(String::from)
.collect();
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/"));
Self {
line,
cursor,
word,
word_start,
previous_words,
cwd,
}
}
#[must_use]
pub fn is_command_position(&self) -> bool {
self.previous_words.is_empty()
}
}
pub trait CompletionProvider: Send + Sync {
fn complete(&self, ctx: &CompletionContext) -> Vec<CompletionItem>;
fn name(&self) -> &str;
fn priority(&self) -> i32 {
0
}
}
pub trait MlCompletionProvider: Send + Sync {
fn predict(&self, ctx: &CompletionContext) -> Vec<CompletionItem>;
fn is_ready(&self) -> bool;
fn model_name(&self) -> &str;
}
pub struct AliasCompleter {
aliases: Arc<AHashMap<String, String>>,
}
impl AliasCompleter {
#[must_use]
pub fn new(aliases: Arc<AHashMap<String, String>>) -> Self {
Self { aliases }
}
}
impl CompletionProvider for AliasCompleter {
fn complete(&self, ctx: &CompletionContext) -> Vec<CompletionItem> {
if !ctx.is_command_position() {
return Vec::new();
}
self.aliases
.iter()
.filter(|(name, _)| name.starts_with(&ctx.word))
.map(|(name, expansion)| {
CompletionItem::new(name.clone(), CompletionKind::Alias)
.with_description(format!("→ {expansion}"))
})
.collect()
}
fn name(&self) -> &str {
"aliases"
}
fn priority(&self) -> i32 {
10 }
}
pub struct EnvCompleter;
impl CompletionProvider for EnvCompleter {
fn complete(&self, ctx: &CompletionContext) -> Vec<CompletionItem> {
if !ctx.word.starts_with('$') {
return Vec::new();
}
let prefix = ctx.word.trim_start_matches('$').trim_start_matches('{');
std::env::vars()
.filter(|(name, _)| name.starts_with(prefix))
.take(50) .map(|(name, value)| {
let truncated = if value.len() > 30 {
format!("{}...", &value[..27])
} else {
value
};
CompletionItem::new(format!("${name}"), CompletionKind::Variable)
.with_description(truncated)
})
.collect()
}
fn name(&self) -> &str {
"env"
}
}
pub struct PathCompleter;
impl CompletionProvider for PathCompleter {
fn complete(&self, ctx: &CompletionContext) -> Vec<CompletionItem> {
let path = if ctx.word.is_empty() {
ctx.cwd.clone()
} else if ctx.word.starts_with('/') {
PathBuf::from(&ctx.word)
} else if ctx.word.starts_with('~') {
dirs::home_dir()
.map(|h| h.join(ctx.word.trim_start_matches("~/")))
.unwrap_or_else(|| PathBuf::from(&ctx.word))
} else {
ctx.cwd.join(&ctx.word)
};
let (dir, prefix) = if path.is_dir() {
(path.clone(), String::new())
} else {
let parent = path.parent().unwrap_or(&ctx.cwd);
let prefix = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("")
.to_string();
(parent.to_path_buf(), prefix)
};
let entries = match std::fs::read_dir(&dir) {
Ok(entries) => entries,
Err(_) => return Vec::new(),
};
entries
.filter_map(|e| e.ok())
.filter(|e| {
e.file_name()
.to_str()
.is_some_and(|n| prefix.is_empty() || n.starts_with(&prefix))
})
.take(100) .map(|entry| {
let name = entry.file_name().to_string_lossy().to_string();
let is_dir = entry.file_type().map(|t| t.is_dir()).unwrap_or(false);
let kind = if is_dir {
CompletionKind::Directory
} else {
CompletionKind::File
};
let display = if is_dir {
format!("{name}/")
} else {
name.clone()
};
CompletionItem::new(display.clone(), kind).with_display(display)
})
.collect()
}
fn name(&self) -> &str {
"paths"
}
}
pub struct HistoryCompleter {
history: Vec<String>,
}
impl HistoryCompleter {
#[must_use]
pub fn new(history: Vec<String>) -> Self {
Self { history }
}
pub fn add(&mut self, entry: String) {
self.history.push(entry);
}
}
impl CompletionProvider for HistoryCompleter {
fn complete(&self, ctx: &CompletionContext) -> Vec<CompletionItem> {
if ctx.word.is_empty() {
return Vec::new();
}
self.history
.iter()
.rev() .filter(|entry| entry.starts_with(&ctx.word))
.take(10)
.map(|entry| CompletionItem::new(entry.clone(), CompletionKind::History))
.collect()
}
fn name(&self) -> &str {
"history"
}
fn priority(&self) -> i32 {
5 }
}
pub struct AprenderShellCompleter {
model_path: Option<PathBuf>,
ready: bool,
cache: AHashMap<String, Vec<(String, f32)>>,
}
impl AprenderShellCompleter {
#[must_use]
pub fn new() -> Self {
Self {
model_path: None,
ready: false,
cache: AHashMap::new(),
}
}
pub fn load_model(&mut self, path: PathBuf) -> Result<(), String> {
if !path.exists() {
return Err(format!("Model not found: {}", path.display()));
}
if path.extension().and_then(|e| e.to_str()) != Some("apr") {
return Err(format!(
"Invalid model format: expected .apr, got {}",
path.display()
));
}
self.model_path = Some(path);
self.ready = true;
self.cache.clear();
Ok(())
}
#[must_use]
pub fn model_path(&self) -> Option<&PathBuf> {
self.model_path.as_ref()
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
impl Default for AprenderShellCompleter {
fn default() -> Self {
Self::new()
}
}
impl MlCompletionProvider for AprenderShellCompleter {
fn predict(&self, ctx: &CompletionContext) -> Vec<CompletionItem> {
if !self.ready {
return Vec::new();
}
let _prefix = if ctx.previous_words.is_empty() {
ctx.word.clone()
} else {
format!("{} {}", ctx.previous_words.join(" "), ctx.word)
};
let mut predictions = Vec::new();
if ctx.is_command_position() {
let common = [
("git", 0.95),
("cd", 0.90),
("ls", 0.88),
("cargo", 0.85),
("docker", 0.82),
("npm", 0.80),
("grep", 0.78),
("cat", 0.75),
("find", 0.72),
("make", 0.70),
];
for (cmd, base_score) in &common {
if cmd.starts_with(&ctx.word) {
predictions.push(
CompletionItem::new((*cmd).to_string(), CompletionKind::Command)
.with_description("aprender-shell")
.with_score(*base_score),
);
}
}
} else if !ctx.previous_words.is_empty() {
let last_cmd = &ctx.previous_words[0];
let suggestions: Vec<(&str, f32)> = match last_cmd.as_str() {
"git" => vec![
("status", 0.92),
("commit", 0.90),
("push", 0.88),
("pull", 0.86),
("checkout", 0.84),
("add", 0.82),
("branch", 0.80),
("log", 0.78),
("diff", 0.76),
("stash", 0.74),
],
"docker" => vec![
("ps", 0.90),
("images", 0.88),
("run", 0.86),
("build", 0.84),
("compose", 0.82),
("exec", 0.80),
("stop", 0.78),
],
"cargo" => vec![
("build", 0.92),
("test", 0.90),
("run", 0.88),
("clippy", 0.86),
("fmt", 0.84),
("check", 0.82),
("doc", 0.80),
],
"npm" => vec![
("install", 0.90),
("run", 0.88),
("test", 0.86),
("start", 0.84),
("build", 0.82),
],
"kubectl" => vec![
("get", 0.92),
("apply", 0.90),
("describe", 0.88),
("logs", 0.86),
("exec", 0.84),
("delete", 0.82),
],
_ => vec![],
};
for (sug, score) in suggestions {
if sug.starts_with(&ctx.word) {
predictions.push(
CompletionItem::new(sug.to_string(), CompletionKind::Predicted)
.with_description("aprender-shell")
.with_score(score),
);
}
}
}
predictions
}
fn is_ready(&self) -> bool {
self.ready
}
fn model_name(&self) -> &str {
"aprender-shell"
}
}
pub struct CompletionEngine {
providers: Vec<Box<dyn CompletionProvider>>,
ml_provider: Option<Box<dyn MlCompletionProvider>>,
#[allow(dead_code)] cache: AHashMap<String, Vec<CompletionItem>>,
#[allow(dead_code)] cache_timeout: Duration,
}
impl CompletionEngine {
#[must_use]
pub fn new() -> Self {
Self {
providers: Vec::new(),
ml_provider: None,
cache: AHashMap::new(),
cache_timeout: Duration::from_millis(100),
}
}
pub fn add_provider(&mut self, provider: impl CompletionProvider + 'static) {
self.providers.push(Box::new(provider));
self.providers
.sort_by(|a, b| b.priority().cmp(&a.priority()));
}
pub fn set_ml_provider(&mut self, provider: impl MlCompletionProvider + 'static) {
self.ml_provider = Some(Box::new(provider));
}
#[must_use]
pub fn complete(&self, line: &str, cursor: usize) -> Vec<CompletionItem> {
let start = Instant::now();
let ctx = CompletionContext::from_line(line, cursor);
let mut results = Vec::new();
for provider in &self.providers {
if start.elapsed() > Duration::from_millis(COMPLETION_BUDGET_MS) {
break; }
results.extend(provider.complete(&ctx));
}
if let Some(ml) = &self.ml_provider {
if ml.is_ready() && start.elapsed() < Duration::from_millis(COMPLETION_BUDGET_MS) {
let predictions = ml.predict(&ctx);
results.extend(predictions);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut seen = std::collections::HashSet::new();
results.retain(|item| seen.insert(item.text.clone()));
results
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
impl Default for CompletionEngine {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn default_engine(aliases: Arc<AHashMap<String, String>>) -> CompletionEngine {
let mut engine = CompletionEngine::new();
engine.add_provider(AliasCompleter::new(aliases));
engine.add_provider(EnvCompleter);
engine.add_provider(PathCompleter);
engine
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_completion_context_parsing() {
let ctx = CompletionContext::from_line("git sta", 7);
assert_eq!(ctx.word, "sta");
assert_eq!(ctx.previous_words, vec!["git"]);
assert!(!ctx.is_command_position());
}
#[test]
fn test_completion_context_command_position() {
let ctx = CompletionContext::from_line("gi", 2);
assert_eq!(ctx.word, "gi");
assert!(ctx.previous_words.is_empty());
assert!(ctx.is_command_position());
}
#[test]
fn test_alias_completer() {
let mut aliases = AHashMap::new();
aliases.insert("gs".to_string(), "git status".to_string());
aliases.insert("gp".to_string(), "git push".to_string());
aliases.insert("ll".to_string(), "ls -la".to_string());
let completer = AliasCompleter::new(Arc::new(aliases));
let ctx = CompletionContext::from_line("g", 1);
let results = completer.complete(&ctx);
assert_eq!(results.len(), 2); assert!(results.iter().any(|r| r.text == "gs"));
assert!(results.iter().any(|r| r.text == "gp"));
}
#[test]
fn test_alias_completer_not_command_position() {
let mut aliases = AHashMap::new();
aliases.insert("gs".to_string(), "git status".to_string());
let completer = AliasCompleter::new(Arc::new(aliases));
let ctx = CompletionContext::from_line("echo g", 6);
let results = completer.complete(&ctx);
assert!(results.is_empty()); }
#[test]
fn test_env_completer() {
let completer = EnvCompleter;
let ctx = CompletionContext::from_line("echo $PAT", 9);
let results = completer.complete(&ctx);
assert!(results.iter().any(|r| r.text == "$PATH"));
}
#[test]
fn test_env_completer_no_dollar() {
let completer = EnvCompleter;
let ctx = CompletionContext::from_line("echo PATH", 9);
let results = completer.complete(&ctx);
assert!(results.is_empty());
}
#[test]
fn test_completion_item_builder() {
let item = CompletionItem::new("test", CompletionKind::Command)
.with_display("Test Command")
.with_score(0.9)
.with_description("A test command");
assert_eq!(item.text, "test");
assert_eq!(item.display, "Test Command");
assert!((item.score - 0.9).abs() < f32::EPSILON);
assert_eq!(item.description, Some("A test command".to_string()));
}
#[test]
fn test_completion_engine() {
let mut aliases = AHashMap::new();
aliases.insert("gs".to_string(), "git status".to_string());
let engine = default_engine(Arc::new(aliases));
let results = engine.complete("g", 1);
assert!(results.iter().any(|r| r.text == "gs"));
}
#[test]
fn test_aprender_shell_stub() {
let completer = AprenderShellCompleter::new();
assert!(!completer.is_ready());
assert_eq!(completer.model_name(), "aprender-shell");
let ctx = CompletionContext::from_line("git", 3);
let results = completer.predict(&ctx);
assert!(results.is_empty()); }
#[test]
fn test_completion_performance() {
let mut aliases = AHashMap::new();
for i in 0..1000 {
aliases.insert(format!("alias{i}"), format!("command{i}"));
}
let engine = default_engine(Arc::new(aliases));
let start = Instant::now();
for _ in 0..100 {
let _ = engine.complete("alias", 5);
}
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(500),
"Completion too slow: {:?}",
elapsed
);
}
#[test]
fn test_history_completer() {
let history = vec![
"git status".to_string(),
"git push".to_string(),
"cargo build".to_string(),
];
let completer = HistoryCompleter::new(history);
let ctx = CompletionContext::from_line("git", 3);
let results = completer.complete(&ctx);
assert_eq!(results.len(), 2);
}
#[test]
fn test_completion_context_from_line_basic() {
let ctx = CompletionContext::from_line("ls -la", 6);
assert_eq!(ctx.line, "ls -la");
assert_eq!(ctx.cursor, 6);
assert_eq!(ctx.word, "-la");
assert_eq!(ctx.word_start, 3);
assert!(!ctx.previous_words.is_empty());
}
#[test]
fn test_completion_context_is_command_position() {
let ctx = CompletionContext::from_line("git", 3);
assert!(ctx.is_command_position());
let ctx2 = CompletionContext::from_line("git status", 10);
assert!(!ctx2.is_command_position());
}
#[test]
fn test_completion_context_cursor_beyond_line() {
let ctx = CompletionContext::from_line("hello", 100);
assert_eq!(ctx.cursor, 5); }
#[test]
fn test_completion_context_empty_line() {
let ctx = CompletionContext::from_line("", 0);
assert!(ctx.word.is_empty());
assert!(ctx.previous_words.is_empty());
assert!(ctx.is_command_position());
}
#[test]
fn test_completion_kind_debug() {
let kinds = [
CompletionKind::Command,
CompletionKind::File,
CompletionKind::Directory,
CompletionKind::Alias,
CompletionKind::Variable,
CompletionKind::Flag,
CompletionKind::History,
CompletionKind::Predicted,
CompletionKind::Other,
];
for kind in kinds {
let debug = format!("{:?}", kind);
assert!(!debug.is_empty());
}
}
#[test]
fn test_completion_kind_equality() {
assert_eq!(CompletionKind::Command, CompletionKind::Command);
assert_ne!(CompletionKind::Command, CompletionKind::File);
}
#[test]
fn test_completion_item_default_score() {
let item = CompletionItem::new("test", CompletionKind::Command);
assert!((item.score - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_completion_item_clone() {
let item = CompletionItem::new("test", CompletionKind::Alias).with_description("desc");
let cloned = item.clone();
assert_eq!(item.text, cloned.text);
assert_eq!(item.kind, cloned.kind);
}
#[test]
fn test_alias_completer_name() {
let aliases = AHashMap::new();
let completer = AliasCompleter::new(Arc::new(aliases));
assert_eq!(completer.name(), "aliases");
}
#[test]
fn test_alias_completer_priority() {
let aliases = AHashMap::new();
let completer = AliasCompleter::new(Arc::new(aliases));
assert_eq!(completer.priority(), 10);
}
#[test]
fn test_env_completer_name() {
let completer = EnvCompleter;
assert_eq!(completer.name(), "env");
}
#[test]
fn test_path_completer_name() {
let completer = PathCompleter;
assert_eq!(completer.name(), "paths");
}
#[test]
fn test_history_completer_name() {
let completer = HistoryCompleter::new(vec![]);
assert_eq!(completer.name(), "history");
}
#[test]
fn test_history_completer_priority() {
let completer = HistoryCompleter::new(vec![]);
assert_eq!(completer.priority(), 5);
}
#[test]
fn test_aprender_shell_load_model_not_found() {
let mut completer = AprenderShellCompleter::new();
let result = completer.load_model(PathBuf::from("/nonexistent/model.apr"));
assert!(result.is_err());
assert!(result.unwrap_err().contains("not found"));
assert!(!completer.is_ready());
}
#[test]
fn test_aprender_shell_load_model_wrong_extension() {
let mut completer = AprenderShellCompleter::new();
let result = completer.load_model(PathBuf::from("Cargo.toml"));
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid model format"));
}
#[test]
fn test_aprender_shell_default() {
let completer = AprenderShellCompleter::default();
assert!(!completer.is_ready());
assert!(completer.model_path().is_none());
}
#[test]
fn test_aprender_shell_predict_command_position() {
let mut completer = AprenderShellCompleter::new();
completer.ready = true;
let ctx = CompletionContext::from_line("gi", 2);
let predictions = completer.predict(&ctx);
assert!(!predictions.is_empty());
assert!(predictions.iter().any(|p| p.text == "git"));
}
#[test]
fn test_aprender_shell_predict_git_subcommands() {
let mut completer = AprenderShellCompleter::new();
completer.ready = true;
let ctx = CompletionContext::from_line("git st", 6);
let predictions = completer.predict(&ctx);
assert!(!predictions.is_empty());
assert!(
predictions
.iter()
.any(|p| p.text == "status" || p.text == "stash")
);
}
#[test]
fn test_aprender_shell_predict_cargo_subcommands() {
let mut completer = AprenderShellCompleter::new();
completer.ready = true;
let ctx = CompletionContext::from_line("cargo b", 7);
let predictions = completer.predict(&ctx);
assert!(!predictions.is_empty());
assert!(predictions.iter().any(|p| p.text == "build"));
}
#[test]
fn test_aprender_shell_predict_docker_subcommands() {
let mut completer = AprenderShellCompleter::new();
completer.ready = true;
let ctx = CompletionContext::from_line("docker p", 8);
let predictions = completer.predict(&ctx);
assert!(!predictions.is_empty());
assert!(predictions.iter().any(|p| p.text == "ps"));
}
#[test]
fn test_aprender_shell_predict_kubectl_subcommands() {
let mut completer = AprenderShellCompleter::new();
completer.ready = true;
let ctx = CompletionContext::from_line("kubectl g", 9);
let predictions = completer.predict(&ctx);
assert!(!predictions.is_empty());
assert!(predictions.iter().any(|p| p.text == "get"));
}
#[test]
fn test_aprender_shell_not_ready_returns_empty() {
let completer = AprenderShellCompleter::new();
assert!(!completer.is_ready());
let ctx = CompletionContext::from_line("git ", 4);
let predictions = completer.predict(&ctx);
assert!(predictions.is_empty());
}
#[test]
fn test_aprender_shell_clear_cache() {
let mut completer = AprenderShellCompleter::new();
completer.cache.insert("test".to_string(), vec![]);
assert!(!completer.cache.is_empty());
completer.clear_cache();
assert!(completer.cache.is_empty());
}
#[test]
fn test_aprender_shell_predictions_have_scores() {
let mut completer = AprenderShellCompleter::new();
completer.ready = true;
let ctx = CompletionContext::from_line("git ", 4);
let predictions = completer.predict(&ctx);
for pred in predictions {
assert!(pred.score > 0.0);
assert!(pred.score <= 1.0);
assert_eq!(pred.description.as_deref(), Some("aprender-shell"));
}
}
#[test]
fn test_completion_engine_add_provider() {
let mut engine = CompletionEngine::new();
let aliases = AHashMap::new();
engine.add_provider(AliasCompleter::new(Arc::new(aliases)));
}
#[test]
fn test_completion_engine_set_ml_provider() {
let mut engine = CompletionEngine::new();
engine.set_ml_provider(AprenderShellCompleter::new());
}
#[test]
fn test_path_completer_empty_word() {
let completer = PathCompleter;
let ctx = CompletionContext::from_line("cd ", 3);
let results = completer.complete(&ctx);
assert!(results.len() >= 0);
}
#[test]
fn test_path_completer_absolute_path() {
let completer = PathCompleter;
let ctx = CompletionContext::from_line("cd /", 4);
let results = completer.complete(&ctx);
assert!(!results.is_empty());
}
#[test]
fn test_path_completer_home_path() {
let completer = PathCompleter;
let ctx = CompletionContext::from_line("cd ~/", 5);
let results = completer.complete(&ctx);
assert!(results.len() >= 0);
}
#[test]
fn test_env_completer_with_brace() {
let completer = EnvCompleter;
let ctx = CompletionContext::from_line("echo ${PAT", 10);
let results = completer.complete(&ctx);
assert!(results.iter().any(|r| r.text.contains("PATH")));
}
#[test]
fn test_history_completer_empty_history() {
let completer = HistoryCompleter::new(vec![]);
let ctx = CompletionContext::from_line("git", 3);
let results = completer.complete(&ctx);
assert!(results.is_empty());
}
#[test]
fn test_history_completer_no_match() {
let history = vec!["cargo build".to_string()];
let completer = HistoryCompleter::new(history);
let ctx = CompletionContext::from_line("xyz", 3);
let results = completer.complete(&ctx);
assert!(results.is_empty());
}
#[test]
fn test_completion_context_debug() {
let ctx = CompletionContext::from_line("test", 4);
let debug = format!("{:?}", ctx);
assert!(debug.contains("test"));
}
#[test]
fn test_completion_context_clone() {
let ctx = CompletionContext::from_line("test", 4);
let cloned = ctx.clone();
assert_eq!(ctx.line, cloned.line);
assert_eq!(ctx.cursor, cloned.cursor);
}
}