1use std::{
14 collections::HashMap,
15 marker::PhantomData,
16 ops::Range,
17 sync::{LazyLock, Mutex},
18};
19
20use regex_cursor::{
21 Cursor, Input,
22 engines::hybrid::{try_search_fwd, try_search_rev},
23 regex_automata::{
24 PatternID,
25 hybrid::dfa::{Cache, DFA},
26 nfa::thompson,
27 util::syntax,
28 },
29};
30
31use super::TextRange;
32use crate::text::{Bytes, Strs, Text};
33
34#[derive(Clone)]
44pub struct Matches<'m, R> {
45 haystack: [&'m [u8]; 2],
46 b_start: usize,
47 dfas: &'static DFAs,
48 fwd_input: Input<SearchBytes<'m>>,
49 rev_input: Input<SearchBytes<'m>>,
50 _ghost: PhantomData<R>,
51}
52
53impl<'m, R> Matches<'m, R> {
54 pub fn range(self, range: impl TextRange) -> Self {
60 let [s0, s1] = self.haystack;
61 let range = range.to_range(s0.len() + s1.len());
62 let i0 = &s0[range.start.min(s0.len())..range.end.min(s0.len())];
63 let i1 = &s1[range.start.saturating_sub(s0.len())..range.end.saturating_sub(s0.len())];
64
65 Self {
66 fwd_input: Input::new(SearchBytes([i0, i1], 0)),
67 rev_input: Input::new(SearchBytes([i0, i1], 0)),
68 b_start: range.start,
69 ..self
70 }
71 }
72}
73
74impl<'m, R: RegexPattern> Iterator for Matches<'m, R> {
75 type Item = R::Match;
76
77 fn next(&mut self) -> Option<Self::Item> {
78 let mut fwd_cache = match self.dfas.fwd_cache.lock() {
79 Ok(cache) => cache,
80 Err(err) => err.into_inner(),
81 };
82
83 let (fwd_input, rev_input) = (&mut self.fwd_input, &mut self.rev_input);
84 let h_end = try_search_fwd(&self.dfas.fwd_dfa, &mut fwd_cache, fwd_input)
85 .ok()
86 .flatten()?
87 .offset();
88
89 let mut rev_cache = match self.dfas.rev_cache.lock() {
90 Ok(cache) => cache,
91 Err(err) => err.into_inner(),
92 };
93
94 rev_input.set_range(fwd_input.start()..h_end);
95 let half = try_search_rev(&self.dfas.rev_dfa, &mut rev_cache, rev_input)
96 .ok()
97 .flatten()?;
98 let h_start = half.offset();
99
100 fwd_input.set_start(h_end);
101
102 if h_start == h_end {
104 fwd_input.set_start(h_end + 1);
105 }
106
107 Some(R::get_match(
108 self.b_start + h_start..self.b_start + h_end,
109 half.pattern(),
110 ))
111 }
112}
113
114impl<'m, R: RegexPattern> DoubleEndedIterator for Matches<'m, R> {
115 fn next_back(&mut self) -> Option<Self::Item> {
116 let mut rev_cache = match self.dfas.rev_cache.lock() {
117 Ok(cache) => cache,
118 Err(err) => err.into_inner(),
119 };
120
121 let (fwd_input, rev_input) = (&mut self.fwd_input, &mut self.rev_input);
122 let h_start = try_search_rev(&self.dfas.rev_dfa, &mut rev_cache, fwd_input)
123 .ok()
124 .flatten()?
125 .offset();
126
127 let mut fwd_cache = match self.dfas.fwd_cache.lock() {
128 Ok(cache) => cache,
129 Err(err) => err.into_inner(),
130 };
131
132 rev_input.set_range(h_start..fwd_input.end());
133 let half = try_search_fwd(&self.dfas.fwd_dfa, &mut fwd_cache, rev_input)
134 .ok()
135 .flatten()?;
136 let h_end = half.offset();
137
138 fwd_input.set_end(h_start);
139
140 if h_start == h_end {
141 fwd_input.set_end(h_start.checked_sub(1)?);
142 }
143
144 Some(R::get_match(
145 self.b_start + h_start..self.b_start + h_end,
146 half.pattern(),
147 ))
148 }
149}
150
151pub trait RegexHaystack<'h> {
157 #[track_caller]
175 fn search<R: RegexPattern>(&'h self, pat: R) -> Matches<'h, R> {
176 match self.try_search(pat) {
177 Ok(matches) => matches,
178 Err(err) => panic!("{err}"),
179 }
180 }
181
182 fn try_search<R: RegexPattern>(
198 &'h self,
199 pat: R,
200 ) -> Result<Matches<'h, R>, Box<regex_syntax::Error>>;
201
202 fn contains_pat(&'h self, pat: impl RegexPattern) -> Result<bool, Box<regex_syntax::Error>> {
210 Ok(self.try_search(pat)?.next().is_some())
211 }
212
213 fn matches_pat(&'h self, pat: impl RegexPattern) -> Result<bool, Box<regex_syntax::Error>> {
218 let mut matches = self.try_search(pat)?;
219 Ok(matches
220 .next()
221 .is_some_and(|_| matches.fwd_input.start() == matches.fwd_input.end()))
222 }
223}
224
225impl<'b> RegexHaystack<'b> for Text {
226 fn try_search<R: RegexPattern>(
227 &'b self,
228 pat: R,
229 ) -> Result<Matches<'b, R>, Box<regex_syntax::Error>> {
230 let dfas = dfas_from_pat(pat)?;
231
232 let haystack = self.slices(..).to_array();
233
234 Ok(Matches {
235 haystack,
236 b_start: 0,
237 dfas,
238 fwd_input: Input::new(SearchBytes(haystack, 0)),
239 rev_input: Input::new(SearchBytes(haystack, 0)),
240 _ghost: PhantomData,
241 })
242 }
243}
244
245impl<'b> RegexHaystack<'b> for Bytes {
246 fn try_search<R: RegexPattern>(
247 &'b self,
248 pat: R,
249 ) -> Result<Matches<'b, R>, Box<regex_syntax::Error>> {
250 let dfas = dfas_from_pat(pat)?;
251
252 let haystack = self.slices(..).to_array();
253
254 Ok(Matches {
255 haystack,
256 b_start: 0,
257 dfas,
258 fwd_input: Input::new(SearchBytes(haystack, 0)),
259 rev_input: Input::new(SearchBytes(haystack, 0)),
260 _ghost: PhantomData,
261 })
262 }
263}
264
265impl<'b> RegexHaystack<'b> for Strs<'b> {
266 fn try_search<R: RegexPattern>(
267 &'b self,
268 pat: R,
269 ) -> Result<Matches<'b, R>, Box<regex_syntax::Error>> {
270 let dfas = dfas_from_pat(pat)?;
271
272 let haystack = self.slices(self.byte_range()).to_array();
273
274 Ok(Matches {
275 haystack,
276 b_start: self.byte_range().start,
277 dfas,
278 fwd_input: Input::new(SearchBytes(haystack, 0)),
279 rev_input: Input::new(SearchBytes(haystack, 0)),
280 _ghost: PhantomData,
281 })
282 }
283}
284
285impl<'s, S: std::ops::Deref<Target = str>> RegexHaystack<'s> for S {
286 fn try_search<R: RegexPattern>(
287 &'s self,
288 pat: R,
289 ) -> Result<Matches<'s, R>, Box<regex_syntax::Error>> {
290 let dfas = dfas_from_pat(pat)?;
291
292 let haystack = [self.deref().as_bytes(), &[]];
293
294 Ok(Matches {
295 haystack,
296 b_start: 0,
297 dfas,
298 fwd_input: Input::new(SearchBytes(haystack, 0)),
299 rev_input: Input::new(SearchBytes(haystack, 0)),
300 _ghost: PhantomData,
301 })
302 }
303}
304
305struct DFAs {
306 fwd_dfa: DFA,
307 fwd_cache: Mutex<Cache>,
308 rev_dfa: DFA,
309 rev_cache: Mutex<Cache>,
310}
311
312fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
313 static DFA_LIST: LazyLock<Mutex<HashMap<Patterns<'static>, &'static DFAs>>> =
314 LazyLock::new(Mutex::default);
315
316 let mut list = DFA_LIST.lock().unwrap();
317
318 let mut bytes = [0; 4];
319 let pat = pat.as_patterns(&mut bytes);
320
321 if let Some(dfas) = list.get(&pat) {
322 Ok(*dfas)
323 } else {
324 let pat = pat.leak();
325 let (fwd_dfa, rev_dfa) = pat.dfas()?;
326
327 let (fwd_cache, rev_cache) = (Cache::new(&fwd_dfa), Cache::new(&rev_dfa));
328 let dfas = Box::leak(Box::new(DFAs {
329 fwd_dfa,
330 fwd_cache: Mutex::new(fwd_cache),
331 rev_dfa,
332 rev_cache: Mutex::new(rev_cache),
333 }));
334 let _ = list.insert(pat, dfas);
335 Ok(dfas)
336 }
337}
338
339#[derive(Clone, Copy, PartialEq, Eq, Hash)]
340enum Patterns<'a> {
341 One(&'a str),
342 Many(&'a [&'a str]),
343}
344
345impl Patterns<'_> {
346 fn leak(&self) -> Patterns<'static> {
347 match self {
348 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
349 Patterns::Many(strs) => Patterns::Many(
350 strs.iter()
351 .map(|s| {
352 let str: &'static str = s.to_string().leak();
353 str
354 })
355 .collect::<Vec<&'static str>>()
356 .leak(),
357 ),
358 }
359 }
360
361 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
362 let mut fwd_builder = DFA::builder();
363 fwd_builder.syntax(syntax::Config::new().multi_line(true));
364 let mut rev_builder = DFA::builder();
365 rev_builder
366 .syntax(syntax::Config::new().multi_line(true))
367 .thompson(thompson::Config::new().reverse(true));
368
369 match self {
370 Patterns::One(pat) => {
371 let pat = pat.replace("\\b", "(?-u:\\b)");
372 syntax::parse(&pat)?;
373 let fwd = fwd_builder.build(&pat).unwrap();
374 let rev = rev_builder.build(&pat).unwrap();
375 Ok((fwd, rev))
376 }
377 Patterns::Many(pats) => {
378 let pats: Vec<String> =
379 pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
380 for pat in pats.iter() {
381 regex_syntax::Parser::new().parse(pat)?;
382 }
383 let fwd = fwd_builder.build_many(&pats).unwrap();
384 let rev = rev_builder.build_many(&pats).unwrap();
385 Ok((fwd, rev))
386 }
387 }
388 }
389}
390
391pub trait RegexPattern: InnerRegexPattern {
396 type Match;
398
399 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match;
401}
402
403impl RegexPattern for &str {
404 type Match = Range<usize>;
405
406 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
407 points
408 }
409}
410
411impl RegexPattern for String {
412 type Match = Range<usize>;
413
414 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
415 points
416 }
417}
418
419impl RegexPattern for &String {
420 type Match = Range<usize>;
421
422 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
423 points
424 }
425}
426
427impl RegexPattern for char {
428 type Match = Range<usize>;
429
430 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
431 points
432 }
433}
434
435impl<const N: usize> RegexPattern for [&str; N] {
436 type Match = (usize, Range<usize>);
437
438 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
439 (pattern.as_usize(), points)
440 }
441}
442
443impl RegexPattern for &[&str] {
444 type Match = (usize, Range<usize>);
445
446 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
447 (pattern.as_usize(), points)
448 }
449}
450
451trait InnerRegexPattern {
452 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
453}
454
455impl InnerRegexPattern for &str {
456 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
457 Patterns::One(self)
458 }
459}
460
461impl InnerRegexPattern for String {
462 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
463 Patterns::One(self)
464 }
465}
466
467impl InnerRegexPattern for &String {
468 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
469 Patterns::One(self)
470 }
471}
472
473impl InnerRegexPattern for char {
474 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
475 Patterns::One(self.encode_utf8(bytes) as &str)
476 }
477}
478
479impl<const N: usize> InnerRegexPattern for [&str; N] {
480 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
481 Patterns::Many(self)
482 }
483}
484
485impl InnerRegexPattern for &[&str] {
486 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
487 Patterns::Many(self)
488 }
489}
490
491#[derive(Clone, Copy)]
492struct SearchBytes<'a>([&'a [u8]; 2], usize);
493
494impl Cursor for SearchBytes<'_> {
495 fn chunk(&self) -> &[u8] {
496 self.0[self.1]
497 }
498
499 fn advance(&mut self) -> bool {
500 if self.1 == 0 {
501 self.1 += 1;
502 true
503 } else {
504 false
505 }
506 }
507
508 fn backtrack(&mut self) -> bool {
509 if self.1 == 1 {
510 self.1 -= 1;
511 true
512 } else {
513 false
514 }
515 }
516
517 fn total_bytes(&self) -> Option<usize> {
518 Some(self.0[0].len() + self.0[1].len())
519 }
520
521 fn offset(&self) -> usize {
522 match self.1 {
523 1 => self.0[0].len(),
524 _ => 0,
525 }
526 }
527}