1use std::{collections::HashMap, ops::RangeBounds, sync::LazyLock};
13
14use parking_lot::{RwLock, RwLockWriteGuard};
15use regex_automata::{
16 Anchored, Input, PatternID,
17 hybrid::dfa::{Cache, DFA},
18 nfa::thompson::Config,
19};
20
21use super::{Point, Text, TextRange};
22
23impl Text {
24 pub fn search_fwd<R: RegexPattern>(
25 &mut self,
26 pat: R,
27 range: impl TextRange,
28 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
29 let range = range.to_range_fwd(self.len().byte());
30 let dfas = dfas_from_pat(pat)?;
31 let haystack = unsafe {
32 self.make_contiguous_in(range.clone());
33 self.continuous_in_unchecked(range.clone())
34 };
35
36 let mut fwd_input = Input::new(haystack);
37 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
38 let mut fwd_cache = dfas.fwd.1.write();
39 let mut rev_cache = dfas.rev.1.write();
40
41 let ref_self = self as &Text;
42 Ok(std::iter::from_fn(move || {
43 let init = fwd_input.start();
44 let h_end = loop {
45 if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
46 if half.offset() == init {
48 fwd_input.set_start(init + 1);
49 } else {
50 break half.offset();
51 }
52 } else {
53 return None;
54 }
55 };
56
57 fwd_input.set_start(h_end);
58 rev_input.set_end(h_end);
59
60 let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
61 return None;
62 };
63 let h_start = half.offset();
64
65 let p0 = ref_self.point_at(h_start + range.start);
66 let p1 = ref_self.point_at(h_end + range.start);
67
68 Some(R::get_match((p0, p1), half.pattern()))
69 }))
70 }
71
72 pub fn search_rev<R: RegexPattern>(
74 &mut self,
75 pat: R,
76 range: impl TextRange,
77 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
78 let range = range.to_range_rev(self.len().byte());
79 let dfas = dfas_from_pat(pat)?;
80 let haystack = unsafe {
81 self.make_contiguous_in(range.clone());
82 self.continuous_in_unchecked(range.clone())
83 };
84
85 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
86 let mut rev_input = Input::new(haystack);
87 let mut fwd_cache = dfas.fwd.1.write();
88 let mut rev_cache = dfas.rev.1.write();
89
90 let ref_self = self as &Text;
91 let gap = range.start;
92 Ok(std::iter::from_fn(move || {
93 let init = rev_input.end();
94 let start = loop {
95 if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
96 if half.offset() == init {
98 rev_input.set_end(init.checked_sub(1)?);
99 } else {
100 break half.offset();
101 }
102 } else {
103 return None;
104 }
105 };
106
107 rev_input.set_end(start);
108 fwd_input.set_start(start);
109
110 let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
111 return None;
112 };
113 let end = half.offset();
114
115 let p0 = ref_self.point_at(start + gap);
116 let p1 = ref_self.point_at(end + gap);
117
118 Some(R::get_match((p0, p1), half.pattern()))
119 }))
120 }
121
122 pub fn matches(
127 &mut self,
128 pat: impl RegexPattern,
129 range: impl TextRange,
130 ) -> Result<bool, Box<regex_syntax::Error>> {
131 let range = range.to_range_fwd(self.len().byte());
132 let dfas = dfas_from_pat(pat)?;
133
134 let haystack = unsafe {
135 self.make_contiguous_in(range.clone());
136 self.continuous_in_unchecked(range)
137 };
138 let fwd_input = Input::new(haystack);
139
140 let mut fwd_cache = dfas.fwd.1.write();
141 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
142 Ok(true)
143 } else {
144 Ok(false)
145 }
146 }
147}
148
149pub trait Matcheable: Sized {
150 fn matches(
151 &self,
152 pat: impl RegexPattern,
153 range: impl RangeBounds<usize> + Clone,
154 ) -> Result<bool, Box<regex_syntax::Error>>;
155}
156
157impl<S: AsRef<str>> Matcheable for S {
158 fn matches(
159 &self,
160 pat: impl RegexPattern,
161 range: impl RangeBounds<usize> + Clone,
162 ) -> Result<bool, Box<regex_syntax::Error>> {
163 let s = self.as_ref();
164 let (start, end) = crate::get_ends(range, s.len());
165 let dfas = dfas_from_pat(pat)?;
166 let fwd_input =
167 Input::new(unsafe { std::str::from_utf8_unchecked(&s.as_bytes()[start..end]) });
168
169 let mut fwd_cache = dfas.fwd.1.write();
170 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
171 Ok(true)
172 } else {
173 Ok(false)
174 }
175 }
176}
177
178pub struct Searcher {
179 pat: String,
180 fwd_dfa: &'static DFA,
181 rev_dfa: &'static DFA,
182 fwd_cache: RwLockWriteGuard<'static, Cache>,
183 rev_cache: RwLockWriteGuard<'static, Cache>,
184}
185
186impl Searcher {
187 pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
188 let dfas = dfas_from_pat(&pat)?;
189 Ok(Self {
190 pat,
191 fwd_dfa: &dfas.fwd.0,
192 rev_dfa: &dfas.rev.0,
193 fwd_cache: dfas.fwd.1.write(),
194 rev_cache: dfas.rev.1.write(),
195 })
196 }
197
198 pub fn search_fwd<'b>(
199 &'b mut self,
200 text: &'b mut Text,
201 range: impl TextRange,
202 ) -> impl Iterator<Item = (Point, Point)> + 'b {
203 let range = range.to_range_fwd(text.len().byte());
204 let haystack = unsafe {
205 text.make_contiguous_in(range.clone());
206 text.continuous_in_unchecked(range.clone())
207 };
208 let mut fwd_input = Input::new(haystack).anchored(Anchored::No);
209 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
210 let mut last_point = text.point_at(range.start);
211
212 let fwd_dfa = &self.fwd_dfa;
213 let rev_dfa = &self.rev_dfa;
214 let fwd_cache = &mut self.fwd_cache;
215 let rev_cache = &mut self.rev_cache;
216 let gap = range.start;
217 std::iter::from_fn(move || {
218 let init = fwd_input.start();
219 let end = loop {
220 if let Ok(Some(half)) = fwd_dfa.try_search_fwd(fwd_cache, &fwd_input) {
221 if half.offset() == init {
223 fwd_input.set_start(init + 1);
224 } else {
225 break half.offset();
226 }
227 } else {
228 return None;
229 }
230 };
231
232 fwd_input.set_start(end);
233 rev_input.set_end(end);
234
235 let half = unsafe {
236 rev_dfa
237 .try_search_rev(rev_cache, &rev_input)
238 .unwrap()
239 .unwrap_unchecked()
240 };
241 let start = half.offset();
242
243 let start = unsafe {
244 std::str::from_utf8_unchecked(&haystack.as_bytes()[last_point.byte() - gap..start])
245 }
246 .chars()
247 .fold(last_point, |p, b| p.fwd(b));
248 let end = unsafe {
249 std::str::from_utf8_unchecked(&haystack.as_bytes()[start.byte() - gap..end])
250 }
251 .chars()
252 .fold(start, |p, b| p.fwd(b));
253
254 last_point = end;
255
256 Some((start, end))
257 })
258 }
259
260 pub fn search_rev<'b>(
261 &'b mut self,
262 text: &'b mut Text,
263 range: impl TextRange,
264 ) -> impl Iterator<Item = (Point, Point)> + 'b {
265 let range = range.to_range_rev(text.len().byte());
266 let haystack = unsafe {
267 text.make_contiguous_in(range.clone());
268 text.continuous_in_unchecked(range.clone())
269 };
270 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
271 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
272 let mut last_point = text.point_at(range.end);
273
274 let fwd_dfa = &self.fwd_dfa;
275 let rev_dfa = &self.rev_dfa;
276 let fwd_cache = &mut self.fwd_cache;
277 let rev_cache = &mut self.rev_cache;
278 let gap = range.start;
279 std::iter::from_fn(move || {
280 let init = rev_input.end();
281 let start = loop {
282 if let Ok(Some(half)) = rev_dfa.try_search_rev(rev_cache, &rev_input) {
283 if half.offset() == init {
285 rev_input.set_end(init - 1);
286 } else {
287 break half.offset();
288 }
289 } else {
290 return None;
291 }
292 };
293
294 fwd_input.set_start(start);
295 rev_input.set_end(start);
296
297 let half = fwd_dfa
298 .try_search_fwd(fwd_cache, &fwd_input)
299 .unwrap()
300 .unwrap();
301
302 let end = unsafe {
303 std::str::from_utf8_unchecked(
304 &haystack.as_bytes()[half.offset()..(last_point.byte() - gap)],
305 )
306 }
307 .chars()
308 .fold(last_point, |p, b| p.rev(b));
309 let start = unsafe {
310 std::str::from_utf8_unchecked(&haystack.as_bytes()[start..(end.byte() - gap)])
311 }
312 .chars()
313 .fold(end, |p, b| p.rev(b));
314
315 last_point = start;
316
317 Some((start, end))
318 })
319 }
320
321 pub fn matches(&mut self, query: impl AsRef<[u8]>) -> bool {
323 let input = Input::new(&query).anchored(Anchored::Yes);
324
325 let Ok(Some(half)) = self.fwd_dfa.try_search_fwd(&mut self.fwd_cache, &input) else {
326 return false;
327 };
328
329 half.offset() == query.as_ref().len()
330 }
331
332 pub fn is_empty(&self) -> bool {
334 self.pat.is_empty()
335 }
336}
337
338struct DFAs {
339 fwd: (DFA, RwLock<Cache>),
340 rev: (DFA, RwLock<Cache>),
341}
342
343fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
344 static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
345 LazyLock::new(RwLock::default);
346
347 let mut list = DFA_LIST.write();
348
349 let mut bytes = [0; 4];
350 let pat = pat.as_patterns(&mut bytes);
351
352 if let Some(dfas) = list.get(&pat) {
353 Ok(*dfas)
354 } else {
355 let pat = pat.leak();
356 let (fwd, rev) = pat.dfas()?;
357
358 let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
359 let dfas = Box::leak(Box::new(DFAs {
360 fwd: (fwd, RwLock::new(fwd_cache)),
361 rev: (rev, RwLock::new(rev_cache)),
362 }));
363 let _ = list.insert(pat, dfas);
364 Ok(dfas)
365 }
366}
367
368#[derive(Clone, Copy, PartialEq, Eq, Hash)]
369enum Patterns<'a> {
370 One(&'a str),
371 Many(&'a [&'static str]),
372}
373
374impl Patterns<'_> {
375 fn leak(&self) -> Patterns<'static> {
376 match self {
377 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
378 Patterns::Many(strs) => Patterns::Many(Vec::from(*strs).leak()),
379 }
380 }
381
382 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
383 let mut fwd_builder = DFA::builder();
384 fwd_builder.thompson(Config::new().utf8(false));
385 let mut rev_builder = DFA::builder();
386 rev_builder.thompson(Config::new().reverse(true).utf8(false));
387
388 match self {
389 Patterns::One(pat) => {
390 regex_syntax::Parser::new().parse(pat)?;
391 let fwd = fwd_builder.build(pat).unwrap();
392 let rev = rev_builder.build(pat).unwrap();
393 Ok((fwd, rev))
394 }
395 Patterns::Many(pats) => {
396 for pat in *pats {
397 regex_syntax::Parser::new().parse(pat)?;
398 }
399 let fwd = fwd_builder.build_many(pats).unwrap();
400 let rev = rev_builder.build_many(pats).unwrap();
401 Ok((fwd, rev))
402 }
403 }
404 }
405}
406
407pub trait RegexPattern: InnerRegexPattern {
408 type Match: 'static;
409
410 fn get_match(points: (Point, Point), pattern: PatternID) -> Self::Match;
411}
412
413impl RegexPattern for &str {
414 type Match = (Point, Point);
415
416 fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
417 points
418 }
419}
420
421impl RegexPattern for String {
422 type Match = (Point, Point);
423
424 fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
425 points
426 }
427}
428
429impl RegexPattern for &String {
430 type Match = (Point, Point);
431
432 fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
433 points
434 }
435}
436
437impl RegexPattern for char {
438 type Match = (Point, Point);
439
440 fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
441 points
442 }
443}
444
445impl<const N: usize> RegexPattern for [&'static str; N] {
446 type Match = (Point, Point, usize);
447
448 fn get_match(points: (Point, Point), pattern: PatternID) -> Self::Match {
449 (points.0, points.1, pattern.as_usize())
450 }
451}
452
453impl RegexPattern for &[&'static str] {
454 type Match = (Point, Point, usize);
455
456 fn get_match(points: (Point, Point), pattern: PatternID) -> Self::Match {
457 (points.0, points.1, pattern.as_usize())
458 }
459}
460
461trait InnerRegexPattern {
462 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
463}
464
465impl InnerRegexPattern for &str {
466 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
467 Patterns::One(self)
468 }
469}
470
471impl InnerRegexPattern for String {
472 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
473 Patterns::One(self)
474 }
475}
476
477impl InnerRegexPattern for &String {
478 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
479 Patterns::One(self)
480 }
481}
482
483impl InnerRegexPattern for char {
484 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
485 Patterns::One(self.encode_utf8(bytes) as &str)
486 }
487}
488
489impl<const N: usize> InnerRegexPattern for [&'static str; N] {
490 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
491 Patterns::Many(self)
492 }
493}
494
495impl InnerRegexPattern for &[&'static str] {
496 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
497 Patterns::Many(self)
498 }
499}