use chrono::Utc;
use regex::Regex;
use sqlx::SqlitePool;
use std::sync::OnceLock;
use tracing::{debug, info};
type PatternRow = (
i64,
String,
String,
i32,
i32,
Option<String>,
Option<String>,
);
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct CommandPattern {
pub id: i64,
pub pattern: String,
pub original_example: String,
pub approval_count: i32,
pub denial_count: i32,
pub last_approved_at: Option<String>,
pub last_denied_at: Option<String>,
}
impl CommandPattern {
pub fn confidence(&self) -> f32 {
let total = self.approval_count + self.denial_count;
if total == 0 {
return 0.0;
}
let approval_ratio = self.approval_count as f32 / total as f32;
let volume_factor = (total as f32 / 10.0).min(1.0);
approval_ratio * volume_factor
}
pub fn is_trusted(&self) -> bool {
self.approval_count >= 3 && self.confidence() >= 0.8
}
}
fn path_regex() -> &'static Regex {
static REGEX: OnceLock<Regex> = OnceLock::new();
REGEX.get_or_init(|| Regex::new(r"(?:^|[\s=])(/[\w./-]+|\.{1,2}/[\w./-]+)").unwrap())
}
fn url_regex() -> &'static Regex {
static REGEX: OnceLock<Regex> = OnceLock::new();
REGEX.get_or_init(|| Regex::new(r"https?://[^\s]+").unwrap())
}
fn number_arg_regex() -> &'static Regex {
static REGEX: OnceLock<Regex> = OnceLock::new();
REGEX.get_or_init(|| Regex::new(r"[0-9]{4,}").unwrap())
}
fn single_quoted_regex() -> &'static Regex {
static REGEX: OnceLock<Regex> = OnceLock::new();
REGEX.get_or_init(|| Regex::new(r"'[^']*'").unwrap())
}
fn double_quoted_regex() -> &'static Regex {
static REGEX: OnceLock<Regex> = OnceLock::new();
REGEX.get_or_init(|| Regex::new(r#""[^"]*""#).unwrap())
}
pub fn generalize_command(command: &str) -> String {
let mut pattern = command.to_string();
pattern = url_regex().replace_all(&pattern, "<url>").to_string();
pattern = path_regex().replace_all(&pattern, " <path>").to_string();
pattern = single_quoted_regex()
.replace_all(&pattern, "<string>")
.to_string();
pattern = double_quoted_regex()
.replace_all(&pattern, "<string>")
.to_string();
pattern = number_arg_regex()
.replace_all(&pattern, "<num>")
.to_string();
let mut prev_space = false;
pattern = pattern
.chars()
.filter(|c| {
if *c == ' ' {
if prev_space {
false
} else {
prev_space = true;
true
}
} else {
prev_space = false;
true
}
})
.collect();
pattern.trim().to_string()
}
pub fn pattern_similarity(command: &str, pattern: &str) -> f32 {
let generalized = generalize_command(command);
if generalized == pattern {
return 1.0;
}
let cmd_tokens: Vec<&str> = generalized.split_whitespace().collect();
let pat_tokens: Vec<&str> = pattern.split_whitespace().collect();
if cmd_tokens.is_empty() || pat_tokens.is_empty() {
return 0.0;
}
if cmd_tokens[0] != pat_tokens[0] {
return 0.0;
}
let mut matches = 1;
let min_len = cmd_tokens.len().min(pat_tokens.len());
for i in 1..min_len {
if cmd_tokens[i] == pat_tokens[i]
|| pat_tokens[i].starts_with('<')
|| cmd_tokens[i].starts_with('<')
{
matches += 1;
}
}
let max_len = cmd_tokens.len().max(pat_tokens.len());
matches as f32 / max_len as f32
}
pub async fn record_approval(pool: &SqlitePool, command: &str) -> anyhow::Result<()> {
let pattern = generalize_command(command);
let now = Utc::now().to_rfc3339();
let result = sqlx::query(
"UPDATE command_patterns SET approval_count = approval_count + 1, last_approved_at = ? WHERE pattern = ?"
)
.bind(&now)
.bind(&pattern)
.execute(pool)
.await?;
if result.rows_affected() == 0 {
sqlx::query(
"INSERT INTO command_patterns (pattern, original_example, approval_count, last_approved_at, created_at) VALUES (?, ?, 1, ?, ?)"
)
.bind(&pattern)
.bind(command)
.bind(&now)
.bind(&now)
.execute(pool)
.await?;
info!(pattern = %pattern, "Learned new command pattern from approval");
} else {
debug!(pattern = %pattern, "Updated existing pattern approval count");
}
Ok(())
}
pub async fn record_denial(pool: &SqlitePool, command: &str) -> anyhow::Result<()> {
let pattern = generalize_command(command);
let now = Utc::now().to_rfc3339();
let result = sqlx::query(
"UPDATE command_patterns SET denial_count = denial_count + 1, last_denied_at = ? WHERE pattern = ?"
)
.bind(&now)
.bind(&pattern)
.execute(pool)
.await?;
if result.rows_affected() == 0 {
sqlx::query(
"INSERT INTO command_patterns (pattern, original_example, approval_count, denial_count, last_denied_at, created_at) VALUES (?, ?, 0, 1, ?, ?)"
)
.bind(&pattern)
.bind(command)
.bind(&now)
.bind(&now)
.execute(pool)
.await?;
info!(pattern = %pattern, "Recorded denied command pattern");
}
Ok(())
}
pub async fn find_matching_pattern(
pool: &SqlitePool,
command: &str,
) -> anyhow::Result<Option<(CommandPattern, f32)>> {
let generalized = generalize_command(command);
let exact: Option<PatternRow> = sqlx::query_as(
"SELECT id, pattern, original_example, approval_count, denial_count, last_approved_at, last_denied_at FROM command_patterns WHERE pattern = ?"
)
.bind(&generalized)
.fetch_optional(pool)
.await?;
if let Some((
id,
pattern,
original_example,
approval_count,
denial_count,
last_approved_at,
last_denied_at,
)) = exact
{
return Ok(Some((
CommandPattern {
id,
pattern,
original_example,
approval_count,
denial_count,
last_approved_at,
last_denied_at,
},
1.0,
)));
}
let base_cmd = command.split_whitespace().next().unwrap_or("");
if base_cmd.is_empty() {
return Ok(None);
}
let candidates: Vec<PatternRow> = sqlx::query_as(
"SELECT id, pattern, original_example, approval_count, denial_count, last_approved_at, last_denied_at FROM command_patterns WHERE pattern LIKE ? ORDER BY approval_count DESC LIMIT 20"
)
.bind(format!("{}%", base_cmd))
.fetch_all(pool)
.await?;
let mut best_match: Option<(CommandPattern, f32)> = None;
for (
id,
pattern,
original_example,
approval_count,
denial_count,
last_approved_at,
last_denied_at,
) in candidates
{
let similarity = pattern_similarity(command, &pattern);
if similarity >= 0.7
&& (best_match.is_none() || similarity > best_match.as_ref().unwrap().1)
{
best_match = Some((
CommandPattern {
id,
pattern,
original_example,
approval_count,
denial_count,
last_approved_at,
last_denied_at,
},
similarity,
));
}
}
Ok(best_match)
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct PatternStats {
pub total_patterns: usize,
pub trusted_patterns: usize,
pub top_patterns: Vec<(String, i32)>,
}
#[allow(dead_code)]
pub async fn get_pattern_stats(pool: &SqlitePool) -> anyhow::Result<PatternStats> {
let total: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM command_patterns")
.fetch_one(pool)
.await?;
let trusted: (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM command_patterns WHERE approval_count >= 3")
.fetch_one(pool)
.await?;
let top_patterns: Vec<(String, i32)> = sqlx::query_as(
"SELECT pattern, approval_count FROM command_patterns WHERE approval_count >= 3 ORDER BY approval_count DESC LIMIT 10"
)
.fetch_all(pool)
.await?;
Ok(PatternStats {
total_patterns: total.0 as usize,
trusted_patterns: trusted.0 as usize,
top_patterns,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generalize_command() {
assert_eq!(generalize_command("cargo test"), "cargo test");
assert!(generalize_command("curl https://api.example.com").contains("<url>"));
assert!(generalize_command("kill 12345").contains("<num>"));
}
#[test]
fn test_pattern_similarity() {
assert_eq!(pattern_similarity("cargo test", "cargo test"), 1.0);
assert_eq!(pattern_similarity("cargo test", "npm test"), 0.0);
}
#[test]
fn test_confidence_calculation() {
let pattern = CommandPattern {
id: 1,
pattern: "cargo test".to_string(),
original_example: "cargo test".to_string(),
approval_count: 10,
denial_count: 0,
last_approved_at: None,
last_denied_at: None,
};
assert_eq!(pattern.confidence(), 1.0);
assert!(pattern.is_trusted());
}
}