multimatch 0.1.1

Multi-pattern matching engine — Aho-Corasick + regex with optional Hyperscan SIMD acceleration
Documentation
//! The matching engine — Aho-Corasick for literals, regex for patterns.

use aho_corasick::AhoCorasick;
use regex::bytes::Regex;

use crate::pattern::{PatternDef, PatternKind};
use crate::MatchError;

/// A match result from scanning.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MatchResult {
    /// The user-assigned pattern ID that matched.
    pub pattern_id: usize,
    /// Byte offset where the match starts.
    pub start: usize,
    /// Byte offset where the match ends (exclusive).
    pub end: usize,
}

/// Compiled matching engine.
pub struct MatchEngine {
    /// Aho-Corasick automaton for exact-match literal patterns.
    ac_exact: Option<AhoCorasick>,
    /// Map from exact AC pattern index to user pattern ID.
    ac_exact_map: Vec<usize>,
    /// Aho-Corasick automaton for case-insensitive literal patterns.
    ac_ci: Option<AhoCorasick>,
    /// Map from CI AC pattern index to user pattern ID.
    ac_ci_map: Vec<usize>,
    /// Compiled regex patterns with their user IDs.
    regexes: Vec<(usize, Regex)>,
}

impl MatchEngine {
    /// Compile patterns into an engine.
    pub fn compile(patterns: Vec<PatternDef>) -> Result<Self, MatchError> {
        let mut exact_literals: Vec<(String, usize)> = Vec::new();
        let mut regexes: Vec<(usize, Regex)> = Vec::new();

        for pat in &patterns {
            match &pat.kind {
                PatternKind::Literal(s) => {
                    if pat.case_insensitive {
                        let pattern_str = format!("(?i){}", regex::escape(s));
                        let compiled = Regex::new(&pattern_str).map_err(|e| {
                            MatchError::InvalidRegex {
                                id: pat.id,
                                source: regex::Error::Syntax(e.to_string()),
                            }
                        })?;
                        regexes.push((pat.id, compiled));
                    } else {
                        exact_literals.push((s.clone(), pat.id));
                    }
                }
                PatternKind::Regex(r) => {
                    let pattern_str = if pat.case_insensitive {
                        format!("(?i){r}")
                    } else {
                        r.clone()
                    };
                    let compiled = Regex::new(&pattern_str).map_err(|e| {
                        MatchError::InvalidRegex {
                            id: pat.id,
                            source: regex::Error::Syntax(e.to_string()),
                        }
                    })?;
                    regexes.push((pat.id, compiled));
                }
            }
        }

        let (ac_exact, ac_exact_map) = Self::build_ac(exact_literals, false)?;

        Ok(Self {
            ac_exact,
            ac_exact_map,
            ac_ci: None,
            ac_ci_map: Vec::new(),
            regexes,
        })
    }

    fn build_ac(literals: Vec<(String, usize)>, ci: bool) -> Result<(Option<AhoCorasick>, Vec<usize>), MatchError> {
        if literals.is_empty() {
            return Ok((None, Vec::new()));
        }
        let strs: Vec<&str> = literals.iter().map(|(s, _)| s.as_str()).collect();
        let id_map: Vec<usize> = literals.iter().map(|(_, id)| *id).collect();

        let mut builder = aho_corasick::AhoCorasick::builder();
        builder.ascii_case_insensitive(ci);

        let ac = builder.build(&strs).map_err(|e| MatchError::AhoCorasick(e.to_string()))?;
        Ok((Some(ac), id_map))
    }

    /// Scan input for all matches.
    pub fn scan(&self, input: &[u8]) -> Vec<MatchResult> {
        let mut results = Vec::with_capacity(
            self.ac_exact_map.len() + self.ac_ci_map.len() + self.regexes.len()
        );

        // Phase 1a: Aho-Corasick literal scan (Exact match)
        if let Some(ac) = &self.ac_exact {
            for mat in ac.find_overlapping_iter(input) {
                let ac_idx = mat.pattern().as_usize();
                if let Some(&user_id) = self.ac_exact_map.get(ac_idx) {
                    results.push(MatchResult {
                        pattern_id: user_id,
                        start: mat.start(),
                        end: mat.end(),
                    });
                }
            }
        }

        // Phase 1b: Aho-Corasick literal scan (Case-insensitive)
        if let Some(ac) = &self.ac_ci {
            for mat in ac.find_overlapping_iter(input) {
                let ac_idx = mat.pattern().as_usize();
                if let Some(&user_id) = self.ac_ci_map.get(ac_idx) {
                    results.push(MatchResult {
                        pattern_id: user_id,
                        start: mat.start(),
                        end: mat.end(),
                    });
                }
            }
        }

        // Phase 2: Regex scan.
        for (user_id, regex) in &self.regexes {
            for mat in regex.find_iter(input) {
                results.push(MatchResult {
                    pattern_id: *user_id,
                    start: mat.start(),
                    end: mat.end(),
                });
            }
        }

        results
    }

    /// Check if any pattern matches (short-circuit).
    pub fn is_match(&self, input: &[u8]) -> bool {
        if let Some(ac) = &self.ac_exact {
            if ac.is_match(input) {
                return true;
            }
        }
        if let Some(ac) = &self.ac_ci {
            if ac.is_match(input) {
                return true;
            }
        }
        self.regexes.iter().any(|(_, r)| r.is_match(input))
    }
}

#[cfg(test)]
mod tests {
    use crate::{PatternSet, Scanner};

    #[test]
    fn literal_match() {
        let ps = PatternSet::builder()
            .add_literal("password", 0)
            .add_literal("secret", 1)
            .build()
            .unwrap();

        let matches = ps.scan(b"my password is secret");
        assert!(matches.iter().any(|m| m.pattern_id == 0));
        assert!(matches.iter().any(|m| m.pattern_id == 1));
    }

    #[test]
    fn literal_no_match() {
        let ps = PatternSet::builder()
            .add_literal("foobar", 0)
            .build()
            .unwrap();

        let matches = ps.scan(b"nothing here");
        assert!(matches.is_empty());
    }

    #[test]
    fn regex_match() {
        let ps = PatternSet::builder()
            .add_regex(r"[0-9]{3}-[0-9]{4}", 0)
            .build()
            .unwrap();

        let matches = ps.scan(b"call 555-1234 now");
        assert_eq!(matches.len(), 1);
        assert_eq!(matches[0].pattern_id, 0);
        assert_eq!(matches[0].start, 5);
        assert_eq!(matches[0].end, 13);
    }

    #[test]
    fn mixed_literal_and_regex() {
        let ps = PatternSet::builder()
            .add_literal("token", 0)
            .add_regex(r"[A-Za-z0-9]{20,}", 1)
            .build()
            .unwrap();

        let matches = ps.scan(b"token=abcdefghij1234567890XYZ");
        assert!(matches.iter().any(|m| m.pattern_id == 0)); // literal "token"
        assert!(matches.iter().any(|m| m.pattern_id == 1)); // regex long string
    }

    #[test]
    fn case_insensitive_literal() {
        let ps = PatternSet::builder()
            .add_literal_ci("SECRET", 0)
            .build()
            .unwrap();

        let matches = ps.scan(b"my secret key");
        assert_eq!(matches.len(), 1);
    }

    #[test]
    fn case_insensitive_regex() {
        let ps = PatternSet::builder()
            .add_regex_ci("password", 0)
            .build()
            .unwrap();

        let matches = ps.scan(b"my PASSWORD is here");
        assert_eq!(matches.len(), 1);
    }

    #[test]
    fn is_match_short_circuits() {
        let ps = PatternSet::builder()
            .add_literal("needle", 0)
            .build()
            .unwrap();

        assert!(ps.is_match(b"find the needle"));
        assert!(!ps.is_match(b"nothing here"));
    }

    #[test]
    fn overlapping_literals() {
        let ps = PatternSet::builder()
            .add_literal("ab", 0)
            .add_literal("bc", 1)
            .build()
            .unwrap();

        let matches = ps.scan(b"abc");
        assert!(matches.iter().any(|m| m.pattern_id == 0));
        assert!(matches.iter().any(|m| m.pattern_id == 1));
    }

    #[test]
    fn multiple_regex_matches() {
        let ps = PatternSet::builder()
            .add_regex(r"\d+", 0)
            .build()
            .unwrap();

        let matches = ps.scan(b"a1b22c333");
        assert_eq!(matches.len(), 3);
    }

    #[test]
    fn empty_input() {
        let ps = PatternSet::builder()
            .add_literal("x", 0)
            .build()
            .unwrap();

        assert!(ps.scan(b"").is_empty());
        assert!(!ps.is_match(b""));
    }

    #[test]
    fn scan_str_convenience() {
        let ps = PatternSet::builder()
            .add_literal("hello", 0)
            .build()
            .unwrap();

        let matches = ps.scan_str("say hello world");
        assert_eq!(matches.len(), 1);
    }

    #[test]
    fn invalid_regex_errors() {
        let result = PatternSet::builder()
            .add_regex("[invalid(", 0)
            .build();
        assert!(result.is_err());
    }

    #[test]
    fn large_pattern_set() {
        let mut builder = PatternSet::builder();
        for i in 0..100 {
            builder = builder.add_literal(&format!("pattern_{i}"), i);
        }
        let ps = builder.build().unwrap();
        assert_eq!(ps.pattern_count(), 100);

        let matches = ps.scan(b"contains pattern_42 and pattern_99");
        assert!(matches.iter().any(|m| m.pattern_id == 42));
        assert!(matches.iter().any(|m| m.pattern_id == 99));
    }

    #[test]
    fn match_result_offsets_correct() {
        let ps = PatternSet::builder()
            .add_literal("xyz", 0)
            .build()
            .unwrap();

        let matches = ps.scan(b"--xyz--");
        assert_eq!(matches[0].start, 2);
        assert_eq!(matches[0].end, 5);
    }

    #[test]
    fn binary_input() {
        let ps = PatternSet::builder()
            .add_literal("ELF", 0)
            .build()
            .unwrap();

        let input = b"\x7fELF\x02\x01\x01\x00";
        let matches = ps.scan(input);
        assert_eq!(matches.len(), 1);
    }

    #[test]
    fn unicode_input() {
        let ps = PatternSet::builder()
            .add_literal("пароль", 0) // Russian "password"
            .build()
            .unwrap();

        let matches = ps.scan("мой пароль тут".as_bytes());
        assert_eq!(matches.len(), 1);
    }
}