1use std::{
14 collections::HashMap,
15 marker::PhantomData,
16 ops::Range,
17 sync::{LazyLock, Mutex},
18};
19
20use regex_cursor::{
21 Cursor, Input,
22 engines::hybrid::{try_search_fwd, try_search_rev},
23 regex_automata::{
24 PatternID,
25 hybrid::dfa::{Cache, DFA},
26 nfa::thompson,
27 util::syntax,
28 },
29};
30
31use super::TextRange;
32use crate::text::Strs;
33
34#[derive(Clone)]
44pub struct Matches<'m, R> {
45 haystack: [&'m [u8]; 2],
46 b_start: usize,
47 dfas: &'static DFAs,
48 fwd_input: Input<SearchBytes<'m>>,
49 rev_input: Input<SearchBytes<'m>>,
50 rev_match: Option<Range<usize>>,
51 _ghost: PhantomData<R>,
52}
53
54impl<'m, R> Matches<'m, R> {
55 #[track_caller]
69 pub fn range(self, range: impl TextRange) -> Self {
70 let [s0, s1] = self.haystack;
71 let range = range.to_range(s0.len() + s1.len());
72 let i0 = &s0[range.start.min(s0.len())..range.end.min(s0.len())];
73 let i1 = &s1[range.start.saturating_sub(s0.len())..range.end.saturating_sub(s0.len())];
74
75 Self {
76 fwd_input: Input::new(SearchBytes([i0, i1], 0)),
77 rev_input: Input::new(SearchBytes([i0, i1], 0)),
78 b_start: range.start,
79 ..self
80 }
81 }
82}
83
84impl<'m, R: RegexPattern> Iterator for Matches<'m, R> {
85 type Item = R::Match;
86
87 fn next(&mut self) -> Option<Self::Item> {
88 let mut fwd_cache = match self.dfas.fwd_cache.lock() {
89 Ok(cache) => cache,
90 Err(err) => err.into_inner(),
91 };
92
93 let (fwd_input, rev_input) = (&mut self.fwd_input, &mut self.rev_input);
94 let h_end = try_search_fwd(&self.dfas.fwd_dfa, &mut fwd_cache, fwd_input)
95 .ok()
96 .flatten()?
97 .offset();
98
99 let mut rev_cache = match self.dfas.rev_cache.lock() {
100 Ok(cache) => cache,
101 Err(err) => err.into_inner(),
102 };
103
104 rev_input.set_range(fwd_input.start()..h_end);
105 let half = try_search_rev(&self.dfas.rev_dfa, &mut rev_cache, rev_input)
106 .ok()
107 .flatten()?;
108 let h_start = half.offset();
109
110 fwd_input.set_start(h_end);
111
112 if h_start == h_end {
114 fwd_input.set_start(h_end + 1);
115 }
116
117 Some(R::get_match(
118 self.b_start + h_start..self.b_start + h_end,
119 half.pattern(),
120 ))
121 }
122}
123
124impl<'m, R: RegexPattern> DoubleEndedIterator for Matches<'m, R> {
125 fn next_back(&mut self) -> Option<Self::Item> {
126 let mut rev_cache = match self.dfas.rev_cache.lock() {
127 Ok(cache) => cache,
128 Err(err) => err.into_inner(),
129 };
130
131 let (fwd_input, rev_input) = (&mut self.fwd_input, &mut self.rev_input);
132
133 let h_start = try_search_rev(&self.dfas.rev_dfa, &mut rev_cache, fwd_input)
134 .ok()
135 .flatten()?
136 .offset();
137
138 let mut fwd_cache = match self.dfas.fwd_cache.lock() {
139 Ok(cache) => cache,
140 Err(err) => err.into_inner(),
141 };
142
143 rev_input.set_range(h_start..fwd_input.end());
144 let half = try_search_fwd(&self.dfas.fwd_dfa, &mut fwd_cache, rev_input)
145 .ok()
146 .flatten()?;
147 let h_end = half.offset();
148
149 fwd_input.set_end(h_start);
150
151 if h_start == h_end {
153 if self.rev_match == Some(self.b_start + h_start..self.b_start + h_end) {
154 return None;
155 } else if h_start > 0 {
156 fwd_input.set_end(h_start - 1);
157 }
158 }
159
160 self.rev_match = Some(self.b_start + h_start..self.b_start + h_end);
161
162 Some(R::get_match(
163 self.b_start + h_start..self.b_start + h_end,
164 half.pattern(),
165 ))
166 }
167}
168
169pub trait RegexHaystack<'h> {
175 #[track_caller]
193 fn search<R: RegexPattern>(&'h self, pat: R) -> Matches<'h, R> {
194 match self.try_search(pat) {
195 Ok(matches) => matches,
196 Err(err) => panic!("{err}"),
197 }
198 }
199
200 fn try_search<R: RegexPattern>(
216 &'h self,
217 pat: R,
218 ) -> Result<Matches<'h, R>, Box<regex_syntax::Error>>;
219
220 fn contains_pat(&'h self, pat: impl RegexPattern) -> Result<bool, Box<regex_syntax::Error>> {
228 Ok(self.try_search(pat)?.next().is_some())
229 }
230
231 fn matches_pat(&'h self, pat: impl RegexPattern) -> Result<bool, Box<regex_syntax::Error>> {
236 let mut matches = self.try_search(pat)?;
237 Ok(matches
238 .next()
239 .is_some_and(|_| matches.fwd_input.start() == matches.fwd_input.end()))
240 }
241}
242
243impl<'b> RegexHaystack<'b> for Strs {
244 fn try_search<R: RegexPattern>(
245 &'b self,
246 pat: R,
247 ) -> Result<Matches<'b, R>, Box<regex_syntax::Error>> {
248 let dfas = dfas_from_pat(pat)?;
249
250 let haystack = self.to_array().map(str::as_bytes);
251
252 Ok(Matches {
253 haystack,
254 b_start: 0,
255 dfas,
256 fwd_input: Input::new(SearchBytes(haystack, 0)),
257 rev_input: Input::new(SearchBytes(haystack, 0)),
258 rev_match: None,
259 _ghost: PhantomData,
260 })
261 }
262}
263
264impl<'s, S: std::ops::Deref<Target = str>> RegexHaystack<'s> for S {
265 fn try_search<R: RegexPattern>(
266 &'s self,
267 pat: R,
268 ) -> Result<Matches<'s, R>, Box<regex_syntax::Error>> {
269 let dfas = dfas_from_pat(pat)?;
270
271 let haystack = [self.deref().as_bytes(), &[]];
272
273 Ok(Matches {
274 haystack,
275 b_start: 0,
276 dfas,
277 fwd_input: Input::new(SearchBytes(haystack, 0)),
278 rev_input: Input::new(SearchBytes(haystack, 0)),
279 rev_match: None,
280 _ghost: PhantomData,
281 })
282 }
283}
284
285struct DFAs {
286 fwd_dfa: DFA,
287 fwd_cache: Mutex<Cache>,
288 rev_dfa: DFA,
289 rev_cache: Mutex<Cache>,
290}
291
292fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
293 static DFA_LIST: LazyLock<Mutex<HashMap<Patterns<'static>, &'static DFAs>>> =
294 LazyLock::new(Mutex::default);
295
296 let mut list = DFA_LIST.lock().unwrap();
297
298 let mut bytes = [0; 4];
299 let pat = pat.as_patterns(&mut bytes);
300
301 if let Some(dfas) = list.get(&pat) {
302 Ok(*dfas)
303 } else {
304 let pat = pat.leak();
305 let (fwd_dfa, rev_dfa) = pat.dfas()?;
306
307 let (fwd_cache, rev_cache) = (Cache::new(&fwd_dfa), Cache::new(&rev_dfa));
308 let dfas = Box::leak(Box::new(DFAs {
309 fwd_dfa,
310 fwd_cache: Mutex::new(fwd_cache),
311 rev_dfa,
312 rev_cache: Mutex::new(rev_cache),
313 }));
314 let _ = list.insert(pat, dfas);
315 Ok(dfas)
316 }
317}
318
319#[derive(Clone, Copy, PartialEq, Eq, Hash)]
320enum Patterns<'a> {
321 One(&'a str),
322 Many(&'a [&'a str]),
323}
324
325impl Patterns<'_> {
326 fn leak(&self) -> Patterns<'static> {
327 match self {
328 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
329 Patterns::Many(strs) => Patterns::Many(
330 strs.iter()
331 .map(|s| {
332 let str: &'static str = s.to_string().leak();
333 str
334 })
335 .collect::<Vec<&'static str>>()
336 .leak(),
337 ),
338 }
339 }
340
341 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
342 let mut fwd_builder = DFA::builder();
343 fwd_builder.syntax(syntax::Config::new().multi_line(true));
344 let mut rev_builder = DFA::builder();
345 rev_builder
346 .syntax(syntax::Config::new().multi_line(true))
347 .thompson(thompson::Config::new().reverse(true));
348
349 match self {
350 Patterns::One(pat) => {
351 let pat = pat.replace("\\b", "(?-u:\\b)");
352 syntax::parse(&pat)?;
353 let fwd = fwd_builder.build(&pat).unwrap();
354 let rev = rev_builder.build(&pat).unwrap();
355 Ok((fwd, rev))
356 }
357 Patterns::Many(pats) => {
358 let pats: Vec<String> =
359 pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
360 for pat in pats.iter() {
361 regex_syntax::Parser::new().parse(pat)?;
362 }
363 let fwd = fwd_builder.build_many(&pats).unwrap();
364 let rev = rev_builder.build_many(&pats).unwrap();
365 Ok((fwd, rev))
366 }
367 }
368 }
369}
370
371pub trait RegexPattern: InnerRegexPattern {
376 type Match;
378
379 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match;
381}
382
383impl RegexPattern for &str {
384 type Match = Range<usize>;
385
386 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
387 points
388 }
389}
390
391impl RegexPattern for String {
392 type Match = Range<usize>;
393
394 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
395 points
396 }
397}
398
399impl RegexPattern for &String {
400 type Match = Range<usize>;
401
402 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
403 points
404 }
405}
406
407impl RegexPattern for char {
408 type Match = Range<usize>;
409
410 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
411 points
412 }
413}
414
415impl<const N: usize> RegexPattern for [&str; N] {
416 type Match = (usize, Range<usize>);
417
418 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
419 (pattern.as_usize(), points)
420 }
421}
422
423impl RegexPattern for &[&str] {
424 type Match = (usize, Range<usize>);
425
426 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
427 (pattern.as_usize(), points)
428 }
429}
430
431trait InnerRegexPattern {
432 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
433}
434
435impl InnerRegexPattern for &str {
436 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
437 Patterns::One(self)
438 }
439}
440
441impl InnerRegexPattern for String {
442 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
443 Patterns::One(self)
444 }
445}
446
447impl InnerRegexPattern for &String {
448 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
449 Patterns::One(self)
450 }
451}
452
453impl InnerRegexPattern for char {
454 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
455 Patterns::One(self.encode_utf8(bytes) as &str)
456 }
457}
458
459impl<const N: usize> InnerRegexPattern for [&str; N] {
460 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
461 Patterns::Many(self)
462 }
463}
464
465impl InnerRegexPattern for &[&str] {
466 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
467 Patterns::Many(self)
468 }
469}
470
471#[derive(Clone, Copy)]
472struct SearchBytes<'a>([&'a [u8]; 2], usize);
473
474impl Cursor for SearchBytes<'_> {
475 fn chunk(&self) -> &[u8] {
476 self.0[self.1]
477 }
478
479 fn advance(&mut self) -> bool {
480 if self.1 == 0 {
481 self.1 += 1;
482 true
483 } else {
484 false
485 }
486 }
487
488 fn backtrack(&mut self) -> bool {
489 if self.1 == 1 {
490 self.1 -= 1;
491 true
492 } else {
493 false
494 }
495 }
496
497 fn total_bytes(&self) -> Option<usize> {
498 Some(self.0[0].len() + self.0[1].len())
499 }
500
501 fn offset(&self) -> usize {
502 match self.1 {
503 1 => self.0[0].len(),
504 _ => 0,
505 }
506 }
507}