const CORRECTIONS: &[(&str, &str)] = &[
("git stauts", "git status"),
("git comit", "git commit"),
("git psuh", "git push"),
("git pul", "git pull"),
("dokcer", "docker"),
("kubeclt", "kubectl"),
("carg ", "cargo "),
("pytohn", "python"),
("ndoe", "node"),
("yran", "yarn"),
];
const KNOWN_TOOLS: &[&str] = &[
"git",
"cargo",
"docker",
"kubectl",
"npm",
"yarn",
"make",
"aws",
"gcloud",
"az",
"terraform",
"ansible",
"helm",
"python",
"node",
"go",
"rustc",
"gcc",
"clang",
];
pub fn validate_suggestion(suggestion: &str, prefix: &str) -> Option<String> {
if !suggestion.starts_with(prefix) {
return None;
}
let normalized: String = suggestion.split_whitespace().collect::<Vec<_>>().join(" ");
if normalized.len() <= prefix.trim().len() {
return None;
}
if normalized.ends_with('\\') || normalized.ends_with('|') || normalized.ends_with('&') {
return None;
}
if normalized.contains(" ") {
return None;
}
if normalized.chars().any(|c| c.is_control() && c != '\t') {
return None;
}
Some(normalized)
}
pub fn apply_typo_corrections(suggestion: &str) -> String {
let mut corrected = suggestion.to_string();
for (typo, fix) in CORRECTIONS {
if corrected.contains(typo) {
corrected = corrected.replace(typo, fix);
}
}
corrected
}
pub fn suggestion_quality_score(suggestion: &str) -> f32 {
let first_word = suggestion.split_whitespace().next().unwrap_or("");
let mut score = if KNOWN_TOOLS.contains(&first_word) {
1.0_f32
} else {
0.8_f32 };
if suggestion.len() < 5 {
score *= 0.5;
}
let unusual_char_count = suggestion
.chars()
.filter(|c| !c.is_alphanumeric() && !" -_./=:@".contains(*c))
.count();
score *= 1.0 - (unusual_char_count as f32 * 0.1).min(0.5);
if suggestion.ends_with(' ') || suggestion.ends_with('-') {
score *= 0.7;
}
score.clamp(0.0, 1.0)
}
pub fn filter_quality_suggestions(
suggestions: Vec<(String, f32)>,
prefix: &str,
min_quality: f32,
) -> Vec<(String, f32)> {
suggestions
.into_iter()
.filter_map(|(suggestion, model_score)| {
let validated = validate_suggestion(&suggestion, prefix)?;
let corrected = apply_typo_corrections(&validated);
let quality = suggestion_quality_score(&corrected);
if quality < min_quality {
return None;
}
let combined_score = model_score * quality;
Some((corrected, combined_score))
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_suggestion_accepted() {
assert!(validate_suggestion("git status", "git").is_some());
assert!(validate_suggestion("cargo build --release", "cargo").is_some());
}
#[test]
fn test_suggestion_not_longer_than_prefix() {
assert!(validate_suggestion("git", "git").is_none());
assert!(validate_suggestion("git", "git ").is_none());
}
#[test]
fn test_suggestion_must_start_with_prefix() {
assert!(validate_suggestion("cargo build", "git").is_none());
}
#[test]
fn test_trailing_backslash_rejected() {
assert!(validate_suggestion("git status \\", "git").is_none());
}
#[test]
fn test_trailing_pipe_rejected() {
assert!(validate_suggestion("git status |", "git").is_none());
}
#[test]
fn test_trailing_ampersand_rejected() {
assert!(validate_suggestion("git status &", "git").is_none());
}
#[test]
fn test_whitespace_normalized() {
let result = validate_suggestion("git status -v", "git").unwrap();
assert_eq!(result, "git status -v");
}
#[test]
fn test_control_chars_rejected() {
assert!(validate_suggestion("git\x07status", "git").is_none());
}
#[test]
fn test_git_typos_corrected() {
assert_eq!(apply_typo_corrections("git stauts"), "git status");
assert_eq!(
apply_typo_corrections("git comit -m 'test'"),
"git commit -m 'test'"
);
assert_eq!(apply_typo_corrections("git psuh origin"), "git push origin");
}
#[test]
fn test_tool_typos_corrected() {
assert_eq!(apply_typo_corrections("dokcer ps"), "docker ps");
assert_eq!(
apply_typo_corrections("kubeclt get pods"),
"kubectl get pods"
);
}
#[test]
fn test_no_false_corrections() {
assert_eq!(apply_typo_corrections("git status"), "git status");
assert_eq!(apply_typo_corrections("docker run"), "docker run");
}
#[test]
fn test_known_tool_higher_score() {
let git_score = suggestion_quality_score("git status");
let unknown_score = suggestion_quality_score("xyz status");
assert!(git_score > unknown_score);
}
#[test]
fn test_short_suggestion_lower_score() {
let short = suggestion_quality_score("git");
let long = suggestion_quality_score("git status --verbose");
assert!(short < long);
}
#[test]
fn test_unusual_chars_lower_score() {
let normal = suggestion_quality_score("git status");
let unusual = suggestion_quality_score("git !@#$%");
assert!(normal > unusual);
}
#[test]
fn test_incomplete_lower_score() {
let complete = suggestion_quality_score("git status");
let incomplete = suggestion_quality_score("git status ");
assert!(complete > incomplete);
}
#[test]
fn test_score_bounded_zero_to_one() {
assert!(suggestion_quality_score("git status") <= 1.0);
assert!(suggestion_quality_score("git status") >= 0.0);
assert!(suggestion_quality_score("!@#$%^&*") <= 1.0);
assert!(suggestion_quality_score("!@#$%^&*") >= 0.0);
}
#[test]
fn test_filter_quality_suggestions() {
let suggestions = vec![
("git status".to_string(), 0.9),
("git stauts".to_string(), 0.8), ("git".to_string(), 0.7), ("git commit".to_string(), 0.6),
];
let filtered = filter_quality_suggestions(suggestions, "git", 0.3);
assert!(filtered.iter().any(|(s, _)| s == "git status"));
assert!(filtered.iter().any(|(s, _)| s == "git commit"));
assert!(!filtered.iter().any(|(s, _)| s == "git"));
}
}