1use std::collections::{HashMap, VecDeque};
18use std::fmt;
19use std::mem;
20use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
21
22pub use matchy_match_mode::MatchMode;
24
25pub mod validation;
27
28pub use validation::{
30 validate_ac_reachability, validate_ac_structure, validate_pattern_references, ACStats,
31 ACValidationResult,
32};
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum ACError {
37 InvalidPattern(String),
39 ResourceLimitExceeded(String),
41 InvalidInput(String),
43}
44
45impl fmt::Display for ACError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::InvalidPattern(msg) => write!(f, "Invalid pattern: {msg}"),
49 Self::ResourceLimitExceeded(msg) => write!(f, "Resource limit exceeded: {msg}"),
50 Self::InvalidInput(msg) => write!(f, "Invalid input: {msg}"),
51 }
52 }
53}
54
55impl std::error::Error for ACError {}
56
57#[repr(u8)]
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum StateKind {
63 Empty = 0,
65 One = 1,
67 Sparse = 2,
69 Dense = 3,
71}
72
73#[repr(C)]
75#[derive(Debug, Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout)]
76pub struct ACNodeHot {
77 pub state_kind: u8,
79 pub one_char: u8,
81 pub edge_count: u8,
83 pub pattern_count: u8,
85 pub one_target: u32,
87 pub failure_offset: u32,
89 pub edges_offset: u32,
91 pub patterns_offset: u32,
93}
94
95#[repr(C)]
97#[derive(Debug, Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout)]
98pub struct ACEdge {
99 pub character: u8,
101 pub reserved: [u8; 3],
103 pub target_offset: u32,
105}
106
107impl ACEdge {
108 fn new(character: u8, target_offset: u32) -> Self {
109 Self {
110 character,
111 reserved: [0; 3],
112 target_offset,
113 }
114 }
115}
116
117#[repr(C)]
119#[derive(Debug, Clone, Copy, FromBytes, IntoBytes, Immutable, KnownLayout)]
120pub struct DenseLookup {
121 pub targets: [u32; 256],
124}
125
126struct ACBuilder {
140 states: Vec<BuilderState>,
142 mode: MatchMode,
144 patterns: Vec<String>,
146}
147
148#[derive(Debug, Clone)]
150struct BuilderState {
151 transitions: HashMap<u8, u32>,
152 failure: u32,
153 outputs: Vec<u32>, }
155
156impl BuilderState {
157 fn new(_id: u32, _depth: u8) -> Self {
158 Self {
159 transitions: HashMap::new(),
160 failure: 0,
161 outputs: Vec::new(),
162 }
163 }
164
165 fn classify_state_kind(&self) -> StateKind {
174 match self.transitions.len() {
175 0 => StateKind::Empty,
176 1 => StateKind::One,
177 2..=8 => StateKind::Sparse,
178 _ => StateKind::Dense, }
180 }
181}
182
183impl ACBuilder {
184 fn new(mode: MatchMode) -> Self {
185 Self {
186 states: vec![BuilderState::new(0, 0)], mode,
188 patterns: Vec::new(),
189 }
190 }
191
192 fn add_pattern(&mut self, pattern: &str) -> Result<u32, ACError> {
202 let pattern_id = u32::try_from(self.patterns.len())
203 .map_err(|_| ACError::ResourceLimitExceeded("Pattern count exceeds u32::MAX".into()))?;
204 self.patterns.push(pattern.to_string());
205
206 let pattern_bytes: Vec<u8> = match self.mode {
209 MatchMode::CaseSensitive => pattern.as_bytes().to_vec(),
210 MatchMode::CaseInsensitive => pattern.to_lowercase().into_bytes(),
211 };
212
213 let mut current = 0u32;
215 let mut depth = 0u8;
216
217 for &ch in &pattern_bytes {
218 depth += 1;
219
220 if let Some(&next) = self.states[current as usize].transitions.get(&ch) {
222 current = next;
223 } else {
224 let new_id = u32::try_from(self.states.len()).map_err(|_| {
226 ACError::ResourceLimitExceeded("State count exceeds u32::MAX".into())
227 })?;
228 self.states.push(BuilderState::new(new_id, depth));
229 self.states[current as usize].transitions.insert(ch, new_id);
230 current = new_id;
231 }
232 }
233
234 self.states[current as usize].outputs.push(pattern_id);
236
237 Ok(pattern_id)
238 }
239
240 fn build_failure_links(&mut self) {
241 let mut queue = VecDeque::new();
242
243 let root_children: Vec<u32> = self.states[0].transitions.values().copied().collect();
245
246 for child in root_children {
247 self.states[child as usize].failure = 0;
248 queue.push_back(child);
249 }
250
251 while let Some(state_id) = queue.pop_front() {
253 let transitions: Vec<(u8, u32)> = self.states[state_id as usize]
254 .transitions
255 .iter()
256 .map(|(&ch, &next)| (ch, next))
257 .collect();
258
259 for (ch, next_state) in transitions {
260 queue.push_back(next_state);
261
262 let mut fail = self.states[state_id as usize].failure;
264 let mut failure_found = false;
265
266 while fail != 0 {
268 if let Some(&target) = self.states[fail as usize].transitions.get(&ch) {
269 self.states[next_state as usize].failure = target;
270 failure_found = true;
271 break;
272 }
273 fail = self.states[fail as usize].failure;
274 }
275
276 if !failure_found {
278 if let Some(&target) = self.states[0].transitions.get(&ch) {
279 if target == next_state {
281 self.states[next_state as usize].failure = 0;
282 } else {
283 self.states[next_state as usize].failure = target;
284 }
285 } else {
286 self.states[next_state as usize].failure = 0;
287 }
288 }
289
290 let mut suffix_state = self.states[next_state as usize].failure;
293 while suffix_state != 0 {
294 let suffix_outputs = self.states[suffix_state as usize].outputs.clone();
295 if !suffix_outputs.is_empty() {
296 self.states[next_state as usize]
297 .outputs
298 .extend(suffix_outputs);
299 }
300 suffix_state = self.states[suffix_state as usize].failure;
301 }
302 }
303 }
304 }
305
306 fn serialize(self) -> Result<Vec<u8>, ACError> {
308 let mut buffer = Vec::new();
309
310 let node_size = mem::size_of::<ACNodeHot>();
312 let edge_size = mem::size_of::<ACEdge>();
313 let dense_size = mem::size_of::<DenseLookup>();
314
315 let nodes_start = 0;
316 let nodes_size = self.states.len() * node_size;
317
318 let state_kinds: Vec<StateKind> = self
321 .states
322 .iter()
323 .enumerate()
324 .map(|(i, s)| {
325 if i == 0 {
326 StateKind::Dense
327 } else {
328 s.classify_state_kind()
329 }
330 })
331 .collect();
332
333 let dense_count = state_kinds
334 .iter()
335 .filter(|&&k| k == StateKind::Dense)
336 .count();
337 let sparse_edges: usize = self
338 .states
339 .iter()
340 .zip(&state_kinds)
341 .filter(|(_, &kind)| kind == StateKind::Sparse)
342 .map(|(s, _)| s.transitions.len())
343 .sum();
344
345 let total_patterns: usize = self.states.iter().map(|s| s.outputs.len()).sum();
347
348 let edges_start = nodes_size;
350 let edges_size = sparse_edges * edge_size;
351
352 let unaligned_dense_start = edges_start + edges_size;
355 let dense_alignment = mem::align_of::<DenseLookup>(); let (dense_padding, dense_start) = if dense_count > 0 {
357 let padding =
358 (dense_alignment - (unaligned_dense_start % dense_alignment)) % dense_alignment;
359 (padding, unaligned_dense_start + padding)
360 } else {
361 (0, unaligned_dense_start)
363 };
364 let dense_size_total = dense_count * dense_size;
365
366 let patterns_start = dense_start + dense_size_total;
367 let patterns_size = total_patterns * mem::size_of::<u32>();
368
369 let total_size = nodes_size + edges_size + dense_padding + dense_size_total + patterns_size;
371
372 const MAX_BUFFER_SIZE: usize = 2_000_000_000; if total_size > MAX_BUFFER_SIZE {
378 return Err(ACError::ResourceLimitExceeded(format!(
379 "Pattern database too large: {} bytes ({} states, {} sparse edges, {} dense, {} patterns). \
380 Maximum allowed is {} bytes. This may be caused by pathological patterns \
381 with many null bytes or special characters.",
382 total_size,
383 self.states.len(),
384 sparse_edges,
385 dense_count,
386 total_patterns,
387 MAX_BUFFER_SIZE
388 )));
389 }
390
391 buffer.resize(total_size, 0);
393
394 debug_assert_eq!(
396 dense_start % dense_alignment,
397 0,
398 "Dense section must be {}-byte aligned, but starts at offset {} ({}% alignment)",
399 dense_alignment,
400 dense_start,
401 dense_start % dense_alignment
402 );
403
404 let mut edge_offset = edges_start;
406 let mut dense_offset = dense_start;
407 let mut pattern_offset = patterns_start;
408
409 let node_offsets: Vec<usize> = (0..self.states.len())
410 .map(|i| nodes_start + i * node_size)
411 .collect();
412
413 for (i, state) in self.states.iter().enumerate() {
415 let node_offset = node_offsets[i];
416 let kind = state_kinds[i];
417
418 let mut edges: Vec<(u8, u32)> = state
420 .transitions
421 .iter()
422 .map(|(&ch, &target)| {
423 let offset = node_offsets[target as usize];
424 let offset_u32 = u32::try_from(offset).map_err(|_| {
425 ACError::ResourceLimitExceeded("Node offset exceeds u32::MAX".into())
426 });
427 offset_u32.map(|o| (ch, o))
428 })
429 .collect::<Result<Vec<_>, _>>()?;
430 edges.sort_by_key(|(ch, _)| *ch); let (edges_offset_for_node, one_char, _one_target) = match kind {
434 StateKind::Empty => (0u32, 0u8, 0u32),
435
436 StateKind::One => {
437 let (ch, target) = edges[0];
439 (target, ch, 0u32) }
441
442 StateKind::Sparse => {
443 let sparse_offset = u32::try_from(edge_offset).map_err(|_| {
445 ACError::ResourceLimitExceeded("Sparse edge offset exceeds u32::MAX".into())
446 })?;
447
448 for (ch, target) in &edges {
449 let edge = ACEdge::new(*ch, *target);
450 buffer[edge_offset..edge_offset + edge_size]
451 .copy_from_slice(edge.as_bytes());
452 edge_offset += edge_size;
453 }
454
455 (sparse_offset, 0u8, 0u32)
456 }
457
458 StateKind::Dense => {
459 let lookup_offset = u32::try_from(dense_offset).map_err(|_| {
461 ACError::ResourceLimitExceeded(
462 "Dense lookup offset exceeds u32::MAX".into(),
463 )
464 })?;
465 let mut lookup = DenseLookup {
466 targets: [0u32; 256],
467 };
468
469 for (ch, target) in &edges {
470 lookup.targets[*ch as usize] = *target;
471 }
472
473 buffer[dense_offset..dense_offset + dense_size]
474 .copy_from_slice(lookup.as_bytes());
475 dense_offset += dense_size;
476
477 (lookup_offset, 0u8, 0u32)
478 }
479 };
480
481 let patterns_offset_for_node = if state.outputs.is_empty() {
483 0u32
484 } else {
485 u32::try_from(pattern_offset).map_err(|_| {
486 ACError::ResourceLimitExceeded("Pattern offset exceeds u32::MAX".into())
487 })?
488 };
489
490 for &pattern_id in &state.outputs {
491 buffer[pattern_offset..pattern_offset + 4]
492 .copy_from_slice(&pattern_id.to_le_bytes());
493 pattern_offset += mem::size_of::<u32>();
494 }
495
496 let failure_offset = if state.failure == 0 {
498 0u32
499 } else {
500 u32::try_from(node_offsets[state.failure as usize]).map_err(|_| {
501 ACError::ResourceLimitExceeded("Failure offset exceeds u32::MAX".into())
502 })?
503 };
504
505 let edge_count_u8 = match kind {
507 StateKind::One => 0, _ => u8::try_from(state.transitions.len()).unwrap_or(u8::MAX),
509 };
510 let pattern_count_u8 = u8::try_from(state.outputs.len()).unwrap_or(u8::MAX);
511
512 let one_target = match kind {
514 StateKind::One => edges[0].1,
515 _ => 0,
516 };
517
518 let node = ACNodeHot {
519 state_kind: kind as u8,
520 one_char,
521 edge_count: edge_count_u8,
522 pattern_count: pattern_count_u8,
523 one_target,
524 failure_offset,
525 edges_offset: edges_offset_for_node,
526 patterns_offset: patterns_offset_for_node,
527 };
528
529 buffer[node_offset..node_offset + node_size].copy_from_slice(node.as_bytes());
530 }
531
532 Ok(buffer)
533 }
534}
535
536pub struct ACAutomaton {
542 buffer: Vec<u8>,
544 node_count: usize,
546}
547
548impl ACAutomaton {
549 #[must_use]
551 pub fn new(_mode: MatchMode) -> Self {
552 Self {
553 buffer: Vec::new(),
554 node_count: 0,
555 }
556 }
557
558 pub fn build(patterns: &[&str], mode: MatchMode) -> Result<Self, ACError> {
562 if patterns.is_empty() {
563 return Err(ACError::InvalidPattern("No patterns provided".to_string()));
564 }
565
566 let mut builder = ACBuilder::new(mode);
567
568 for pattern in patterns {
569 if pattern.is_empty() {
570 return Err(ACError::InvalidPattern("Empty pattern".to_string()));
571 }
572 builder.add_pattern(pattern)?; }
574
575 builder.build_failure_links();
576 let node_count = builder.states.len();
577 let buffer = builder.serialize()?;
578
579 Ok(Self { buffer, node_count })
580 }
581
582 #[must_use]
584 pub fn buffer(&self) -> &[u8] {
585 &self.buffer
586 }
587
588 #[must_use]
590 pub fn node_count(&self) -> usize {
591 self.node_count
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[test]
600 fn test_build_simple() {
601 let patterns = vec!["he", "she", "his", "hers"];
602 let ac = ACAutomaton::build(&patterns, MatchMode::CaseSensitive).unwrap();
603
604 assert!(!ac.buffer.is_empty());
605 }
606}
607
608