#![allow(clippy::missing_const_for_fn)]
use crate::normalize::strip_wrapper_prefixes;
use regex::{Regex, escape as regex_escape};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.30;
const MAX_ALTERNATION_COUNT: usize = 10;
const HIGH_CONFIDENCE_MIN_FREQUENCY: usize = 10;
const MEDIUM_CONFIDENCE_MIN_FREQUENCY: usize = 5;
const PATH_CLUSTER_THRESHOLD: f32 = 0.7;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConfidenceTier {
High,
Medium,
Low,
}
impl ConfidenceTier {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::High => "high",
Self::Medium => "medium",
Self::Low => "low",
}
}
#[must_use]
pub const fn score(&self) -> f32 {
match self {
Self::High => 1.0,
Self::Medium => 0.6,
Self::Low => 0.3,
}
}
}
impl std::fmt::Display for ConfidenceTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RiskLevel {
Low,
Medium,
High,
}
impl RiskLevel {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
}
}
#[must_use]
pub const fn score(&self) -> f32 {
match self {
Self::Low => 0.2,
Self::Medium => 0.5,
Self::High => 0.9,
}
}
}
impl std::fmt::Display for RiskLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SuggestionReason {
HighFrequency,
PathClustered,
ManuallyBypassed,
SafePatternMatch,
}
impl SuggestionReason {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::HighFrequency => "high_frequency",
Self::PathClustered => "path_clustered",
Self::ManuallyBypassed => "manually_bypassed",
Self::SafePatternMatch => "safe_pattern_match",
}
}
#[must_use]
pub const fn description(&self) -> &'static str {
match self {
Self::HighFrequency => "Blocked many times across sessions",
Self::PathClustered => "Consistently blocked in specific directories",
Self::ManuallyBypassed => "Blocked but then allowed manually",
Self::SafePatternMatch => "Matches a known safe command pattern",
}
}
}
impl std::fmt::Display for SuggestionReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PathPattern {
pub pattern: String,
pub occurrence_count: usize,
pub is_project_dir: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AllowlistSuggestion {
pub cluster: CommandCluster,
pub confidence: ConfidenceTier,
pub risk: RiskLevel,
pub reason: SuggestionReason,
pub contributing_factors: Vec<SuggestionReason>,
pub path_patterns: Vec<PathPattern>,
pub suggest_path_specific: bool,
pub bypass_count: usize,
pub score: f32,
}
impl AllowlistSuggestion {
#[must_use]
pub fn from_cluster(cluster: CommandCluster) -> Self {
let confidence = calculate_confidence_tier(cluster.frequency, cluster.unique_count);
let risk = assess_risk_level(&cluster.commands);
let reason = determine_primary_reason(cluster.frequency, false, &[]);
let score = calculate_suggestion_score(confidence, risk);
Self {
cluster,
confidence,
risk,
reason,
contributing_factors: Vec::new(),
path_patterns: Vec::new(),
suggest_path_specific: false,
bypass_count: 0,
score,
}
}
pub fn with_path_analysis(mut self, working_dirs: &[String]) -> Self {
let (patterns, suggest_path_specific) = analyze_path_patterns(working_dirs);
self.path_patterns = patterns;
self.suggest_path_specific = suggest_path_specific;
if suggest_path_specific
&& !self
.contributing_factors
.contains(&SuggestionReason::PathClustered)
{
self.contributing_factors
.push(SuggestionReason::PathClustered);
if self.confidence == ConfidenceTier::Low {
self.confidence = ConfidenceTier::Medium;
}
}
self.score = calculate_suggestion_score(self.confidence, self.risk);
self
}
pub fn with_bypass_count(mut self, count: usize) -> Self {
self.bypass_count = count;
if count > 0 {
self.contributing_factors
.push(SuggestionReason::ManuallyBypassed);
self.confidence = ConfidenceTier::High;
self.reason = SuggestionReason::ManuallyBypassed;
}
self.score = calculate_suggestion_score(self.confidence, self.risk);
self
}
}
#[must_use]
pub fn calculate_confidence_tier(frequency: usize, unique_variants: usize) -> ConfidenceTier {
if frequency >= HIGH_CONFIDENCE_MIN_FREQUENCY {
let consistency_ratio = if unique_variants > 0 {
#[allow(clippy::cast_precision_loss)]
{
frequency as f32 / unique_variants as f32
}
} else {
0.0
};
if consistency_ratio >= 2.0 {
return ConfidenceTier::High;
}
return ConfidenceTier::Medium;
}
if frequency >= MEDIUM_CONFIDENCE_MIN_FREQUENCY {
return ConfidenceTier::Medium;
}
ConfidenceTier::Low
}
#[must_use]
pub fn assess_risk_level(commands: &[String]) -> RiskLevel {
for cmd in commands {
let lower = cmd.to_lowercase();
if lower.contains("rm -rf")
|| lower.contains("rmdir")
|| lower.contains("drop ")
|| lower.contains("truncate ")
|| lower.contains("delete ")
|| lower.contains("--force")
|| lower.contains("-f ")
|| lower.contains("reset --hard")
|| lower.contains("clean -f")
{
return RiskLevel::High;
}
if lower.contains("rm ")
|| lower.contains("git reset")
|| lower.contains("git checkout")
|| lower.contains("git restore")
|| lower.contains("docker rm")
|| lower.contains("docker rmi")
|| lower.contains("kubectl delete")
|| lower.starts_with("sudo ")
{
return RiskLevel::Medium;
}
}
RiskLevel::Low
}
#[must_use]
pub fn determine_primary_reason(
frequency: usize,
has_bypasses: bool,
path_patterns: &[PathPattern],
) -> SuggestionReason {
if has_bypasses {
return SuggestionReason::ManuallyBypassed;
}
if !path_patterns.is_empty() {
let total_occurrences: usize = path_patterns.iter().map(|p| p.occurrence_count).sum();
if let Some(top_pattern) = path_patterns.first() {
#[allow(clippy::cast_precision_loss)]
let concentration =
top_pattern.occurrence_count as f32 / total_occurrences.max(1) as f32;
if concentration >= PATH_CLUSTER_THRESHOLD {
return SuggestionReason::PathClustered;
}
}
}
if frequency >= MEDIUM_CONFIDENCE_MIN_FREQUENCY {
return SuggestionReason::HighFrequency;
}
SuggestionReason::HighFrequency
}
#[must_use]
pub fn calculate_suggestion_score(confidence: ConfidenceTier, risk: RiskLevel) -> f32 {
let confidence_score = confidence.score();
let risk_penalty = risk.score();
(confidence_score * (1.0 - 0.4 * risk_penalty)).clamp(0.0, 1.0)
}
#[must_use]
pub fn analyze_path_patterns(working_dirs: &[String]) -> (Vec<PathPattern>, bool) {
if working_dirs.is_empty() {
return (Vec::new(), false);
}
let mut dir_counts: HashMap<&str, usize> = HashMap::new();
for dir in working_dirs {
*dir_counts.entry(dir.as_str()).or_insert(0) += 1;
}
let mut prefix_counts: HashMap<String, usize> = HashMap::new();
for dir in working_dirs {
let components: Vec<&str> = dir.split('/').filter(|s| !s.is_empty()).collect();
for len in 1..=components.len().min(5) {
let prefix = format!("/{}", components[..len].join("/"));
*prefix_counts.entry(prefix).or_insert(0) += 1;
}
}
let total_count = working_dirs.len();
let mut patterns: Vec<PathPattern> = prefix_counts
.into_iter()
.filter(|(_, count)| {
#[allow(clippy::cast_precision_loss)]
let coverage = *count as f32 / total_count as f32;
coverage >= 0.5 })
.map(|(prefix, count)| {
let is_project_dir = is_likely_project_dir(&prefix);
PathPattern {
pattern: prefix,
occurrence_count: count,
is_project_dir,
}
})
.collect();
patterns.sort_by(|a, b| {
b.occurrence_count
.cmp(&a.occurrence_count)
.then_with(|| b.pattern.len().cmp(&a.pattern.len()))
});
let mut seen_prefixes: HashSet<String> = HashSet::new();
patterns.retain(|p| {
for seen in &seen_prefixes {
if p.pattern.starts_with(seen.as_str()) || seen.starts_with(&p.pattern) {
return false;
}
}
seen_prefixes.insert(p.pattern.clone());
true
});
patterns.truncate(3);
#[allow(clippy::cast_precision_loss)]
let suggest_path_specific = patterns.first().is_some_and(|p| {
let coverage = p.occurrence_count as f32 / total_count as f32;
coverage >= PATH_CLUSTER_THRESHOLD && p.is_project_dir
});
(patterns, suggest_path_specific)
}
#[must_use]
fn is_likely_project_dir(path: &str) -> bool {
let project_indicators = [
"/home/",
"/Users/",
"/data/projects/",
"/workspace/",
"/repo/",
"/repos/",
"/src/",
"/code/",
"/projects/",
];
let non_project_paths = [
"/tmp", "/var/tmp", "/etc", "/usr", "/bin", "/sbin", "/root", "/",
];
for non_project in non_project_paths {
if path == non_project || path.starts_with(&format!("{non_project}/")) {
return false;
}
}
for indicator in project_indicators {
if path.starts_with(indicator) || path.contains(indicator) {
return true;
}
}
let path_lower = path.to_lowercase();
path_lower.contains("project")
|| path_lower.contains("workspace")
|| path_lower.contains("repo")
|| path_lower.contains("/src/")
|| path_lower.contains("/code/")
}
#[derive(Debug, Clone)]
pub struct CommandEntryInfo {
pub command: String,
pub working_dir: String,
pub was_bypassed: bool,
}
#[must_use]
pub fn generate_enhanced_suggestions(
entries: &[CommandEntryInfo],
min_frequency: usize,
) -> Vec<AllowlistSuggestion> {
if entries.is_empty() {
return Vec::new();
}
let mut command_data: HashMap<String, (usize, Vec<String>, usize)> = HashMap::new();
for entry in entries {
let data = command_data
.entry(entry.command.clone())
.or_insert((0, Vec::new(), 0));
data.0 += 1; data.1.push(entry.working_dir.clone()); if entry.was_bypassed {
data.2 += 1; }
}
let commands: Vec<(String, usize)> = command_data
.iter()
.filter(|(_, (freq, _, _))| *freq >= min_frequency)
.map(|(cmd, (freq, _, _))| (cmd.clone(), *freq))
.collect();
if commands.is_empty() {
return Vec::new();
}
let clusters = cluster_denied_commands(&commands, 1);
let mut suggestions: Vec<AllowlistSuggestion> = clusters
.into_iter()
.map(|cluster| {
let working_dirs: Vec<String> = cluster
.commands
.iter()
.filter_map(|cmd| command_data.get(cmd))
.flat_map(|(_, dirs, _)| dirs.clone())
.collect();
let bypass_count: usize = cluster
.commands
.iter()
.filter_map(|cmd| command_data.get(cmd))
.map(|(_, _, bypasses)| *bypasses)
.sum();
AllowlistSuggestion::from_cluster(cluster)
.with_path_analysis(&working_dirs)
.with_bypass_count(bypass_count)
})
.collect();
suggestions.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
suggestions
}
#[must_use]
pub fn filter_by_confidence(
suggestions: Vec<AllowlistSuggestion>,
tier: ConfidenceTier,
) -> Vec<AllowlistSuggestion> {
suggestions
.into_iter()
.filter(|s| s.confidence == tier)
.collect()
}
#[must_use]
pub fn filter_by_risk(
suggestions: Vec<AllowlistSuggestion>,
level: RiskLevel,
) -> Vec<AllowlistSuggestion> {
suggestions
.into_iter()
.filter(|s| s.risk == level)
.collect()
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CommandCluster {
pub commands: Vec<String>,
pub normalized: Vec<String>,
pub proposed_pattern: String,
pub frequency: usize,
pub unique_count: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GeneratedPattern {
pub regex: String,
pub specificity_score: f32,
pub matches_all: bool,
pub example_matches: Vec<String>,
}
#[must_use]
pub fn generate_pattern_from_cluster(commands: &[String]) -> GeneratedPattern {
if commands.is_empty() {
return GeneratedPattern {
regex: String::new(),
specificity_score: 0.0,
matches_all: true,
example_matches: Vec::new(),
};
}
let unique_commands = deduplicate_commands(commands);
if unique_commands.len() == 1 {
let regex = format!("^{}$", regex_escape(&unique_commands[0]));
return GeneratedPattern {
regex,
specificity_score: 1.0,
matches_all: true,
example_matches: unique_commands,
};
}
let tokenized: Vec<Vec<&str>> = unique_commands
.iter()
.map(|cmd| cmd.split_whitespace().collect())
.collect();
let prefix_len = find_common_prefix_length(&tokenized);
let suffix_len = find_common_suffix_length(&tokenized, prefix_len);
let regex = build_segmented_pattern(&tokenized, prefix_len, suffix_len);
let (matches_all, example_matches) =
validate_pattern_against_commands(®ex, &unique_commands);
let specificity_score = calculate_pattern_specificity(®ex, unique_commands.len());
GeneratedPattern {
regex,
specificity_score,
matches_all,
example_matches,
}
}
fn deduplicate_commands(commands: &[String]) -> Vec<String> {
let mut seen = HashSet::new();
let mut result = Vec::with_capacity(commands.len());
for cmd in commands {
if seen.insert(cmd.clone()) {
result.push(cmd.clone());
}
}
result
}
fn find_common_prefix_length(tokenized: &[Vec<&str>]) -> usize {
if tokenized.is_empty() {
return 0;
}
let min_len = tokenized.iter().map(Vec::len).min().unwrap_or(0);
let first = &tokenized[0];
for (i, token) in first.iter().enumerate().take(min_len) {
if !tokenized.iter().all(|t| t.get(i) == Some(token)) {
return i;
}
}
min_len
}
fn find_common_suffix_length(tokenized: &[Vec<&str>], prefix_len: usize) -> usize {
if tokenized.is_empty() {
return 0;
}
let min_len = tokenized.iter().map(Vec::len).min().unwrap_or(0);
if min_len <= prefix_len {
return 0;
}
let first = &tokenized[0];
let first_len = first.len();
for i in 0..(min_len - prefix_len) {
let token = first[first_len - 1 - i];
let all_match = tokenized.iter().all(|t| {
let idx = t.len() - 1 - i;
t.get(idx) == Some(&token)
});
if !all_match {
return i;
}
}
min_len - prefix_len
}
fn build_segmented_pattern(
tokenized: &[Vec<&str>],
prefix_len: usize,
suffix_len: usize,
) -> String {
if tokenized.is_empty() {
return String::new();
}
let mut parts = Vec::new();
if prefix_len > 0 {
let prefix_tokens: Vec<&str> = tokenized[0][..prefix_len].to_vec();
for token in prefix_tokens {
parts.push(regex_escape(token));
}
}
let first = &tokenized[0];
let first_len = first.len();
let middle_start = prefix_len;
let middle_end = first_len.saturating_sub(suffix_len);
if middle_start < middle_end {
let mut middle_variants: Vec<String> = Vec::new();
let mut seen_middles = HashSet::new();
for tokens in tokenized {
let tokens_len = tokens.len();
let var_end = tokens_len.saturating_sub(suffix_len);
if middle_start < var_end {
let middle: Vec<&str> = tokens[middle_start..var_end].to_vec();
let middle_str = middle.join(" ");
if seen_middles.insert(middle_str.clone()) {
middle_variants.push(middle_str);
}
}
}
if !middle_variants.is_empty() {
if middle_variants.len() == 1 {
let escaped: Vec<String> = middle_variants[0]
.split_whitespace()
.map(regex_escape)
.collect();
parts.extend(escaped);
} else if middle_variants.len() <= MAX_ALTERNATION_COUNT {
let alternatives: Vec<String> = middle_variants
.iter()
.map(|v| {
v.split_whitespace()
.map(regex_escape)
.collect::<Vec<_>>()
.join(r"\s+")
})
.collect();
let mut sorted_alternatives = alternatives;
sorted_alternatives.sort();
let alternation = format!("(?:{})", sorted_alternatives.join("|"));
parts.push(alternation);
} else {
let pattern = build_conservative_variable_pattern(&middle_variants);
parts.push(pattern);
}
}
}
if suffix_len > 0 {
let suffix_start = first_len - suffix_len;
let suffix_tokens: Vec<&str> = first[suffix_start..].to_vec();
for token in suffix_tokens {
parts.push(regex_escape(token));
}
}
format!("^{}$", parts.join(r"\s+"))
}
fn build_conservative_variable_pattern(variants: &[String]) -> String {
let all_single_token = variants.iter().all(|v| !v.contains(' '));
if all_single_token {
let all_numeric = variants
.iter()
.all(|v| v.chars().all(|c| c.is_ascii_digit()));
let all_hex = variants.iter().all(|v| {
v.chars()
.all(|c| c.is_ascii_hexdigit() || c == '-' || c == '_')
});
let all_uuid_like = variants
.iter()
.all(|v| v.len() >= 32 && v.chars().all(|c| c.is_ascii_hexdigit() || c == '-'));
if all_numeric {
return r"\d+".to_string();
}
if all_uuid_like {
return r"[0-9a-fA-F-]{32,}".to_string();
}
if all_hex {
return r"[0-9a-fA-F_-]+".to_string();
}
return r"[\w.-]+".to_string();
}
r"[\w\s.-]+".to_string()
}
fn validate_pattern_against_commands(pattern: &str, commands: &[String]) -> (bool, Vec<String>) {
let Ok(regex) = Regex::new(pattern) else {
return (false, Vec::new());
};
let mut matches_all = true;
let mut example_matches = Vec::new();
for cmd in commands {
if regex.is_match(cmd) {
if example_matches.len() < 3 {
example_matches.push(cmd.clone());
}
} else {
matches_all = false;
}
}
(matches_all, example_matches)
}
fn calculate_pattern_specificity(pattern: &str, command_count: usize) -> f32 {
let mut score = 1.0_f32;
if pattern.contains(".*") {
score -= 0.4;
}
if pattern.contains(".+") {
score -= 0.3;
}
if pattern.contains(r"[\w\s") {
score -= 0.2;
}
if pattern.contains(r"[\w.-]+") {
score -= 0.15;
}
if pattern.contains(r"\d+") {
score -= 0.1;
}
if pattern.starts_with('^') && pattern.ends_with('$') {
score += 0.1;
}
let alternation_count = pattern.matches('|').count();
if alternation_count > 0 && alternation_count <= MAX_ALTERNATION_COUNT {
score += 0.1;
} else if alternation_count > MAX_ALTERNATION_COUNT {
score -= 0.1;
}
if pattern.len() < 10 {
score -= 0.2;
}
if command_count <= 5 {
score += 0.1;
}
score.clamp(0.0, 1.0)
}
#[derive(Debug, Clone)]
struct CommandRecord {
original: String,
normalized: String,
tokens: Vec<String>,
program: String,
count: usize,
}
#[derive(Debug, Clone)]
struct TempCluster {
records: Vec<CommandRecord>,
rep_tokens: Vec<String>,
}
impl TempCluster {
fn new(record: CommandRecord) -> Self {
Self {
rep_tokens: record.tokens.clone(),
records: vec![record],
}
}
fn add(&mut self, record: CommandRecord) {
self.records.push(record);
}
fn into_command_cluster(self) -> CommandCluster {
let mut commands = Vec::new();
let mut normalized = Vec::new();
let mut seen_commands = HashSet::new();
let mut seen_normalized = HashSet::new();
let mut frequency = 0_usize;
for record in &self.records {
frequency = frequency.saturating_add(record.count);
if seen_commands.insert(record.original.clone()) {
commands.push(record.original.clone());
}
if seen_normalized.insert(record.normalized.clone()) {
normalized.push(record.normalized.clone());
}
}
let proposed_pattern = build_proposed_pattern(&normalized);
let unique_count = normalized.len();
CommandCluster {
commands,
normalized,
proposed_pattern,
frequency,
unique_count,
}
}
}
#[must_use]
pub fn cluster_denied_commands(
commands: &[(String, usize)],
min_cluster_size: usize,
) -> Vec<CommandCluster> {
cluster_denied_commands_with_threshold(commands, min_cluster_size, DEFAULT_SIMILARITY_THRESHOLD)
}
fn cluster_denied_commands_with_threshold(
commands: &[(String, usize)],
min_cluster_size: usize,
similarity_threshold: f32,
) -> Vec<CommandCluster> {
if commands.is_empty() {
return Vec::new();
}
let mut records = Vec::with_capacity(commands.len());
for (command, count) in commands {
let normalized = normalize_for_clustering(command);
let tokens = tokenize_for_similarity(&normalized);
let program = tokens.first().cloned().unwrap_or_default();
records.push(CommandRecord {
original: command.clone(),
normalized,
tokens,
program,
count: *count,
});
}
let mut groups: HashMap<String, Vec<CommandRecord>> = HashMap::new();
for record in records {
groups
.entry(record.program.clone())
.or_default()
.push(record);
}
let mut clusters = Vec::new();
for (_program, group) in groups {
let mut temp_clusters: Vec<TempCluster> = Vec::new();
for record in group {
let mut record_opt = Some(record);
let mut placed = false;
for cluster in &mut temp_clusters {
let record_ref = record_opt.as_ref().expect("record should be present");
let similarity = jaccard_similarity(&cluster.rep_tokens, &record_ref.tokens);
if similarity >= similarity_threshold {
let record = record_opt.take().expect("record should be present");
cluster.add(record);
placed = true;
break;
}
}
if !placed {
let record = record_opt.take().expect("record should be present");
temp_clusters.push(TempCluster::new(record));
}
}
for cluster in temp_clusters {
if cluster.records.len() >= min_cluster_size {
clusters.push(cluster.into_command_cluster());
}
}
}
clusters.sort_by(|a, b| {
b.frequency
.cmp(&a.frequency)
.then_with(|| a.proposed_pattern.cmp(&b.proposed_pattern))
});
clusters
}
fn normalize_for_clustering(command: &str) -> String {
let stripped = strip_wrapper_prefixes(command);
collapse_whitespace(stripped.normalized.as_ref())
}
fn collapse_whitespace(input: &str) -> String {
let mut out = String::with_capacity(input.len());
let mut last_was_space = false;
for ch in input.chars() {
if ch.is_whitespace() {
if !last_was_space {
out.push(' ');
last_was_space = true;
}
} else {
out.push(ch);
last_was_space = false;
}
}
out.trim().to_string()
}
fn tokenize_for_similarity(command: &str) -> Vec<String> {
command
.split_whitespace()
.map(str::to_ascii_lowercase)
.collect()
}
fn jaccard_similarity(a: &[String], b: &[String]) -> f32 {
if a.is_empty() && b.is_empty() {
return 1.0;
}
let set_a: HashSet<&str> = a.iter().map(String::as_str).collect();
let set_b: HashSet<&str> = b.iter().map(String::as_str).collect();
if set_a.is_empty() && set_b.is_empty() {
return 1.0;
}
let intersection = u32::try_from(set_a.intersection(&set_b).count()).unwrap_or(u32::MAX);
let union = u32::try_from(set_a.union(&set_b).count()).unwrap_or(u32::MAX);
if union == 0 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
{
intersection as f32 / union as f32
}
}
}
fn build_proposed_pattern(commands: &[String]) -> String {
if commands.is_empty() {
return String::new();
}
let mut unique = Vec::new();
let mut seen = HashSet::new();
for cmd in commands {
if seen.insert(cmd.clone()) {
unique.push(cmd.clone());
}
}
if unique.len() == 1 {
return format!("^{}$", regex_escape(&unique[0]));
}
let mut parts = Vec::with_capacity(unique.len());
for cmd in unique {
parts.push(regex_escape(&cmd));
}
format!("^(?:{})$", parts.join("|"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn clusters_similar_commands_by_program() {
let input = vec![
("npm run build --production".to_string(), 10),
("npm run test --coverage".to_string(), 5),
("git status".to_string(), 2),
];
let clusters = cluster_denied_commands(&input, 2);
assert_eq!(clusters.len(), 1);
let cluster = &clusters[0];
assert_eq!(cluster.unique_count, 2);
assert!(cluster.proposed_pattern.contains("npm"));
assert!(cluster.proposed_pattern.contains("run"));
}
#[test]
fn respects_min_cluster_size() {
let input = vec![("git status".to_string(), 1), ("docker ps".to_string(), 1)];
let clusters = cluster_denied_commands(&input, 2);
assert!(clusters.is_empty());
}
#[test]
fn proposed_pattern_is_anchored_and_escaped() {
let input = vec![("echo foo|bar".to_string(), 3)];
let clusters = cluster_denied_commands(&input, 1);
assert_eq!(clusters.len(), 1);
let pattern = &clusters[0].proposed_pattern;
assert!(pattern.starts_with('^'));
assert!(pattern.ends_with('$'));
assert!(pattern.contains("\\|"));
}
#[test]
fn handles_empty_input() {
let input: Vec<(String, usize)> = vec![];
let clusters = cluster_denied_commands(&input, 1);
assert!(clusters.is_empty());
}
#[test]
fn handles_single_command() {
let input = vec![("git reset --hard".to_string(), 5)];
let clusters = cluster_denied_commands(&input, 1);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].unique_count, 1);
assert_eq!(clusters[0].frequency, 5);
assert!(clusters[0].proposed_pattern.starts_with('^'));
assert!(clusters[0].proposed_pattern.ends_with('$'));
}
#[test]
fn handles_all_different_programs() {
let input = vec![
("git status".to_string(), 1),
("npm install".to_string(), 1),
("docker ps".to_string(), 1),
];
let clusters = cluster_denied_commands(&input, 2);
assert!(
clusters.is_empty(),
"No clusters should form when all programs differ"
);
}
#[test]
fn strips_wrapper_prefixes_before_clustering() {
let input = vec![
("sudo git reset --hard".to_string(), 3),
("git reset --soft".to_string(), 2),
];
let clusters = cluster_denied_commands(&input, 2);
assert_eq!(clusters.len(), 1);
assert!(
clusters[0]
.normalized
.iter()
.all(|n| !n.starts_with("sudo"))
);
}
#[test]
fn accumulates_frequency_across_cluster() {
let input = vec![
("git reset --hard".to_string(), 10),
("git reset --soft".to_string(), 5),
("git reset --mixed".to_string(), 3),
];
let clusters = cluster_denied_commands(&input, 1);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].frequency, 18);
}
#[test]
fn deduplicates_identical_commands() {
let input = vec![("git status".to_string(), 5), ("git status".to_string(), 3)];
let clusters = cluster_denied_commands(&input, 1);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].unique_count, 1);
assert_eq!(clusters[0].frequency, 8);
}
#[test]
fn sorts_clusters_by_frequency_descending() {
let input = vec![
("npm run build".to_string(), 1),
("npm run test".to_string(), 1),
("git status".to_string(), 50),
("git log".to_string(), 50),
];
let clusters = cluster_denied_commands(&input, 2);
assert_eq!(clusters.len(), 2);
assert!(clusters[0].commands[0].starts_with("git"));
assert!(clusters[1].commands[0].starts_with("npm"));
}
#[test]
fn jaccard_similarity_identical_tokens() {
let a = vec!["git".to_string(), "reset".to_string(), "--hard".to_string()];
let b = vec!["git".to_string(), "reset".to_string(), "--hard".to_string()];
let similarity = jaccard_similarity(&a, &b);
assert!(
(similarity - 1.0).abs() < 0.001,
"Identical tokens should have similarity 1.0"
);
}
#[test]
fn jaccard_similarity_no_overlap() {
let a = vec!["git".to_string(), "status".to_string()];
let b = vec!["npm".to_string(), "install".to_string()];
let similarity = jaccard_similarity(&a, &b);
assert!(
(similarity - 0.0).abs() < 0.001,
"No overlap should have similarity 0.0"
);
}
#[test]
fn jaccard_similarity_empty_sets() {
let a: Vec<String> = vec![];
let b: Vec<String> = vec![];
let similarity = jaccard_similarity(&a, &b);
assert!(
(similarity - 1.0).abs() < 0.001,
"Empty sets should have similarity 1.0"
);
}
#[test]
fn proposed_pattern_alternation_for_multiple_commands() {
let input = vec![("echo hello".to_string(), 1), ("echo world".to_string(), 1)];
let clusters = cluster_denied_commands(&input, 2);
assert_eq!(clusters.len(), 1);
let pattern = &clusters[0].proposed_pattern;
assert!(pattern.contains("(?:"));
assert!(pattern.contains('|'));
}
#[test]
fn handles_commands_with_special_regex_chars() {
let input = vec![("echo $HOME".to_string(), 1), ("echo $PATH".to_string(), 1)];
let clusters = cluster_denied_commands(&input, 2);
assert_eq!(clusters.len(), 1);
let pattern = &clusters[0].proposed_pattern;
assert!(pattern.contains("\\$"));
}
#[test]
fn normalize_collapses_whitespace() {
let input = vec![
("git reset --hard".to_string(), 1),
("git reset --hard".to_string(), 1),
];
let clusters = cluster_denied_commands(&input, 1);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].unique_count, 1);
}
#[test]
fn generate_pattern_empty_input() {
let commands: Vec<String> = vec![];
let pattern = generate_pattern_from_cluster(&commands);
assert!(pattern.regex.is_empty());
assert!(pattern.matches_all);
assert!((pattern.specificity_score - 0.0).abs() < f32::EPSILON);
}
#[test]
fn generate_pattern_single_command() {
let commands = vec!["git status".to_string()];
let pattern = generate_pattern_from_cluster(&commands);
assert_eq!(pattern.regex, "^git status$");
assert!(pattern.matches_all);
assert!((pattern.specificity_score - 1.0).abs() < f32::EPSILON);
}
#[test]
fn generate_pattern_common_prefix() {
let commands = vec![
"npm run build".to_string(),
"npm run test".to_string(),
"npm run lint".to_string(),
];
let pattern = generate_pattern_from_cluster(&commands);
assert!(pattern.regex.starts_with("^npm"));
assert!(pattern.regex.contains("run"));
assert!(pattern.matches_all);
assert!(pattern.regex.contains('|'));
let re = Regex::new(&pattern.regex).unwrap();
for cmd in &commands {
assert!(re.is_match(cmd), "Pattern should match: {cmd}");
}
}
#[test]
fn generate_pattern_common_prefix_and_suffix() {
let commands = vec![
"docker run --rm alpine".to_string(),
"docker run --rm ubuntu".to_string(),
"docker run --rm debian".to_string(),
];
let pattern = generate_pattern_from_cluster(&commands);
assert!(pattern.matches_all);
let re = Regex::new(&pattern.regex).unwrap();
for cmd in &commands {
assert!(re.is_match(cmd), "Pattern should match: {cmd}");
}
}
#[test]
fn generate_pattern_does_not_match_destructive_variants() {
let commands = vec![
"npm run build".to_string(),
"npm run test".to_string(),
"npm run lint".to_string(),
];
let pattern = generate_pattern_from_cluster(&commands);
let re = Regex::new(&pattern.regex).unwrap();
assert!(
!re.is_match("rm -rf /"),
"Pattern should NOT match destructive commands"
);
assert!(
!re.is_match("npm run delete-everything"),
"Pattern should NOT match non-cluster commands"
);
}
#[test]
fn generate_pattern_handles_special_chars() {
let commands = vec![
"echo $HOME".to_string(),
"echo $PATH".to_string(),
"echo $USER".to_string(),
];
let pattern = generate_pattern_from_cluster(&commands);
let re = Regex::new(&pattern.regex);
assert!(re.is_ok(), "Pattern should be valid regex");
assert!(pattern.matches_all);
}
#[test]
fn generate_pattern_specificity_score() {
let exact = generate_pattern_from_cluster(&["git status".to_string()]);
assert!(
exact.specificity_score >= 0.9,
"Exact match should have high specificity"
);
let small = generate_pattern_from_cluster(&[
"npm run build".to_string(),
"npm run test".to_string(),
]);
assert!(
small.specificity_score >= 0.7,
"Small alternation should have good specificity"
);
}
#[test]
fn generate_pattern_deduplicates_commands() {
let commands = vec![
"git status".to_string(),
"git status".to_string(),
"git status".to_string(),
];
let pattern = generate_pattern_from_cluster(&commands);
assert_eq!(pattern.regex, "^git status$");
assert!(pattern.matches_all);
}
#[test]
fn generate_pattern_variable_segment_analysis() {
let numeric_commands = vec![
"fetch page 1".to_string(),
"fetch page 2".to_string(),
"fetch page 3".to_string(),
"fetch page 4".to_string(),
"fetch page 5".to_string(),
];
let pattern = generate_pattern_from_cluster(&numeric_commands);
assert!(pattern.matches_all);
let re = Regex::new(&pattern.regex).unwrap();
for cmd in &numeric_commands {
assert!(re.is_match(cmd), "Pattern should match: {cmd}");
}
}
#[test]
fn generate_pattern_anchored() {
let commands = vec!["npm run build".to_string(), "npm run test".to_string()];
let pattern = generate_pattern_from_cluster(&commands);
assert!(pattern.regex.starts_with('^'));
assert!(pattern.regex.ends_with('$'));
}
#[test]
fn generate_pattern_respects_max_alternation_count() {
let commands: Vec<String> = (0..15).map(|i| format!("cmd arg{i}")).collect();
let pattern = generate_pattern_from_cluster(&commands);
assert!(pattern.matches_all);
assert!(pattern.specificity_score < 1.0);
}
#[test]
fn common_prefix_length_calculation() {
let tokenized = vec![
vec!["npm", "run", "build"],
vec!["npm", "run", "test"],
vec!["npm", "run", "lint"],
];
let prefix_len = find_common_prefix_length(&tokenized);
assert_eq!(prefix_len, 2); }
#[test]
fn common_suffix_length_calculation() {
let tokenized = vec![
vec!["docker", "run", "--rm", "alpine"],
vec!["docker", "exec", "--rm", "alpine"],
];
let prefix_len = find_common_prefix_length(&tokenized);
let suffix_len = find_common_suffix_length(&tokenized, prefix_len);
assert_eq!(prefix_len, 1); assert_eq!(suffix_len, 2); }
}