Skip to main content

matchy_ac/
lib.rs

1//! Offset-based Aho-Corasick Automaton
2//!
3//! This module implements an Aho-Corasick automaton that builds directly into
4//! the binary offset-based format. Unlike traditional implementations, this
5//! creates the serialized format during construction, allowing zero-copy
6//! memory-mapped operation.
7//!
8//! # Design
9//!
10//! The automaton is stored as a single `Vec<u8>` containing:
11//! - AC nodes with offset-based transitions
12//! - Edge arrays referenced by nodes
13//! - Pattern ID arrays referenced by nodes
14//!
15//! All operations (both building and matching) work directly on this buffer.
16
17use std::collections::{HashMap, VecDeque};
18use std::fmt;
19use std::mem;
20use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
21
22// Re-export MatchMode from shared crate
23pub use matchy_match_mode::MatchMode;
24
25// Validation module for AC automaton structures
26pub mod validation;
27
28// Re-export validation types for convenience
29pub use validation::{
30    validate_ac_reachability, validate_ac_structure, validate_pattern_references, ACStats,
31    ACValidationResult,
32};
33
34/// Error type for AC automaton operations
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum ACError {
37    /// Invalid pattern
38    InvalidPattern(String),
39    /// Resource limit exceeded (e.g., too many states)
40    ResourceLimitExceeded(String),
41    /// Invalid input
42    InvalidInput(String),
43}
44
45impl fmt::Display for ACError {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            Self::InvalidPattern(msg) => write!(f, "Invalid pattern: {msg}"),
49            Self::ResourceLimitExceeded(msg) => write!(f, "Resource limit exceeded: {msg}"),
50            Self::InvalidInput(msg) => write!(f, "Invalid input: {msg}"),
51        }
52    }
53}
54
55impl std::error::Error for ACError {}
56
57// Binary format structures for offset-based AC automaton
58
59/// State encoding type
60#[repr(u8)]
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum StateKind {
63    /// No transitions (terminal state only)
64    Empty = 0,
65    /// Single transition - stored inline in node (75-80% of states)
66    One = 1,
67    /// 2-8 transitions - sparse edge array (10-15% of states)
68    Sparse = 2,
69    /// 9+ transitions - dense lookup table (2-5% of states)
70    Dense = 3,
71}
72
73/// AC Node hot data (checked every transition)
74#[repr(C)]
75#[derive(Debug, Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout)]
76pub struct ACNodeHot {
77    /// State encoding type (StateKind enum)
78    pub state_kind: u8,
79    /// ONE encoding: character for single transition
80    pub one_char: u8,
81    /// Number of edges (SPARSE/DENSE states)
82    pub edge_count: u8,
83    /// Number of pattern IDs at this node
84    pub pattern_count: u8,
85    /// ONE encoding: target offset for single transition
86    pub one_target: u32,
87    /// Failure link offset
88    pub failure_offset: u32,
89    /// Offset to edges array (SPARSE/DENSE states)
90    pub edges_offset: u32,
91    /// Offset to pattern IDs array
92    pub patterns_offset: u32,
93}
94
95/// AC Edge (for sparse/dense states)
96#[repr(C)]
97#[derive(Debug, Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout)]
98pub struct ACEdge {
99    /// Input character (0-255)
100    pub character: u8,
101    /// Reserved for alignment
102    pub reserved: [u8; 3],
103    /// Offset to target node
104    pub target_offset: u32,
105}
106
107impl ACEdge {
108    fn new(character: u8, target_offset: u32) -> Self {
109        Self {
110            character,
111            reserved: [0; 3],
112            target_offset,
113        }
114    }
115}
116
117/// Dense lookup table (256 entries)
118#[repr(C)]
119#[derive(Debug, Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout)]
120pub struct DenseLookup {
121    /// Target offsets indexed by character (0-255)
122    /// 0 means no transition for that character
123    pub targets: [u32; 256],
124}
125
126// Note: Case-Insensitive Implementation
127//
128// Case-insensitive mode uses a memory-efficient approach:
129// - Patterns are normalized to lowercase during automaton construction
130// - Input text is normalized to lowercase during search (using SIMD)
131// - This avoids doubling the automaton size (compared to storing both upper/lower transitions)
132//
133// For ~16K PSL patterns, this reduces memory usage by approximately 50%.
134
135/// Builder for constructing the offset-based AC automaton
136///
137/// This uses temporary in-memory structures during construction,
138/// then serializes them into the final offset-based format.
139struct ACBuilder {
140    /// Temporary states during construction
141    states: Vec<BuilderState>,
142    /// Matching mode
143    mode: MatchMode,
144    /// Original patterns
145    patterns: Vec<String>,
146}
147
148/// Temporary state structure used during construction
149#[derive(Debug, Clone)]
150struct BuilderState {
151    transitions: HashMap<u8, u32>,
152    failure: u32,
153    outputs: Vec<u32>, // Pattern IDs
154}
155
156impl BuilderState {
157    fn new(_id: u32, _depth: u8) -> Self {
158        Self {
159            transitions: HashMap::new(),
160            failure: 0,
161            outputs: Vec::new(),
162        }
163    }
164
165    /// Classify state encoding based on transition count
166    ///
167    /// # State Encoding Selection
168    ///
169    /// - **Empty** (0 transitions): Terminal states only, no lookups needed
170    /// - **One** (1 transition): Store inline, eliminates cache miss (75-80% of states)
171    /// - **Sparse** (2-8 transitions): Linear search is optimal for this range
172    /// - **Dense** (9+ transitions): O(1) lookup table worth the 1KB overhead
173    fn classify_state_kind(&self) -> StateKind {
174        match self.transitions.len() {
175            0 => StateKind::Empty,
176            1 => StateKind::One,
177            2..=8 => StateKind::Sparse,
178            _ => StateKind::Dense, // 9+ transitions
179        }
180    }
181}
182
183impl ACBuilder {
184    fn new(mode: MatchMode) -> Self {
185        Self {
186            states: vec![BuilderState::new(0, 0)], // Root
187            mode,
188            patterns: Vec::new(),
189        }
190    }
191
192    /// Add a pattern to the automaton
193    ///
194    /// # Case-Insensitive Mode
195    ///
196    /// For case-insensitive matching, patterns are normalized to lowercase here.
197    /// This avoids the memory overhead of storing both uppercase and lowercase transitions.
198    ///
199    /// Example: Pattern "Hello" becomes "hello" with a single transition path,
200    /// rather than 2^5 = 32 paths for all case combinations.
201    fn add_pattern(&mut self, pattern: &str) -> Result<u32, ACError> {
202        let pattern_id = u32::try_from(self.patterns.len())
203            .map_err(|_| ACError::ResourceLimitExceeded("Pattern count exceeds u32::MAX".into()))?;
204        self.patterns.push(pattern.to_string());
205
206        // For case-insensitive mode, normalize pattern to lowercase during build
207        // We'll normalize text to lowercase during search instead of doubling transitions
208        let pattern_bytes: Vec<u8> = match self.mode {
209            MatchMode::CaseSensitive => pattern.as_bytes().to_vec(),
210            MatchMode::CaseInsensitive => pattern.to_lowercase().into_bytes(),
211        };
212
213        // Build trie path
214        let mut current = 0u32;
215        let mut depth = 0u8;
216
217        for &ch in &pattern_bytes {
218            depth += 1;
219
220            // Check if transition already exists
221            if let Some(&next) = self.states[current as usize].transitions.get(&ch) {
222                current = next;
223            } else {
224                // Create new state
225                let new_id = u32::try_from(self.states.len()).map_err(|_| {
226                    ACError::ResourceLimitExceeded("State count exceeds u32::MAX".into())
227                })?;
228                self.states.push(BuilderState::new(new_id, depth));
229                self.states[current as usize].transitions.insert(ch, new_id);
230                current = new_id;
231            }
232        }
233
234        // Add output
235        self.states[current as usize].outputs.push(pattern_id);
236
237        Ok(pattern_id)
238    }
239
240    fn build_failure_links(&mut self) {
241        let mut queue = VecDeque::new();
242
243        // Depth-1 states fail to root
244        let root_children: Vec<u32> = self.states[0].transitions.values().copied().collect();
245
246        for child in root_children {
247            self.states[child as usize].failure = 0;
248            queue.push_back(child);
249        }
250
251        // BFS to compute failure links
252        while let Some(state_id) = queue.pop_front() {
253            let transitions: Vec<(u8, u32)> = self.states[state_id as usize]
254                .transitions
255                .iter()
256                .map(|(&ch, &next)| (ch, next))
257                .collect();
258
259            for (ch, next_state) in transitions {
260                queue.push_back(next_state);
261
262                // Find failure state
263                let mut fail = self.states[state_id as usize].failure;
264                let mut failure_found = false;
265
266                // Follow failure links looking for a state with a transition for 'ch'
267                while fail != 0 {
268                    if let Some(&target) = self.states[fail as usize].transitions.get(&ch) {
269                        self.states[next_state as usize].failure = target;
270                        failure_found = true;
271                        break;
272                    }
273                    fail = self.states[fail as usize].failure;
274                }
275
276                // If not found, check root
277                if !failure_found {
278                    if let Some(&target) = self.states[0].transitions.get(&ch) {
279                        // Only set if target is not the node itself (avoid self-loop)
280                        if target == next_state {
281                            self.states[next_state as usize].failure = 0;
282                        } else {
283                            self.states[next_state as usize].failure = target;
284                        }
285                    } else {
286                        self.states[next_state as usize].failure = 0;
287                    }
288                }
289
290                // Merge outputs from ALL suffix states (via failure links)
291                // This is critical: we need to inherit patterns from the entire failure link chain
292                let mut suffix_state = self.states[next_state as usize].failure;
293                while suffix_state != 0 {
294                    let suffix_outputs = self.states[suffix_state as usize].outputs.clone();
295                    if !suffix_outputs.is_empty() {
296                        self.states[next_state as usize]
297                            .outputs
298                            .extend(suffix_outputs);
299                    }
300                    suffix_state = self.states[suffix_state as usize].failure;
301                }
302            }
303        }
304    }
305
306    /// Serialize into offset-based format with state-specific encoding
307    fn serialize(self) -> Result<Vec<u8>, ACError> {
308        let mut buffer = Vec::new();
309
310        // Calculate section sizes - using cache-optimized ACNodeHot (16 bytes)
311        let node_size = mem::size_of::<ACNodeHot>();
312        let edge_size = mem::size_of::<ACEdge>();
313        let dense_size = mem::size_of::<DenseLookup>();
314
315        let nodes_start = 0;
316        let nodes_size = self.states.len() * node_size;
317
318        // Classify states and count by type
319        // Root node (index 0) is ALWAYS Dense for O(1) lookup performance
320        let state_kinds: Vec<StateKind> = self
321            .states
322            .iter()
323            .enumerate()
324            .map(|(i, s)| {
325                if i == 0 {
326                    StateKind::Dense
327                } else {
328                    s.classify_state_kind()
329                }
330            })
331            .collect();
332
333        let dense_count = state_kinds
334            .iter()
335            .filter(|&&k| k == StateKind::Dense)
336            .count();
337        let sparse_edges: usize = self
338            .states
339            .iter()
340            .zip(&state_kinds)
341            .filter(|(_, &kind)| kind == StateKind::Sparse)
342            .map(|(s, _)| s.transitions.len())
343            .sum();
344
345        // ONE states don't need edge arrays!
346        let total_patterns: usize = self.states.iter().map(|s| s.outputs.len()).sum();
347
348        // Layout: [Nodes][Sparse Edges][Padding][Dense Lookups][Patterns]
349        let edges_start = nodes_size;
350        let edges_size = sparse_edges * edge_size;
351
352        // Calculate padding to align dense section to 64-byte boundary (only if we have dense lookups)
353        // DenseLookup now has #[repr(C, align(64))] for cache-line alignment
354        let unaligned_dense_start = edges_start + edges_size;
355        let dense_alignment = mem::align_of::<DenseLookup>(); // 64 bytes
356        let (dense_padding, dense_start) = if dense_count > 0 {
357            let padding =
358                (dense_alignment - (unaligned_dense_start % dense_alignment)) % dense_alignment;
359            (padding, unaligned_dense_start + padding)
360        } else {
361            // No dense lookups, so no need for padding - patterns come right after edges
362            (0, unaligned_dense_start)
363        };
364        let dense_size_total = dense_count * dense_size;
365
366        let patterns_start = dense_start + dense_size_total;
367        let patterns_size = total_patterns * mem::size_of::<u32>();
368
369        // Calculate total size (including alignment padding only if we have dense lookups)
370        let total_size = nodes_size + edges_size + dense_padding + dense_size_total + patterns_size;
371
372        // Reasonable size limit to prevent pathological inputs from causing OOM
373        // Set to 2GB which is large enough for legitimate databases but catches
374        // pathological inputs early
375        const MAX_BUFFER_SIZE: usize = 2_000_000_000; // 2GB
376
377        if total_size > MAX_BUFFER_SIZE {
378            return Err(ACError::ResourceLimitExceeded(format!(
379                "Pattern database too large: {} bytes ({} states, {} sparse edges, {} dense, {} patterns). \
380                     Maximum allowed is {} bytes. This may be caused by pathological patterns \
381                     with many null bytes or special characters.",
382                total_size,
383                self.states.len(),
384                sparse_edges,
385                dense_count,
386                total_patterns,
387                MAX_BUFFER_SIZE
388            )));
389        }
390
391        // Allocate buffer
392        buffer.resize(total_size, 0);
393
394        // Verify alignment of dense section
395        debug_assert_eq!(
396            dense_start % dense_alignment,
397            0,
398            "Dense section must be {}-byte aligned, but starts at offset {} ({}% alignment)",
399            dense_alignment,
400            dense_start,
401            dense_start % dense_alignment
402        );
403
404        // Track offsets for writing data
405        let mut edge_offset = edges_start;
406        let mut dense_offset = dense_start;
407        let mut pattern_offset = patterns_start;
408
409        let node_offsets: Vec<usize> = (0..self.states.len())
410            .map(|i| nodes_start + i * node_size)
411            .collect();
412
413        // Write each node with state-specific encoding
414        for (i, state) in self.states.iter().enumerate() {
415            let node_offset = node_offsets[i];
416            let kind = state_kinds[i];
417
418            // Prepare sorted edges for this state
419            let mut edges: Vec<(u8, u32)> = state
420                .transitions
421                .iter()
422                .map(|(&ch, &target)| {
423                    let offset = node_offsets[target as usize];
424                    let offset_u32 = u32::try_from(offset).map_err(|_| {
425                        ACError::ResourceLimitExceeded("Node offset exceeds u32::MAX".into())
426                    });
427                    offset_u32.map(|o| (ch, o))
428                })
429                .collect::<Result<Vec<_>, _>>()?;
430            edges.sort_by_key(|(ch, _)| *ch); // Sort for efficient lookup
431
432            // Write state-specific transition data
433            let (edges_offset_for_node, one_char, _one_target) = match kind {
434                StateKind::Empty => (0u32, 0u8, 0u32),
435
436                StateKind::One => {
437                    // Store single transition inline in node!
438                    let (ch, target) = edges[0];
439                    (target, ch, 0u32) // edges_offset stores target for ONE states
440                }
441
442                StateKind::Sparse => {
443                    // Write edges to sparse edge array
444                    let sparse_offset = u32::try_from(edge_offset).map_err(|_| {
445                        ACError::ResourceLimitExceeded("Sparse edge offset exceeds u32::MAX".into())
446                    })?;
447
448                    for (ch, target) in &edges {
449                        let edge = ACEdge::new(*ch, *target);
450                        buffer[edge_offset..edge_offset + edge_size]
451                            .copy_from_slice(edge.as_bytes());
452                        edge_offset += edge_size;
453                    }
454
455                    (sparse_offset, 0u8, 0u32)
456                }
457
458                StateKind::Dense => {
459                    // Write dense lookup table
460                    let lookup_offset = u32::try_from(dense_offset).map_err(|_| {
461                        ACError::ResourceLimitExceeded(
462                            "Dense lookup offset exceeds u32::MAX".into(),
463                        )
464                    })?;
465                    let mut lookup = DenseLookup {
466                        targets: [0u32; 256],
467                    };
468
469                    for (ch, target) in &edges {
470                        lookup.targets[*ch as usize] = *target;
471                    }
472
473                    buffer[dense_offset..dense_offset + dense_size]
474                        .copy_from_slice(lookup.as_bytes());
475                    dense_offset += dense_size;
476
477                    (lookup_offset, 0u8, 0u32)
478                }
479            };
480
481            // Write pattern IDs
482            let patterns_offset_for_node = if state.outputs.is_empty() {
483                0u32
484            } else {
485                u32::try_from(pattern_offset).map_err(|_| {
486                    ACError::ResourceLimitExceeded("Pattern offset exceeds u32::MAX".into())
487                })?
488            };
489
490            for &pattern_id in &state.outputs {
491                buffer[pattern_offset..pattern_offset + 4]
492                    .copy_from_slice(&pattern_id.to_le_bytes());
493                pattern_offset += mem::size_of::<u32>();
494            }
495
496            // Write cache-optimized hot node (16 bytes)
497            let failure_offset = if state.failure == 0 {
498                0u32
499            } else {
500                u32::try_from(node_offsets[state.failure as usize]).map_err(|_| {
501                    ACError::ResourceLimitExceeded("Failure offset exceeds u32::MAX".into())
502                })?
503            };
504
505            // Edge and pattern counts saturate at u8::MAX (255)
506            let edge_count_u8 = match kind {
507                StateKind::One => 0, // Single edge stored inline, not in edge array
508                _ => u8::try_from(state.transitions.len()).unwrap_or(u8::MAX),
509            };
510            let pattern_count_u8 = u8::try_from(state.outputs.len()).unwrap_or(u8::MAX);
511
512            // Create hot node with optimal field ordering for cache access
513            let one_target = match kind {
514                StateKind::One => edges[0].1,
515                _ => 0,
516            };
517
518            let node = ACNodeHot {
519                state_kind: kind as u8,
520                one_char,
521                edge_count: edge_count_u8,
522                pattern_count: pattern_count_u8,
523                one_target,
524                failure_offset,
525                edges_offset: edges_offset_for_node,
526                patterns_offset: patterns_offset_for_node,
527            };
528
529            buffer[node_offset..node_offset + node_size].copy_from_slice(node.as_bytes());
530        }
531
532        Ok(buffer)
533    }
534}
535
536/// Offset-based Aho-Corasick automaton for building
537///
538/// All data is stored in a single byte buffer using offsets.
539/// This struct is only used for building the automaton.
540/// Querying is done via paraglob_offset's optimized implementation.
541pub struct ACAutomaton {
542    /// Binary buffer containing all automaton data
543    buffer: Vec<u8>,
544    /// Number of AC nodes in the automaton
545    node_count: usize,
546}
547
548impl ACAutomaton {
549    /// Create a new AC automaton (initially empty)
550    #[must_use]
551    pub fn new(_mode: MatchMode) -> Self {
552        Self {
553            buffer: Vec::new(),
554            node_count: 0,
555        }
556    }
557
558    /// Build the automaton from patterns
559    ///
560    /// This constructs the offset-based binary format directly.
561    pub fn build(patterns: &[&str], mode: MatchMode) -> Result<Self, ACError> {
562        if patterns.is_empty() {
563            return Err(ACError::InvalidPattern("No patterns provided".to_string()));
564        }
565
566        let mut builder = ACBuilder::new(mode);
567
568        for pattern in patterns {
569            if pattern.is_empty() {
570                return Err(ACError::InvalidPattern("Empty pattern".to_string()));
571            }
572            builder.add_pattern(pattern)?; // Propagate error
573        }
574
575        builder.build_failure_links();
576        let node_count = builder.states.len();
577        let buffer = builder.serialize()?;
578
579        Ok(Self { buffer, node_count })
580    }
581
582    /// Get the buffer (for serialization)
583    #[must_use]
584    pub fn buffer(&self) -> &[u8] {
585        &self.buffer
586    }
587
588    /// Get the number of AC nodes in the automaton
589    #[must_use]
590    pub fn node_count(&self) -> usize {
591        self.node_count
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598
599    #[test]
600    fn test_build_simple() {
601        let patterns = vec!["he", "she", "his", "hers"];
602        let ac = ACAutomaton::build(&patterns, MatchMode::CaseSensitive).unwrap();
603
604        assert!(!ac.buffer.is_empty());
605    }
606}
607
608// Note: Query method tests removed - ACAutomaton is now only used for building.
609// Querying is done via paraglob_offset's optimized inline implementation.