use crate::constants::*;
use crate::patterns::*;
use crate::types::InjectionDetectionResult;
pub struct DetectionEngine {
config: crate::types::LLMSecurityConfig,
}
impl DetectionEngine {
pub fn new(config: crate::types::LLMSecurityConfig) -> Self {
Self { config }
}
pub fn detect_prompt_injection(&self, code: &str) -> InjectionDetectionResult {
let mut detected_patterns = Vec::new();
let mut risk_score = 0u32;
for pattern in get_prompt_injection_patterns().iter() {
if let Some(captures) = pattern.captures(code) {
let matched = captures.get(0).unwrap().as_str();
detected_patterns.push(matched.to_string());
risk_score += REGEX_PATTERN_RISK_SCORE;
}
}
let lower_code = code.to_lowercase();
for keyword in get_dangerous_keywords().iter() {
if lower_code.contains(keyword) {
detected_patterns.push(format!("Keyword: {}", keyword));
risk_score += KEYWORD_RISK_SCORE;
}
}
if self.detect_homoglyphs(code) {
detected_patterns.push("Homoglyph characters detected".to_string());
risk_score += HOMOGLYPH_RISK_SCORE;
}
if code.chars().any(|c| get_rtl_override_chars().contains(&c)) {
detected_patterns.push("RTL override characters detected".to_string());
risk_score += RTL_OVERRIDE_RISK_SCORE;
}
if self.detect_markdown_manipulation(code) {
detected_patterns.push("Suspicious markdown formatting".to_string());
risk_score += MARKDOWN_MANIPULATION_RISK_SCORE;
}
let special_char_ratio = code
.chars()
.filter(|c| !c.is_alphanumeric() && !c.is_whitespace())
.count() as f32
/ code.len() as f32;
if special_char_ratio > MAX_SPECIAL_CHAR_RATIO {
detected_patterns.push("High special character ratio".to_string());
risk_score += SPECIAL_CHAR_RISK_SCORE;
}
if code
.chars()
.any(|c| matches!(c, '\u{200B}' | '\u{200C}' | '\u{200D}' | '\u{FEFF}'))
{
detected_patterns.push("Hidden unicode characters".to_string());
risk_score += HIDDEN_UNICODE_RISK_SCORE;
}
if self.detect_semantic_cloaking(&lower_code) {
detected_patterns.push("Semantic cloaking detected".to_string());
risk_score += SEMANTIC_CLOAKING_RISK_SCORE;
}
if lower_code.contains("let's think step by step")
|| lower_code.contains("step 1:") && lower_code.contains("therefore")
{
detected_patterns.push("Chain-of-thought manipulation".to_string());
risk_score += CHAIN_OF_THOUGHT_RISK_SCORE;
}
if lower_code.contains("example")
&& lower_code.contains("result:")
&& (lower_code.contains("safe") || lower_code.contains("ok"))
{
detected_patterns.push("Few-shot example poisoning".to_string());
risk_score += FEW_SHOT_POISONING_RISK_SCORE;
}
let confidence = (risk_score as f32 / 100.0).min(1.0);
let is_malicious = risk_score > DEFAULT_MALICIOUS_THRESHOLD;
InjectionDetectionResult::new(is_malicious, confidence, detected_patterns, risk_score)
}
fn detect_homoglyphs(&self, text: &str) -> bool {
text.chars().any(|c| {
matches!(c as u32,
0x0400..=0x04FF |
0x0370..=0x03FF |
0x1D400..=0x1D7FF |
0xFF00..=0xFFEF
)
})
}
fn detect_markdown_manipulation(&self, text: &str) -> bool {
let bold_count = text.matches("**").count();
let italic_count = text.matches("*").count() - (bold_count * 2);
let header_count = text.matches('#').count();
(bold_count > 5 || italic_count > 10 || header_count > 5)
&& (text.to_lowercase().contains("important")
|| text.to_lowercase().contains("notice")
|| text.to_lowercase().contains("attention"))
}
fn detect_semantic_cloaking(&self, lower_text: &str) -> bool {
let cloaking_phrases = [
"as a fellow professional",
"in this specific context",
"may not be relevant",
"minor security concerns",
"focus on functionality",
"optimization rather than security",
"test environment",
"development purposes only",
"certified safe",
"approved by experts",
"security team reviewed",
"false positives expected",
"confidential and proprietary",
"not authorized to view",
"requires special permission",
"copyright protected code",
"under nda",
"legal disclaimer",
"violation of terms",
"unauthorized access",
"protected by law",
"intellectual property",
"must be executed to analyze",
"dynamic analysis required",
"runtime verification needed",
"compile and run to verify",
"execution is necessary",
];
cloaking_phrases
.iter()
.any(|phrase| lower_text.contains(phrase))
}
pub fn detect_prompt_injection_safe(&self, code: &str) -> InjectionDetectionResult {
if self.detect_regex_dos_patterns(code) {
return InjectionDetectionResult::malicious(
1.0,
vec!["Regex DoS attack".to_string()],
REGEX_DOS_RISK_SCORE,
);
}
if self.detect_steganography(code) {
return InjectionDetectionResult::malicious(
0.9,
vec!["Steganography detected".to_string()],
STEGANOGRAPHY_RISK_SCORE,
);
}
let normalized_code = self.normalize_unicode(code);
if self.detect_encoding_layers(&normalized_code) {
return InjectionDetectionResult::malicious(
0.8,
vec!["Multiple encoding layers".to_string()],
MULTIPLE_ENCODING_RISK_SCORE,
);
}
if self.detect_context_injection(&normalized_code) {
return InjectionDetectionResult::malicious(
0.85,
vec!["Context injection".to_string()],
CONTEXT_INJECTION_RISK_SCORE,
);
}
self.detect_prompt_injection(&normalized_code)
}
fn detect_regex_dos_patterns(&self, code: &str) -> bool {
if code.contains("++") || code.contains("**") || code.contains("??") {
return true;
}
if code.len() > 1000 {
let repeated_chars = code.chars().filter(|&c| c == 'a' || c == 'b').count();
if repeated_chars > code.len() / 2 {
return true;
}
}
if code.contains("(a+)+") || code.contains("(a*)*") || code.contains("(a|a)*") {
return true;
}
false
}
fn normalize_unicode(&self, input: &str) -> String {
use unicode_normalization::UnicodeNormalization;
let normalized = input.nfc().collect::<String>();
let cleaned = normalized
.chars()
.filter(|c| !matches!(c, '\u{200B}'..='\u{200D}' | '\u{FEFF}'))
.collect::<String>();
cleaned.replace("\r\n", "\n").replace('\r', "\n")
}
fn detect_steganography(&self, code: &str) -> bool {
let hidden_chars = ['\u{200B}', '\u{200C}', '\u{200D}', '\u{FEFF}'];
if hidden_chars.iter().any(|&c| code.contains(c)) {
return true;
}
let mut alternating_count = 0;
let chars: Vec<char> = code.chars().collect();
for i in 1..chars.len() {
if chars[i].is_ascii_alphabetic() && chars[i-1].is_ascii_alphabetic() {
if chars[i].is_uppercase() != chars[i-1].is_uppercase() {
alternating_count += 1;
}
}
}
if alternating_count > code.len() / 10 {
return true;
}
let spaces = code.matches(' ').count();
let tabs = code.matches('\t').count();
if spaces > code.len() / 3 || tabs > code.len() / 3 {
return true;
}
if code.contains("//") {
let lines: Vec<&str> = code.lines().collect();
for line in lines {
if line.trim().starts_with("//") {
let comment = line.trim_start_matches("//").trim();
if comment.len() > 20 && comment.chars().all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '=') {
return true;
}
}
}
}
false
}
fn detect_encoding_layers(&self, code: &str) -> bool {
if code.contains("base64:") || code.contains("b64:") {
return true;
}
if code.contains("hex:") || code.contains("0x") {
return true;
}
if code.contains("%20") || code.contains("%2F") || code.contains("%2E") {
return true;
}
if code.contains("&#") || code.contains("<") || code.contains(">") {
return true;
}
if code.contains("rot13:") || code.contains("caesar:") {
return true;
}
if code.contains("binary:") || code.contains("bin:") {
return true;
}
let encoding_indicators = ["decode", "encode", "encrypt", "decrypt", "cipher", "crypto"];
let mut count = 0;
for indicator in encoding_indicators.iter() {
if code.to_lowercase().contains(indicator) {
count += 1;
}
}
count >= 2
}
fn detect_context_injection(&self, code: &str) -> bool {
if code.contains("{") && code.contains("}") {
if let Some(start) = code.find('{') {
if let Some(end) = code[start..].find('}') {
let json_like = &code[start..start + end + 1];
if json_like.contains("\"ignore\"") || json_like.contains("\"override\"") ||
json_like.contains("\"bypass\"") || json_like.contains("\"skip\"") {
return true;
}
}
}
}
if code.contains("<") && code.contains(">") {
if code.contains("<ignore>") || code.contains("<override>") ||
code.contains("<bypass>") || code.contains("<skip>") {
return true;
}
}
if code.contains("{{") && code.contains("}}") {
if code.contains("{{ignore}}") || code.contains("{{override}}") ||
code.contains("{{bypass}}") || code.contains("{{skip}}") {
return true;
}
}
if code.contains("pr") && (code.contains("OR") || code.contains("AND")) {
return true;
}
if code.contains("`") || code.contains("$(") || code.contains("${") {
return true;
}
false
}
}