aimds-detection 0.1.1

Fast-path detection layer for AIMDS with pattern matching and anomaly detection
Documentation
//! Pattern matching for threat detection

use aimds_core::{DetectionResult, Result, ThreatSeverity, ThreatType};
use aho_corasick::AhoCorasick;
use chrono::Utc;
use dashmap::DashMap;
use regex::RegexSet;
use std::sync::Arc;
use midstreamer_temporal_compare::{TemporalComparator, Sequence, ComparisonAlgorithm};
use uuid::Uuid;

/// Pattern matcher using multiple detection strategies
pub struct PatternMatcher {
    /// Fast string matching for known patterns
    aho_corasick: Arc<AhoCorasick>,
    /// Regex patterns for complex matching
    regex_set: Arc<RegexSet>,
    /// Temporal comparison for behavioral patterns (using i32 for character codes)
    temporal_comparator: TemporalComparator<i32>,
    /// Pattern cache for performance
    cache: Arc<DashMap<String, DetectionResult>>,
}

impl PatternMatcher {
    /// Create a new pattern matcher with default patterns
    pub fn new() -> Result<Self> {
        let patterns = Self::default_patterns();
        let regexes = Self::default_regexes();

        let aho_corasick = AhoCorasick::new(patterns)
            .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;

        let regex_set = RegexSet::new(regexes)
            .map_err(|e| aimds_core::AimdsError::Detection(e.to_string()))?;

        Ok(Self {
            aho_corasick: Arc::new(aho_corasick),
            regex_set: Arc::new(regex_set),
            temporal_comparator: TemporalComparator::new(1000, 1000), // cache_size, max_length
            cache: Arc::new(DashMap::new()),
        })
    }

    /// Match patterns in the input text
    pub async fn match_patterns(&self, input: &str) -> Result<DetectionResult> {
        // Check cache first
        let hash = blake3::hash(input.as_bytes());
        let input_hash = hash.to_hex().to_string();
        if let Some(cached) = self.cache.get(&input_hash) {
            return Ok(cached.clone());
        }

        // Perform pattern matching
        let mut matched_patterns = Vec::new();
        let mut max_severity = ThreatSeverity::Low;
        let mut threat_type = ThreatType::Unknown;

        // Fast string matching
        for mat in self.aho_corasick.find_iter(input) {
            let pattern_id = mat.pattern().as_usize();
            matched_patterns.push(format!("pattern_{}", pattern_id));

            // Update severity based on pattern
            if pattern_id < 10 {
                max_severity = ThreatSeverity::Critical;
                threat_type = ThreatType::PromptInjection;
            }
        }

        // Regex matching
        let regex_matches = self.regex_set.matches(input);
        for pattern_id in regex_matches.iter() {
            matched_patterns.push(format!("regex_{}", pattern_id));

            if pattern_id < 5 {
                max_severity = std::cmp::max(max_severity, ThreatSeverity::High);
                threat_type = ThreatType::JailbreakAttempt;
            }
        }

        // Temporal analysis for behavioral patterns
        let temporal_score = self.analyze_temporal_patterns(input).await?;

        // Calculate confidence based on matches
        let confidence = self.calculate_confidence(&matched_patterns, temporal_score);

        let result = DetectionResult {
            id: Uuid::new_v4(),
            timestamp: Utc::now(),
            severity: max_severity,
            threat_type,
            confidence,
            input_hash: input_hash.clone(),
            matched_patterns,
            context: serde_json::json!({
                "temporal_score": temporal_score,
                "input_length": input.len(),
            }),
        };

        // Cache the result
        self.cache.insert(input_hash, result.clone());

        Ok(result)
    }

    /// Analyze temporal patterns using Midstream's temporal comparator
    async fn analyze_temporal_patterns(&self, input: &str) -> Result<f64> {
        // Convert input to temporal sequence for DTW analysis (using i32 for char codes)
        let mut input_sequence = Sequence::new();
        for (idx, ch) in input.chars().take(1000).enumerate() {
            input_sequence.push(ch as i32, idx as u64);
        }

        // Use temporal-compare DTW (validated: 7.8ms performance)
        // Compare against known malicious temporal patterns
        let threat_sequences = Self::threat_temporal_sequences();

        let mut max_similarity: f64 = 0.0;
        for threat_seq in threat_sequences {
            match self.temporal_comparator.compare(
                &input_sequence,
                &threat_seq,
                ComparisonAlgorithm::DTW,
            ) {
                Ok(result) => {
                    // Convert distance to similarity (lower distance = higher similarity)
                    let similarity = 1.0 / (1.0 + result.distance);
                    max_similarity = max_similarity.max(similarity);
                }
                Err(_) => continue,
            }
        }

        Ok(max_similarity)
    }

    /// Known threat temporal sequences for DTW comparison
    fn threat_temporal_sequences() -> Vec<Sequence<i32>> {
        vec![
            // Prompt injection temporal pattern
            Self::str_to_sequence("ignore previous instructions"),
            // Jailbreak attempt pattern
            Self::str_to_sequence("you are no longer bound by"),
            // System prompt override pattern
            Self::str_to_sequence("system: you must now"),
        ]
    }

    /// Helper to convert string to Sequence
    fn str_to_sequence(s: &str) -> Sequence<i32> {
        let mut seq = Sequence::new();
        for (idx, ch) in s.chars().enumerate() {
            seq.push(ch as i32, idx as u64);
        }
        seq
    }

    /// Calculate confidence score
    fn calculate_confidence(&self, patterns: &[String], temporal_score: f64) -> f64 {
        let pattern_score = (patterns.len() as f64 * 0.1).min(0.7);
        let combined = (pattern_score * 0.6) + (temporal_score * 0.4);
        combined.min(1.0)
    }

    /// Default threat patterns
    fn default_patterns() -> Vec<&'static str> {
        vec![
            "ignore previous instructions",
            "disregard all prior",
            "forget everything",
            "system prompt",
            "admin mode",
            "developer mode",
            "jailbreak",
            "unrestricted mode",
            "bypass filter",
            "override safety",
        ]
    }

    /// Default regex patterns
    fn default_regexes() -> Vec<&'static str> {
        vec![
            r"(?i)ignore\s+(all|previous|prior)\s+instructions",
            r"(?i)system\s*:\s*you\s+are",
            r"(?i)act\s+as\s+(an?\s+)?unrestricted",
            r"(?i)pretend\s+you\s+are\s+(not\s+)?bound",
            r"(?i)disregard\s+your\s+(programming|rules)",
        ]
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_pattern_matcher_creation() {
        let matcher = PatternMatcher::new();
        assert!(matcher.is_ok());
    }

    #[tokio::test]
    async fn test_simple_pattern_match() {
        let matcher = PatternMatcher::new().unwrap();
        let result = matcher
            .match_patterns("Please ignore previous instructions")
            .await
            .unwrap();

        assert!(!result.matched_patterns.is_empty());
        assert!(result.confidence > 0.0);
    }

    #[tokio::test]
    async fn test_safe_input() {
        let matcher = PatternMatcher::new().unwrap();
        let result = matcher
            .match_patterns("What is the weather today?")
            .await
            .unwrap();

        assert!(result.matched_patterns.is_empty());
    }
}