use crate::context::{CommandSpans, SpanKind, classify_command};
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfidenceSignal {
ExecutedSpan,
InlineCodeSpan,
DataSpan,
ArgumentSpan,
CommentSpan,
HeredocBodySpan,
UnknownSpan,
SanitizedRegion,
ExecutionOperatorsNearby,
CommandPosition,
ArgumentPosition,
}
impl ConfidenceSignal {
#[must_use]
pub const fn weight(self) -> f32 {
match self {
Self::ExecutedSpan | Self::InlineCodeSpan => 1.0,
Self::CommandPosition | Self::ExecutionOperatorsNearby => 1.1, Self::DataSpan => 0.1,
Self::CommentSpan => 0.05,
Self::ArgumentSpan => 0.3,
Self::SanitizedRegion => 0.2,
Self::ArgumentPosition => 0.6,
Self::HeredocBodySpan => 0.7, Self::UnknownSpan => 0.8, }
}
#[must_use]
pub const fn description(self) -> &'static str {
match self {
Self::ExecutedSpan => "match is in executed code",
Self::InlineCodeSpan => "match is in inline code (bash -c, python -c, etc.)",
Self::DataSpan => "match is in a data string (single-quoted)",
Self::CommentSpan => "match is in a comment",
Self::ArgumentSpan => "match is in a string argument to a safe command",
Self::HeredocBodySpan => "match is in a heredoc body",
Self::UnknownSpan => "match context is ambiguous",
Self::SanitizedRegion => "match was in a region masked by sanitization",
Self::ExecutionOperatorsNearby => "execution operators (|, ;, &&) found nearby",
Self::CommandPosition => "match is at command position",
Self::ArgumentPosition => "match is in argument position",
}
}
}
#[derive(Debug, Clone)]
pub struct ConfidenceScore {
pub value: f32,
pub signals: SmallVec<[ConfidenceSignal; 4]>,
}
impl Default for ConfidenceScore {
fn default() -> Self {
Self::high()
}
}
impl ConfidenceScore {
#[must_use]
pub fn high() -> Self {
Self {
value: 1.0,
signals: SmallVec::new(),
}
}
#[must_use]
pub fn low(signal: ConfidenceSignal) -> Self {
let mut signals = SmallVec::new();
signals.push(signal);
Self {
value: signal.weight(),
signals,
}
}
pub fn add_signal(&mut self, signal: ConfidenceSignal) {
self.signals.push(signal);
self.value = (self.value * signal.weight()).clamp(0.0, 1.0);
}
#[must_use]
pub fn is_low(&self, threshold: f32) -> bool {
self.value < threshold
}
#[must_use]
pub fn should_warn(&self) -> bool {
self.is_low(DEFAULT_WARN_THRESHOLD)
}
}
pub const DEFAULT_WARN_THRESHOLD: f32 = 0.5;
pub struct ConfidenceContext<'a> {
pub command: &'a str,
pub sanitized_command: Option<&'a str>,
pub match_start: usize,
pub match_end: usize,
}
#[must_use]
pub fn compute_match_confidence(ctx: &ConfidenceContext<'_>) -> ConfidenceScore {
let mut score = ConfidenceScore::high();
if let Some(sanitized) = ctx.sanitized_command {
if ctx.match_start < sanitized.len()
&& ctx.match_end <= sanitized.len()
&& sanitized != ctx.command
{
let original_slice = ctx.command.get(ctx.match_start..ctx.match_end);
let sanitized_slice = sanitized.get(ctx.match_start..ctx.match_end);
if original_slice != sanitized_slice {
score.add_signal(ConfidenceSignal::SanitizedRegion);
}
}
}
let spans = classify_command(ctx.command);
let signal = classify_match_span(&spans, ctx.match_start, ctx.match_end);
score.add_signal(signal);
if has_execution_operators_nearby(ctx.command, ctx.match_start, ctx.match_end) {
score.add_signal(ConfidenceSignal::ExecutionOperatorsNearby);
}
if is_command_position(ctx.command, ctx.match_start) {
score.add_signal(ConfidenceSignal::CommandPosition);
} else {
score.add_signal(ConfidenceSignal::ArgumentPosition);
}
score
}
fn classify_match_span(
spans: &CommandSpans,
match_start: usize,
match_end: usize,
) -> ConfidenceSignal {
for span in spans.spans() {
if span.byte_range.start <= match_start && match_end <= span.byte_range.end {
return match span.kind {
SpanKind::Executed => ConfidenceSignal::ExecutedSpan,
SpanKind::InlineCode => ConfidenceSignal::InlineCodeSpan,
SpanKind::Data => ConfidenceSignal::DataSpan,
SpanKind::Argument => ConfidenceSignal::ArgumentSpan,
SpanKind::Comment => ConfidenceSignal::CommentSpan,
SpanKind::HeredocBody => ConfidenceSignal::HeredocBodySpan,
SpanKind::Unknown => ConfidenceSignal::UnknownSpan,
};
}
}
ConfidenceSignal::UnknownSpan
}
fn has_execution_operators_nearby(command: &str, match_start: usize, match_end: usize) -> bool {
let search_start = match_start.saturating_sub(20);
let prefix = command.get(search_start..match_start).unwrap_or("");
let search_end = (match_end + 20).min(command.len());
let suffix = command.get(match_end..search_end).unwrap_or("");
let operators = ["|", ";", "&&", "||", "$(", "`"];
for op in &operators {
if prefix.contains(op) || suffix.contains(op) {
return true;
}
}
false
}
fn is_command_position(command: &str, match_start: usize) -> bool {
if match_start == 0 {
return true;
}
let prefix = &command[..match_start];
let trimmed = prefix.trim_end();
if trimmed.is_empty() {
return true;
}
let last_char = trimmed.chars().last().unwrap_or(' ');
matches!(last_char, '|' | ';' | '(' | '`')
|| trimmed.ends_with("&&")
|| trimmed.ends_with("||")
|| trimmed.ends_with("$(")
}
#[must_use]
pub fn should_downgrade_to_warn(ctx: &ConfidenceContext<'_>) -> (ConfidenceScore, bool) {
let score = compute_match_confidence(ctx);
let downgrade = score.should_warn();
(score, downgrade)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_high_confidence_executed_command() {
let ctx = ConfidenceContext {
command: "rm -rf /",
sanitized_command: None,
match_start: 0,
match_end: 8,
};
let score = compute_match_confidence(&ctx);
assert!(
score.value > 0.5,
"Direct command should have high confidence"
);
}
#[test]
fn test_low_confidence_in_commit_message() {
let ctx = ConfidenceContext {
command: "git commit -m 'Fix rm -rf detection'",
sanitized_command: Some("git commit -m ''"),
match_start: 18,
match_end: 31,
};
let score = compute_match_confidence(&ctx);
assert!(
score.value < 0.5,
"Match in sanitized commit message should have low confidence: {}",
score.value
);
}
#[test]
fn test_confidence_with_pipe_operator() {
let ctx = ConfidenceContext {
command: "echo foo | rm -rf /",
sanitized_command: None,
match_start: 11,
match_end: 19,
};
let score = compute_match_confidence(&ctx);
assert!(
score
.signals
.contains(&ConfidenceSignal::ExecutionOperatorsNearby),
"Should detect pipe operator"
);
}
#[test]
fn test_command_position_detection() {
assert!(is_command_position("rm -rf /", 0));
assert!(is_command_position("echo foo | rm -rf /", 11));
assert!(is_command_position("foo && rm -rf /", 7));
assert!(!is_command_position("git commit -m 'rm'", 15));
}
#[test]
fn test_confidence_signal_weights() {
assert!(ConfidenceSignal::ExecutedSpan.weight() >= 1.0);
assert!(ConfidenceSignal::DataSpan.weight() < 0.5);
assert!(ConfidenceSignal::CommentSpan.weight() < 0.1);
}
#[test]
fn test_should_warn_threshold() {
let mut score = ConfidenceScore::high();
assert!(!score.should_warn(), "High confidence should not warn");
score.add_signal(ConfidenceSignal::DataSpan);
assert!(score.should_warn(), "Low confidence should warn");
}
#[test]
fn test_utf8_multibyte_handling() {
let command = "🔥🔥🔥 rm -rf /";
let ctx = ConfidenceContext {
command,
sanitized_command: None,
match_start: 13, match_end: 21, };
let score = compute_match_confidence(&ctx);
assert!(score.value > 0.0, "Should compute a valid score");
}
#[test]
fn test_operators_nearby_with_unicode() {
let command = "écho café | rm -rf /";
let result = has_execution_operators_nearby(command, 14, 22);
assert!(
result,
"Should detect pipe operator even with unicode prefix"
);
}
}