aimds_detection/
pattern_matcher.rs1use aimds_core::{DetectionResult, Result, ThreatSeverity, ThreatType};
4use aho_corasick::AhoCorasick;
5use chrono::Utc;
6use dashmap::DashMap;
7use regex::RegexSet;
8use std::sync::Arc;
9use midstreamer_temporal_compare::{TemporalComparator, Sequence, ComparisonAlgorithm};
10use uuid::Uuid;
11
12pub struct PatternMatcher {
14 aho_corasick: Arc<AhoCorasick>,
16 regex_set: Arc<RegexSet>,
18 temporal_comparator: TemporalComparator<i32>,
20 cache: Arc<DashMap<String, DetectionResult>>,
22}
23
24impl PatternMatcher {
25 pub fn new() -> Result<Self> {
27 let patterns = Self::default_patterns();
28 let regexes = Self::default_regexes();
29
30 let aho_corasick = AhoCorasick::new(patterns)
31 .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
32
33 let regex_set = RegexSet::new(regexes)
34 .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
35
36 Ok(Self {
37 aho_corasick: Arc::new(aho_corasick),
38 regex_set: Arc::new(regex_set),
39 temporal_comparator: TemporalComparator::new(1000, 1000), cache: Arc::new(DashMap::new()),
41 })
42 }
43
44 pub async fn match_patterns(&self, input: &str) -> Result<DetectionResult> {
46 let hash = blake3::hash(input.as_bytes());
48 let input_hash = hash.to_hex().to_string();
49 if let Some(cached) = self.cache.get(&input_hash) {
50 return Ok(cached.clone());
51 }
52
53 let mut matched_patterns = Vec::new();
55 let mut max_severity = ThreatSeverity::Low;
56 let mut threat_type = ThreatType::Unknown;
57
58 for mat in self.aho_corasick.find_iter(input) {
60 let pattern_id = mat.pattern().as_usize();
61 matched_patterns.push(format!("pattern_{}", pattern_id));
62
63 if pattern_id < 10 {
65 max_severity = ThreatSeverity::Critical;
66 threat_type = ThreatType::PromptInjection;
67 }
68 }
69
70 let regex_matches = self.regex_set.matches(input);
72 for pattern_id in regex_matches.iter() {
73 matched_patterns.push(format!("regex_{}", pattern_id));
74
75 if pattern_id < 5 {
76 max_severity = std::cmp::max(max_severity, ThreatSeverity::High);
77 threat_type = ThreatType::JailbreakAttempt;
78 }
79 }
80
81 let temporal_score = self.analyze_temporal_patterns(input).await?;
83
84 let confidence = self.calculate_confidence(&matched_patterns, temporal_score);
86
87 let result = DetectionResult {
88 id: Uuid::new_v4(),
89 timestamp: Utc::now(),
90 severity: max_severity,
91 threat_type,
92 confidence,
93 input_hash: input_hash.clone(),
94 matched_patterns,
95 context: serde_json::json!({
96 "temporal_score": temporal_score,
97 "input_length": input.len(),
98 }),
99 };
100
101 self.cache.insert(input_hash, result.clone());
103
104 Ok(result)
105 }
106
107 async fn analyze_temporal_patterns(&self, input: &str) -> Result<f64> {
109 let mut input_sequence = Sequence::new();
111 for (idx, ch) in input.chars().take(1000).enumerate() {
112 input_sequence.push(ch as i32, idx as u64);
113 }
114
115 let threat_sequences = Self::threat_temporal_sequences();
118
119 let mut max_similarity: f64 = 0.0;
120 for threat_seq in threat_sequences {
121 match self.temporal_comparator.compare(
122 &input_sequence,
123 &threat_seq,
124 ComparisonAlgorithm::DTW,
125 ) {
126 Ok(result) => {
127 let similarity = 1.0 / (1.0 + result.distance);
129 max_similarity = max_similarity.max(similarity);
130 }
131 Err(_) => continue,
132 }
133 }
134
135 Ok(max_similarity)
136 }
137
138 fn threat_temporal_sequences() -> Vec<Sequence<i32>> {
140 vec![
141 Self::str_to_sequence("ignore previous instructions"),
143 Self::str_to_sequence("you are no longer bound by"),
145 Self::str_to_sequence("system: you must now"),
147 ]
148 }
149
150 fn str_to_sequence(s: &str) -> Sequence<i32> {
152 let mut seq = Sequence::new();
153 for (idx, ch) in s.chars().enumerate() {
154 seq.push(ch as i32, idx as u64);
155 }
156 seq
157 }
158
159 fn calculate_confidence(&self, patterns: &[String], temporal_score: f64) -> f64 {
161 let pattern_score = (patterns.len() as f64 * 0.1).min(0.7);
162 let combined = (pattern_score * 0.6) + (temporal_score * 0.4);
163 combined.min(1.0)
164 }
165
166 fn default_patterns() -> Vec<&'static str> {
168 vec![
169 "ignore previous instructions",
170 "disregard all prior",
171 "forget everything",
172 "system prompt",
173 "admin mode",
174 "developer mode",
175 "jailbreak",
176 "unrestricted mode",
177 "bypass filter",
178 "override safety",
179 ]
180 }
181
182 fn default_regexes() -> Vec<&'static str> {
184 vec![
185 r"(?i)ignore\s+(all|previous|prior)\s+instructions",
186 r"(?i)system\s*:\s*you\s+are",
187 r"(?i)act\s+as\s+(an?\s+)?unrestricted",
188 r"(?i)pretend\s+you\s+are\s+(not\s+)?bound",
189 r"(?i)disregard\s+your\s+(programming|rules)",
190 ]
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[tokio::test]
199 async fn test_pattern_matcher_creation() {
200 let matcher = PatternMatcher::new();
201 assert!(matcher.is_ok());
202 }
203
204 #[tokio::test]
205 async fn test_simple_pattern_match() {
206 let matcher = PatternMatcher::new().unwrap();
207 let result = matcher
208 .match_patterns("Please ignore previous instructions")
209 .await
210 .unwrap();
211
212 assert!(!result.matched_patterns.is_empty());
213 assert!(result.confidence > 0.0);
214 }
215
216 #[tokio::test]
217 async fn test_safe_input() {
218 let matcher = PatternMatcher::new().unwrap();
219 let result = matcher
220 .match_patterns("What is the weather today?")
221 .await
222 .unwrap();
223
224 assert!(result.matched_patterns.is_empty());
225 }
226}