aimds_detection/
pattern_matcher.rs

1//! Pattern matching for threat detection
2
3use aimds_core::{DetectionResult, Result, ThreatSeverity, ThreatType};
4use aho_corasick::AhoCorasick;
5use chrono::Utc;
6use dashmap::DashMap;
7use regex::RegexSet;
8use std::sync::Arc;
9use midstreamer_temporal_compare::{TemporalComparator, Sequence, ComparisonAlgorithm};
10use uuid::Uuid;
11
12/// Pattern matcher using multiple detection strategies
13pub struct PatternMatcher {
14    /// Fast string matching for known patterns
15    aho_corasick: Arc<AhoCorasick>,
16    /// Regex patterns for complex matching
17    regex_set: Arc<RegexSet>,
18    /// Temporal comparison for behavioral patterns (using i32 for character codes)
19    temporal_comparator: TemporalComparator<i32>,
20    /// Pattern cache for performance
21    cache: Arc<DashMap<String, DetectionResult>>,
22}
23
24impl PatternMatcher {
25    /// Create a new pattern matcher with default patterns
26    pub fn new() -> Result<Self> {
27        let patterns = Self::default_patterns();
28        let regexes = Self::default_regexes();
29
30        let aho_corasick = AhoCorasick::new(patterns)
31            .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
32
33        let regex_set = RegexSet::new(regexes)
34            .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;
35
36        Ok(Self {
37            aho_corasick: Arc::new(aho_corasick),
38            regex_set: Arc::new(regex_set),
39            temporal_comparator: TemporalComparator::new(1000, 1000), // cache_size, max_length
40            cache: Arc::new(DashMap::new()),
41        })
42    }
43
44    /// Match patterns in the input text
45    pub async fn match_patterns(&self, input: &str) -> Result<DetectionResult> {
46        // Check cache first
47        let hash = blake3::hash(input.as_bytes());
48        let input_hash = hash.to_hex().to_string();
49        if let Some(cached) = self.cache.get(&input_hash) {
50            return Ok(cached.clone());
51        }
52
53        // Perform pattern matching
54        let mut matched_patterns = Vec::new();
55        let mut max_severity = ThreatSeverity::Low;
56        let mut threat_type = ThreatType::Unknown;
57
58        // Fast string matching
59        for mat in self.aho_corasick.find_iter(input) {
60            let pattern_id = mat.pattern().as_usize();
61            matched_patterns.push(format!("pattern_{}", pattern_id));
62
63            // Update severity based on pattern
64            if pattern_id < 10 {
65                max_severity = ThreatSeverity::Critical;
66                threat_type = ThreatType::PromptInjection;
67            }
68        }
69
70        // Regex matching
71        let regex_matches = self.regex_set.matches(input);
72        for pattern_id in regex_matches.iter() {
73            matched_patterns.push(format!("regex_{}", pattern_id));
74
75            if pattern_id < 5 {
76                max_severity = std::cmp::max(max_severity, ThreatSeverity::High);
77                threat_type = ThreatType::JailbreakAttempt;
78            }
79        }
80
81        // Temporal analysis for behavioral patterns
82        let temporal_score = self.analyze_temporal_patterns(input).await?;
83
84        // Calculate confidence based on matches
85        let confidence = self.calculate_confidence(&matched_patterns, temporal_score);
86
87        let result = DetectionResult {
88            id: Uuid::new_v4(),
89            timestamp: Utc::now(),
90            severity: max_severity,
91            threat_type,
92            confidence,
93            input_hash: input_hash.clone(),
94            matched_patterns,
95            context: serde_json::json!({
96                "temporal_score": temporal_score,
97                "input_length": input.len(),
98            }),
99        };
100
101        // Cache the result
102        self.cache.insert(input_hash, result.clone());
103
104        Ok(result)
105    }
106
107    /// Analyze temporal patterns using Midstream's temporal comparator
108    async fn analyze_temporal_patterns(&self, input: &str) -> Result<f64> {
109        // Convert input to temporal sequence for DTW analysis (using i32 for char codes)
110        let mut input_sequence = Sequence::new();
111        for (idx, ch) in input.chars().take(1000).enumerate() {
112            input_sequence.push(ch as i32, idx as u64);
113        }
114
115        // Use temporal-compare DTW (validated: 7.8ms performance)
116        // Compare against known malicious temporal patterns
117        let threat_sequences = Self::threat_temporal_sequences();
118
119        let mut max_similarity: f64 = 0.0;
120        for threat_seq in threat_sequences {
121            match self.temporal_comparator.compare(
122                &input_sequence,
123                &threat_seq,
124                ComparisonAlgorithm::DTW,
125            ) {
126                Ok(result) => {
127                    // Convert distance to similarity (lower distance = higher similarity)
128                    let similarity = 1.0 / (1.0 + result.distance);
129                    max_similarity = max_similarity.max(similarity);
130                }
131                Err(_) => continue,
132            }
133        }
134
135        Ok(max_similarity)
136    }
137
138    /// Known threat temporal sequences for DTW comparison
139    fn threat_temporal_sequences() -> Vec<Sequence<i32>> {
140        vec![
141            // Prompt injection temporal pattern
142            Self::str_to_sequence("ignore previous instructions"),
143            // Jailbreak attempt pattern
144            Self::str_to_sequence("you are no longer bound by"),
145            // System prompt override pattern
146            Self::str_to_sequence("system: you must now"),
147        ]
148    }
149
150    /// Helper to convert string to Sequence
151    fn str_to_sequence(s: &str) -> Sequence<i32> {
152        let mut seq = Sequence::new();
153        for (idx, ch) in s.chars().enumerate() {
154            seq.push(ch as i32, idx as u64);
155        }
156        seq
157    }
158
159    /// Calculate confidence score
160    fn calculate_confidence(&self, patterns: &[String], temporal_score: f64) -> f64 {
161        let pattern_score = (patterns.len() as f64 * 0.1).min(0.7);
162        let combined = (pattern_score * 0.6) + (temporal_score * 0.4);
163        combined.min(1.0)
164    }
165
166    /// Default threat patterns
167    fn default_patterns() -> Vec<&'static str> {
168        vec![
169            "ignore previous instructions",
170            "disregard all prior",
171            "forget everything",
172            "system prompt",
173            "admin mode",
174            "developer mode",
175            "jailbreak",
176            "unrestricted mode",
177            "bypass filter",
178            "override safety",
179        ]
180    }
181
182    /// Default regex patterns
183    fn default_regexes() -> Vec<&'static str> {
184        vec![
185            r"(?i)ignore\s+(all|previous|prior)\s+instructions",
186            r"(?i)system\s*:\s*you\s+are",
187            r"(?i)act\s+as\s+(an?\s+)?unrestricted",
188            r"(?i)pretend\s+you\s+are\s+(not\s+)?bound",
189            r"(?i)disregard\s+your\s+(programming|rules)",
190        ]
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[tokio::test]
199    async fn test_pattern_matcher_creation() {
200        let matcher = PatternMatcher::new();
201        assert!(matcher.is_ok());
202    }
203
204    #[tokio::test]
205    async fn test_simple_pattern_match() {
206        let matcher = PatternMatcher::new().unwrap();
207        let result = matcher
208            .match_patterns("Please ignore previous instructions")
209            .await
210            .unwrap();
211
212        assert!(!result.matched_patterns.is_empty());
213        assert!(result.confidence > 0.0);
214    }
215
216    #[tokio::test]
217    async fn test_safe_input() {
218        let matcher = PatternMatcher::new().unwrap();
219        let result = matcher
220            .match_patterns("What is the weather today?")
221            .await
222            .unwrap();
223
224        assert!(result.matched_patterns.is_empty());
225    }
226}