use aimds_core::{DetectionResult, Result, ThreatSeverity, ThreatType};
use aho_corasick::AhoCorasick;
use chrono::Utc;
use dashmap::DashMap;
use regex::RegexSet;
use std::sync::Arc;
use midstreamer_temporal_compare::{TemporalComparator, Sequence, ComparisonAlgorithm};
use uuid::Uuid;
pub struct PatternMatcher {
aho_corasick: Arc<AhoCorasick>,
regex_set: Arc<RegexSet>,
temporal_comparator: TemporalComparator<i32>,
cache: Arc<DashMap<String, DetectionResult>>,
}
impl PatternMatcher {
pub fn new() -> Result<Self> {
let patterns = Self::default_patterns();
let regexes = Self::default_regexes();
let aho_corasick = AhoCorasick::new(patterns)
.map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
let regex_set = RegexSet::new(regexes)
.map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
Ok(Self {
aho_corasick: Arc::new(aho_corasick),
regex_set: Arc::new(regex_set),
temporal_comparator: TemporalComparator::new(1000, 1000), cache: Arc::new(DashMap::new()),
})
}
pub async fn match_patterns(&self, input: &str) -> Result<DetectionResult> {
let hash = blake3::hash(input.as_bytes());
let input_hash = hash.to_hex().to_string();
if let Some(cached) = self.cache.get(&input_hash) {
return Ok(cached.clone());
}
let mut matched_patterns = Vec::new();
let mut max_severity = ThreatSeverity::Low;
let mut threat_type = ThreatType::Unknown;
for mat in self.aho_corasick.find_iter(input) {
let pattern_id = mat.pattern().as_usize();
matched_patterns.push(format!("pattern_{}", pattern_id));
if pattern_id < 10 {
max_severity = ThreatSeverity::Critical;
threat_type = ThreatType::PromptInjection;
}
}
let regex_matches = self.regex_set.matches(input);
for pattern_id in regex_matches.iter() {
matched_patterns.push(format!("regex_{}", pattern_id));
if pattern_id < 5 {
max_severity = std::cmp::max(max_severity, ThreatSeverity::High);
threat_type = ThreatType::JailbreakAttempt;
}
}
let temporal_score = self.analyze_temporal_patterns(input).await?;
let confidence = self.calculate_confidence(&matched_patterns, temporal_score);
let result = DetectionResult {
id: Uuid::new_v4(),
timestamp: Utc::now(),
severity: max_severity,
threat_type,
confidence,
input_hash: input_hash.clone(),
matched_patterns,
context: serde_json::json!({
"temporal_score": temporal_score,
"input_length": input.len(),
}),
};
self.cache.insert(input_hash, result.clone());
Ok(result)
}
async fn analyze_temporal_patterns(&self, input: &str) -> Result<f64> {
let mut input_sequence = Sequence::new();
for (idx, ch) in input.chars().take(1000).enumerate() {
input_sequence.push(ch as i32, idx as u64);
}
let threat_sequences = Self::threat_temporal_sequences();
let mut max_similarity: f64 = 0.0;
for threat_seq in threat_sequences {
match self.temporal_comparator.compare(
&input_sequence,
&threat_seq,
ComparisonAlgorithm::DTW,
) {
Ok(result) => {
let similarity = 1.0 / (1.0 + result.distance);
max_similarity = max_similarity.max(similarity);
}
Err(_) => continue,
}
}
Ok(max_similarity)
}
fn threat_temporal_sequences() -> Vec<Sequence<i32>> {
vec![
Self::str_to_sequence("ignore previous instructions"),
Self::str_to_sequence("you are no longer bound by"),
Self::str_to_sequence("system: you must now"),
]
}
fn str_to_sequence(s: &str) -> Sequence<i32> {
let mut seq = Sequence::new();
for (idx, ch) in s.chars().enumerate() {
seq.push(ch as i32, idx as u64);
}
seq
}
fn calculate_confidence(&self, patterns: &[String], temporal_score: f64) -> f64 {
let pattern_score = (patterns.len() as f64 * 0.1).min(0.7);
let combined = (pattern_score * 0.6) + (temporal_score * 0.4);
combined.min(1.0)
}
fn default_patterns() -> Vec<&'static str> {
vec![
"ignore previous instructions",
"disregard all prior",
"forget everything",
"system prompt",
"admin mode",
"developer mode",
"jailbreak",
"unrestricted mode",
"bypass filter",
"override safety",
]
}
fn default_regexes() -> Vec<&'static str> {
vec![
r"(?i)ignore\s+(all|previous|prior)\s+instructions",
r"(?i)system\s*:\s*you\s+are",
r"(?i)act\s+as\s+(an?\s+)?unrestricted",
r"(?i)pretend\s+you\s+are\s+(not\s+)?bound",
r"(?i)disregard\s+your\s+(programming|rules)",
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pattern_matcher_creation() {
let matcher = PatternMatcher::new();
assert!(matcher.is_ok());
}
#[tokio::test]
async fn test_simple_pattern_match() {
let matcher = PatternMatcher::new().unwrap();
let result = matcher
.match_patterns("Please ignore previous instructions")
.await
.unwrap();
assert!(!result.matched_patterns.is_empty());
assert!(result.confidence > 0.0);
}
#[tokio::test]
async fn test_safe_input() {
let matcher = PatternMatcher::new().unwrap();
let result = matcher
.match_patterns("What is the weather today?")
.await
.unwrap();
assert!(result.matched_patterns.is_empty());
}
}