use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InjectionResult {
pub safe: bool,
pub confidence: f64,
pub detected_patterns: Vec<String>,
}
type PatternMatcher = (String, Box<dyn Fn(&str) -> bool + Send + Sync>);
pub struct PromptInjectionDetector {
patterns: Vec<PatternMatcher>,
}
impl PromptInjectionDetector {
#[must_use]
pub fn new() -> Self {
let patterns: Vec<PatternMatcher> = vec![
(
"ignore_previous_instructions".into(),
Box::new(|s: &str| {
let l = s.to_lowercase();
l.contains("ignore previous instructions")
|| l.contains("ignore all previous")
|| l.contains("disregard previous")
|| l.contains("forget previous instructions")
|| l.contains("ignore your instructions")
}),
),
(
"system_prompt_leak".into(),
Box::new(|s: &str| {
let l = s.to_lowercase();
l.contains("system prompt:")
|| l.contains("system message:")
|| l.contains("reveal your system prompt")
|| l.contains("show me your instructions")
|| l.contains("print your system prompt")
}),
),
(
"role_confusion".into(),
Box::new(|s: &str| {
let l = s.to_lowercase();
l.contains("you are now")
|| l.contains("act as a")
|| l.contains("pretend you are")
|| l.contains("roleplay as")
|| l.contains("switch to role")
}),
),
(
"excessive_special_chars".into(),
Box::new(|s: &str| {
let char_count = s.chars().count();
if char_count < 20 {
return false;
}
let special: usize = s
.chars()
.filter(|c| {
!c.is_alphanumeric() && !c.is_whitespace() && *c != '.' && *c != ','
})
.count();
let ratio = special as f64 / char_count as f64;
ratio > 0.4
}),
),
(
"base64_payload".into(),
Box::new(|s: &str| {
let char_count = s.chars().count();
if char_count < 40 {
return false;
}
let base64_chars: usize = s
.chars()
.filter(|c| {
c.is_ascii_alphanumeric() || *c == '+' || *c == '/' || *c == '='
})
.count();
let ratio = base64_chars as f64 / char_count as f64;
ratio > 0.85 && s.contains('=')
}),
),
(
"delimiter_injection".into(),
Box::new(|s: &str| {
let l = s.to_lowercase();
l.contains("```system") || l.contains("---system") || l.contains("[system]")
}),
),
];
Self { patterns }
}
fn normalize_input(input: &str) -> String {
input
.chars()
.filter(|c| {
!matches!(
*c,
'\u{200B}' | '\u{200C}' | '\u{200D}' | '\u{200E}' | '\u{200F}' | '\u{202A}' | '\u{202B}' | '\u{202C}' | '\u{202D}' | '\u{202E}' | '\u{2060}' | '\u{2061}'..='\u{2064}' | '\u{FEFF}' | '\u{FE00}'..='\u{FE0F}' )
})
.collect()
}
#[must_use]
pub fn check_input(&self, input: &str) -> InjectionResult {
let normalized = Self::normalize_input(input);
let mut detected: Vec<String> = Vec::new();
for (label, check_fn) in &self.patterns {
if check_fn(&normalized) {
detected.push(label.clone());
}
}
let confidence = if detected.is_empty() {
0.0
} else {
(detected.len() as f64 * 0.25).min(1.0)
};
InjectionResult {
safe: detected.is_empty(),
confidence,
detected_patterns: detected,
}
}
}
impl Default for PromptInjectionDetector {
fn default() -> Self {
Self::new()
}
}