Skip to main content

fuzzy_regex/engine/
guard_nfa.rs

1//! Optimized Guard-based Levenshtein NFA implementation.
2//!
3//! Similar to mrab-regex's fuzzy guards - uses bit-packed state encoding
4//! for efficient fuzzy matching with early termination.
5
6#![allow(
7    clippy::too_many_lines,
8    clippy::float_cmp,
9    clippy::allow_attributes,
10    let_underscore_drop
11)]
12
13use crate::engine::damlev::{DamLevMatch, EditLimits};
14
15/// Guard-based NFA for fast fuzzy matching.
16#[derive(Debug)]
17pub struct GuardNfa {
18    pattern: Vec<char>,
19    pattern_len: usize,
20    edit_limits: EditLimits,
21    case_insensitive: bool,
22    first_char: char,
23}
24
25impl GuardNfa {
26    /// Create a new guard-based NFA.
27    #[must_use]
28    pub fn new(pattern: &str, edit_limits: EditLimits, case_insensitive: bool) -> Self {
29        let pattern: Vec<char> = if case_insensitive {
30            pattern.to_lowercase().chars().collect()
31        } else {
32            pattern.chars().collect()
33        };
34        let pattern_len = pattern.len();
35        let first_char = pattern.first().copied().unwrap_or('\0');
36
37        GuardNfa {
38            pattern,
39            pattern_len,
40            edit_limits,
41            case_insensitive,
42            first_char,
43        }
44    }
45
46    /// Find the first match in text with early termination.
47    #[inline]
48    #[must_use]
49    pub fn find_first(&self, text: &str, threshold: f32) -> Option<DamLevMatch> {
50        let max_edits = self.edit_limits.max_edits as usize;
51
52        if self.pattern_len == 0 {
53            return Some(DamLevMatch {
54                start: 0,
55                end: 0,
56                insertions: 0,
57                deletions: 0,
58                substitutions: 0,
59                swaps: 0,
60                similarity: 1.0,
61            });
62        }
63
64        let text_chars: Vec<char> = if self.case_insensitive {
65            text.chars()
66                .map(|c| c.to_lowercase().next().unwrap_or(c))
67                .collect()
68        } else {
69            text.chars().collect()
70        };
71        let text_len = text_chars.len();
72
73        let mut char_to_byte: Vec<usize> = vec![0; text_len + 1];
74        let mut byte_pos = 0;
75        for (i, c) in text.char_indices() {
76            char_to_byte[i] = byte_pos;
77            byte_pos += c.len_utf8();
78        }
79        char_to_byte[text_len] = byte_pos;
80
81        if text_len == 0 {
82            if self.pattern_len <= max_edits {
83                let edits = self.pattern_len;
84                let sim = 1.0 - (edits as f32 / (self.pattern_len + max_edits) as f32);
85                return Some(DamLevMatch {
86                    start: 0,
87                    end: 0,
88                    insertions: 0,
89                    deletions: edits as u8,
90                    substitutions: 0,
91                    swaps: 0,
92                    similarity: sim,
93                });
94            }
95            return None;
96        }
97
98        let m = self.pattern_len;
99
100        // Vec-based states but with bit-packed seen for O(1) deduplication
101        let mut active: Vec<(usize, usize, usize, u8, u8, u8, usize)> = Vec::with_capacity(32);
102        let mut next_active: Vec<(usize, usize, usize, u8, u8, u8, usize)> = Vec::with_capacity(32);
103
104        // Bit-packed seen: 12 bits (6 for pat_pos + 6 for edits)
105        let encode_key =
106            |pat_pos: usize, edits: usize| -> u128 { ((pat_pos as u128) << 6) | (edits as u128) };
107
108        let mut pos = 0;
109        while pos < text_len {
110            let text_char = text_chars[pos];
111            let mut new_seen: u128 = 0;
112
113            // Start new match at this position
114            active.clear();
115            active.push((0, 0, pos, 0, 0, 0, pos));
116
117            next_active.clear();
118
119            // Process all active states
120            for &(pat_pos, edits, start_pos, ins, del, sub, last_consumed) in &active {
121                if pat_pos >= m || edits > max_edits {
122                    continue;
123                }
124
125                let pat_char = self.pattern[pat_pos];
126
127                // Exact match
128                if text_char == pat_char {
129                    let key = encode_key(pat_pos + 1, edits);
130                    if new_seen & (1u128 << key) == 0 {
131                        new_seen |= 1u128 << key;
132                        next_active.push((pat_pos + 1, edits, start_pos, ins, del, sub, pos));
133                    }
134                }
135
136                // Substitution
137                if edits < max_edits && text_char != pat_char {
138                    let key = encode_key(pat_pos + 1, edits + 1);
139                    if new_seen & (1u128 << key) == 0 {
140                        new_seen |= 1u128 << key;
141                        next_active.push((
142                            pat_pos + 1,
143                            edits + 1,
144                            start_pos,
145                            ins,
146                            del,
147                            sub + 1,
148                            pos,
149                        ));
150                    }
151                }
152
153                // Insertion
154                if edits < max_edits {
155                    let key = encode_key(pat_pos, edits + 1);
156                    if new_seen & (1u128 << key) == 0 {
157                        new_seen |= 1u128 << key;
158                        next_active.push((
159                            pat_pos,
160                            edits + 1,
161                            start_pos,
162                            ins + 1,
163                            del,
164                            sub,
165                            last_consumed,
166                        ));
167                    }
168                }
169
170                // Deletion
171                if pat_pos + 1 < m && edits < max_edits {
172                    let key = encode_key(pat_pos + 1, edits + 1);
173                    if new_seen & (1u128 << key) == 0 {
174                        new_seen |= 1u128 << key;
175                        next_active.push((
176                            pat_pos + 1,
177                            edits + 1,
178                            start_pos,
179                            ins,
180                            del + 1,
181                            sub,
182                            last_consumed,
183                        ));
184                    }
185                }
186            }
187
188            // Check for matches - return immediately on first match
189            for &(pat_pos, edits, start_pos, ins, del, sub, last_consumed) in &next_active {
190                if pat_pos >= m && edits <= max_edits {
191                    let sim = 1.0 - (edits as f32 / (m + max_edits) as f32);
192                    if sim >= threshold {
193                        let match_end = last_consumed + 1;
194                        let byte_start = char_to_byte[start_pos];
195                        let byte_end = char_to_byte[match_end];
196                        return Some(DamLevMatch {
197                            start: byte_start,
198                            end: byte_end,
199                            insertions: ins,
200                            deletions: del,
201                            substitutions: sub,
202                            swaps: 0,
203                            similarity: sim,
204                        });
205                    }
206                }
207            }
208
209            std::mem::swap(&mut active, &mut next_active);
210
211            // Guard pruning - skip to first character if no active states
212            if active.is_empty() {
213                pos += 1;
214                while pos < text_len && text_chars[pos] != self.first_char {
215                    pos += 1;
216                }
217            } else {
218                pos += 1;
219            }
220        }
221
222        // Handle trailing deletions
223        for &(pat_pos, edits, start_pos, ins, del, sub, _last_consumed) in &active {
224            if pat_pos >= m && edits <= max_edits {
225                let sim = 1.0 - (edits as f32 / (m + max_edits) as f32);
226                if sim >= threshold {
227                    let byte_start = char_to_byte[start_pos];
228                    let byte_end = char_to_byte[text_len];
229                    return Some(DamLevMatch {
230                        start: byte_start,
231                        end: byte_end,
232                        insertions: ins,
233                        deletions: del,
234                        substitutions: sub,
235                        swaps: 0,
236                        similarity: sim,
237                    });
238                }
239            }
240
241            let remaining = m - pat_pos;
242            let new_edits = edits + remaining;
243            if new_edits <= max_edits {
244                let new_del = del + remaining as u8;
245                let sim = 1.0 - (new_edits as f32 / (m + max_edits) as f32);
246                if sim >= threshold {
247                    let byte_start = char_to_byte[start_pos];
248                    let byte_end = char_to_byte[text_len];
249                    return Some(DamLevMatch {
250                        start: byte_start,
251                        end: byte_end,
252                        insertions: ins,
253                        deletions: new_del,
254                        substitutions: sub,
255                        swaps: 0,
256                        similarity: sim,
257                    });
258                }
259            }
260        }
261
262        None
263    }
264}