1use std::{
14 collections::HashMap,
15 ops::{Range, RangeBounds},
16 sync::{LazyLock, RwLock, RwLockWriteGuard},
17};
18
19use regex_cursor::{
20 Cursor, Input,
21 engines::hybrid::{try_search_fwd, try_search_rev},
22 regex_automata::{
23 Anchored, PatternID,
24 hybrid::dfa::{Cache, DFA},
25 nfa::thompson,
26 util::syntax,
27 },
28};
29
30use super::{Bytes, TextRange};
31
32impl Bytes {
33 pub fn search_fwd<R: RegexPattern>(
45 &self,
46 pat: R,
47 range: impl TextRange,
48 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
49 let range = range.to_range(self.len().byte());
50 let dfas = dfas_from_pat(pat)?;
51
52 let b_start = self.point_at_byte(range.start).byte();
53
54 let (mut fwd_input, mut rev_input) = get_inputs(self, range.clone());
55 rev_input.anchored(Anchored::Yes);
56
57 let mut fwd_cache = dfas.fwd.1.write().unwrap();
58 let mut rev_cache = dfas.rev.1.write().unwrap();
59
60 Ok(std::iter::from_fn(move || {
61 let init = fwd_input.start();
62 let h_end = loop {
63 if let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input)
64 {
65 if half.offset() == init {
67 fwd_input.set_start(init + 1);
68 } else {
69 break half.offset();
70 }
71 } else {
72 return None;
73 }
74 };
75
76 fwd_input.set_start(h_end);
77 rev_input.set_end(h_end);
78
79 let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input) else {
80 return None;
81 };
82 let h_start = half.offset();
83
84 Some(R::get_match(
85 b_start + h_start..b_start + h_end,
86 half.pattern(),
87 ))
88 }))
89 }
90
91 pub fn search_rev<R: RegexPattern>(
103 &self,
104 pat: R,
105 range: impl TextRange,
106 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
107 let range = range.to_range(self.len().byte());
108 let dfas = dfas_from_pat(pat)?;
109
110 let (mut fwd_input, mut rev_input) = get_inputs(self, range.clone());
111 fwd_input.anchored(Anchored::Yes);
112
113 let mut fwd_cache = dfas.fwd.1.write().unwrap();
114 let mut rev_cache = dfas.rev.1.write().unwrap();
115
116 Ok(std::iter::from_fn(move || {
117 let init = rev_input.end();
118 let start = loop {
119 if let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input)
120 {
121 if half.offset() == init {
123 rev_input.set_end(init.checked_sub(1)?);
124 } else {
125 break half.offset();
126 }
127 } else {
128 return None;
129 }
130 };
131
132 rev_input.set_end(start);
133 fwd_input.set_start(start);
134
135 let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) else {
136 return None;
137 };
138 let end = half.offset();
139
140 Some(R::get_match(
141 range.start + start..range.start + end,
142 half.pattern(),
143 ))
144 }))
145 }
146
147 pub fn matches(
152 &self,
153 pat: impl RegexPattern,
154 range: impl TextRange,
155 ) -> Result<bool, Box<regex_syntax::Error>> {
156 let range = range.to_range(self.len().byte());
157 let dfas = dfas_from_pat(pat)?;
158
159 let (mut fwd_input, _) = get_inputs(self, range.clone());
160 fwd_input.anchored(Anchored::Yes);
161
162 let mut fwd_cache = dfas.fwd.1.write().unwrap();
163 if let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) {
164 Ok(hm.offset() == range.end)
165 } else {
166 Ok(false)
167 }
168 }
169}
170
171pub trait Matcheable: Sized {
173 fn search_fwd(
175 &self,
176 pat: impl RegexPattern,
177 range: impl RangeBounds<usize> + Clone,
178 ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>>;
179
180 fn search_rev(
182 &self,
183 pat: impl RegexPattern,
184 range: impl RangeBounds<usize> + Clone,
185 ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>>;
186
187 fn reg_matches(
189 &self,
190 pat: impl RegexPattern,
191 range: impl RangeBounds<usize> + Clone,
192 ) -> Result<bool, Box<regex_syntax::Error>>;
193}
194
195impl<S: AsRef<str>> Matcheable for S {
196 fn search_fwd(
197 &self,
198 pat: impl RegexPattern,
199 range: impl RangeBounds<usize> + Clone,
200 ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>> {
201 let (start, end) = crate::utils::get_ends(range, self.as_ref().len());
202 let str = &self.as_ref()[start..end];
203 let dfas = dfas_from_pat(pat)?;
204
205 let mut fwd_input = Input::new(str);
206 let mut rev_input = Input::new(str);
207 rev_input.anchored(Anchored::Yes);
208
209 let mut fwd_cache = dfas.fwd.1.write().unwrap();
210 let mut rev_cache = dfas.rev.1.write().unwrap();
211
212 Ok(std::iter::from_fn(move || {
213 let init = fwd_input.start();
214 let h_end = loop {
215 if let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input)
216 {
217 if half.offset() == init {
219 fwd_input.set_start(init + 1);
220 } else {
221 break half.offset();
222 }
223 } else {
224 return None;
225 }
226 };
227
228 fwd_input.set_start(h_end);
229 rev_input.set_end(h_end);
230
231 let Ok(Some(hm)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input) else {
232 return None;
233 };
234 let h_start = hm.offset();
235
236 Some((start + h_start..start + h_end, &str[h_start..h_end]))
237 }))
238 }
239
240 fn search_rev(
241 &self,
242 pat: impl RegexPattern,
243 range: impl RangeBounds<usize> + Clone,
244 ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>> {
245 let (start, end) = crate::utils::get_ends(range, self.as_ref().len());
246 let str = &self.as_ref()[start..end];
247 let dfas = dfas_from_pat(pat)?;
248
249 let mut fwd_input = Input::new(str);
250 fwd_input.anchored(Anchored::Yes);
251 let mut rev_input = Input::new(str);
252
253 let mut fwd_cache = dfas.fwd.1.write().unwrap();
254 let mut rev_cache = dfas.rev.1.write().unwrap();
255
256 Ok(std::iter::from_fn(move || {
257 let init = rev_input.end();
258 let h_start = loop {
259 if let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input)
260 {
261 if half.offset() == init {
263 rev_input.set_end(init.checked_sub(1)?);
264 } else {
265 break half.offset();
266 }
267 } else {
268 return None;
269 }
270 };
271
272 rev_input.set_end(h_start);
273 fwd_input.set_start(h_start);
274
275 let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) else {
276 return None;
277 };
278 let h_end = hm.offset();
279
280 Some((start + h_start..start + h_end, &str[h_start..h_end]))
281 }))
282 }
283
284 fn reg_matches(
285 &self,
286 pat: impl RegexPattern,
287 range: impl RangeBounds<usize> + Clone,
288 ) -> Result<bool, Box<regex_syntax::Error>> {
289 let (start, end) = crate::utils::get_ends(range, self.as_ref().len());
290 let str = &self.as_ref()[start..end];
291 let dfas = dfas_from_pat(pat)?;
292
293 let mut fwd_input = Input::new(str);
294 fwd_input.anchored(Anchored::Yes);
295
296 let mut fwd_cache = dfas.fwd.1.write().unwrap();
297 if let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) {
298 Ok(start + hm.offset() == end)
299 } else {
300 Ok(false)
301 }
302 }
303}
304
305pub struct Searcher {
309 pat: String,
310 fwd_dfa: &'static DFA,
311 rev_dfa: &'static DFA,
312 fwd_cache: RwLockWriteGuard<'static, Cache>,
313 rev_cache: RwLockWriteGuard<'static, Cache>,
314}
315
316impl Searcher {
317 pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
319 let dfas = dfas_from_pat(&pat)?;
320 Ok(Self {
321 pat,
322 fwd_dfa: &dfas.fwd.0,
323 rev_dfa: &dfas.rev.0,
324 fwd_cache: dfas.fwd.1.write().unwrap(),
325 rev_cache: dfas.rev.1.write().unwrap(),
326 })
327 }
328
329 pub fn search_fwd<'b>(
333 &'b mut self,
334 ref_bytes: &'b impl AsRef<Bytes>,
335 range: impl TextRange,
336 ) -> impl Iterator<Item = Range<usize>> + 'b {
337 let bytes = ref_bytes.as_ref();
338 let range = range.to_range(bytes.len().byte());
339
340 let (mut fwd_input, mut rev_input) = get_inputs(bytes, range.clone());
341 rev_input.set_anchored(Anchored::Yes);
342
343 let fwd_dfa = &self.fwd_dfa;
344 let rev_dfa = &self.rev_dfa;
345 let fwd_cache = &mut self.fwd_cache;
346 let rev_cache = &mut self.rev_cache;
347
348 std::iter::from_fn(move || {
349 let init = fwd_input.start();
350 let h_end = loop {
351 if let Ok(Some(half)) = try_search_fwd(fwd_dfa, fwd_cache, &mut fwd_input) {
352 if half.offset() == init {
354 fwd_input.set_start(init + 1);
355 } else {
356 break half.offset();
357 }
358 } else {
359 return None;
360 }
361 };
362
363 fwd_input.set_start(h_end);
364 rev_input.set_end(h_end);
365
366 let h_start = unsafe {
367 try_search_rev(rev_dfa, rev_cache, &mut rev_input)
368 .unwrap()
369 .unwrap_unchecked()
370 .offset()
371 };
372
373 Some(range.start + h_start..range.start + h_end)
374 })
375 }
376
377 pub fn search_rev<'b>(
381 &'b mut self,
382 ref_bytes: &'b impl AsRef<Bytes>,
383 range: impl TextRange,
384 ) -> impl Iterator<Item = Range<usize>> + 'b {
385 let bytes = ref_bytes.as_ref();
386 let range = range.to_range(bytes.len().byte());
387
388 let (mut fwd_input, mut rev_input) = get_inputs(bytes, range.clone());
389 fwd_input.anchored(Anchored::Yes);
390
391 let fwd_dfa = &self.fwd_dfa;
392 let rev_dfa = &self.rev_dfa;
393 let fwd_cache = &mut self.fwd_cache;
394 let rev_cache = &mut self.rev_cache;
395 std::iter::from_fn(move || {
396 let init = rev_input.end();
397 let h_start = loop {
398 if let Ok(Some(half)) = try_search_rev(rev_dfa, rev_cache, &mut rev_input) {
399 if half.offset() == init {
401 rev_input.set_end(init - 1);
402 } else {
403 break half.offset();
404 }
405 } else {
406 return None;
407 }
408 };
409
410 fwd_input.set_start(h_start);
411 rev_input.set_end(h_start);
412
413 let h_end = unsafe {
414 try_search_fwd(fwd_dfa, fwd_cache, &mut fwd_input)
415 .unwrap()
416 .unwrap_unchecked()
417 .offset()
418 };
419
420 Some(range.start + h_start..range.start + h_end)
421 })
422 }
423
424 pub fn matches(&mut self, cursor: impl Cursor) -> bool {
426 let total_bytes = cursor.total_bytes();
427
428 let mut input = Input::new(cursor);
429 input.anchored(Anchored::Yes);
430
431 let Ok(Some(half)) = try_search_fwd(self.fwd_dfa, &mut self.fwd_cache, &mut input) else {
432 return false;
433 };
434
435 total_bytes.is_some_and(|len| len == half.offset())
436 }
437
438 pub fn is_empty(&self) -> bool {
440 self.pat.is_empty()
441 }
442}
443
444struct DFAs {
445 fwd: (DFA, RwLock<Cache>),
446 rev: (DFA, RwLock<Cache>),
447}
448
449fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
450 static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
451 LazyLock::new(RwLock::default);
452
453 let mut list = DFA_LIST.write().unwrap();
454
455 let mut bytes = [0; 4];
456 let pat = pat.as_patterns(&mut bytes);
457
458 if let Some(dfas) = list.get(&pat) {
459 Ok(*dfas)
460 } else {
461 let pat = pat.leak();
462 let (fwd, rev) = pat.dfas()?;
463
464 let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
465 let dfas = Box::leak(Box::new(DFAs {
466 fwd: (fwd, RwLock::new(fwd_cache)),
467 rev: (rev, RwLock::new(rev_cache)),
468 }));
469 let _ = list.insert(pat, dfas);
470 Ok(dfas)
471 }
472}
473
474#[derive(Clone, Copy, PartialEq, Eq, Hash)]
475enum Patterns<'a> {
476 One(&'a str),
477 Many(&'a [&'a str]),
478}
479
480impl Patterns<'_> {
481 fn leak(&self) -> Patterns<'static> {
482 match self {
483 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
484 Patterns::Many(strs) => Patterns::Many(
485 strs.iter()
486 .map(|s| {
487 let str: &'static str = s.to_string().leak();
488 str
489 })
490 .collect::<Vec<&'static str>>()
491 .leak(),
492 ),
493 }
494 }
495
496 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
497 let mut fwd_builder = DFA::builder();
498 fwd_builder.syntax(syntax::Config::new().multi_line(true));
499 let mut rev_builder = DFA::builder();
500 rev_builder
501 .syntax(syntax::Config::new().multi_line(true))
502 .thompson(thompson::Config::new().reverse(true));
503
504 match self {
505 Patterns::One(pat) => {
506 let pat = pat.replace("\\b", "(?-u:\\b)");
507 syntax::parse(&pat)?;
508 let fwd = fwd_builder.build(&pat).unwrap();
509 let rev = rev_builder.build(&pat).unwrap();
510 Ok((fwd, rev))
511 }
512 Patterns::Many(pats) => {
513 let pats: Vec<String> =
514 pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
515 for pat in pats.iter() {
516 regex_syntax::Parser::new().parse(pat)?;
517 }
518 let fwd = fwd_builder.build_many(&pats).unwrap();
519 let rev = rev_builder.build_many(&pats).unwrap();
520 Ok((fwd, rev))
521 }
522 }
523 }
524}
525
526pub trait RegexPattern: InnerRegexPattern {
531 type Match: 'static;
533
534 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match;
536}
537
538impl RegexPattern for &str {
539 type Match = Range<usize>;
540
541 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
542 points
543 }
544}
545
546impl RegexPattern for String {
547 type Match = Range<usize>;
548
549 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
550 points
551 }
552}
553
554impl RegexPattern for &String {
555 type Match = Range<usize>;
556
557 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
558 points
559 }
560}
561
562impl RegexPattern for char {
563 type Match = Range<usize>;
564
565 fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
566 points
567 }
568}
569
570impl<const N: usize> RegexPattern for [&str; N] {
571 type Match = (usize, Range<usize>);
572
573 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
574 (pattern.as_usize(), points)
575 }
576}
577
578impl RegexPattern for &[&str] {
579 type Match = (usize, Range<usize>);
580
581 fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
582 (pattern.as_usize(), points)
583 }
584}
585
586trait InnerRegexPattern {
587 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
588}
589
590impl InnerRegexPattern for &str {
591 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
592 Patterns::One(self)
593 }
594}
595
596impl InnerRegexPattern for String {
597 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
598 Patterns::One(self)
599 }
600}
601
602impl InnerRegexPattern for &String {
603 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
604 Patterns::One(self)
605 }
606}
607
608impl InnerRegexPattern for char {
609 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
610 Patterns::One(self.encode_utf8(bytes) as &str)
611 }
612}
613
614impl<const N: usize> InnerRegexPattern for [&str; N] {
615 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
616 Patterns::Many(self)
617 }
618}
619
620impl InnerRegexPattern for &[&str] {
621 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
622 Patterns::Many(self)
623 }
624}
625
626#[derive(Clone, Copy)]
627struct SearchBytes<'a>([&'a [u8]; 2], usize);
628
629impl Cursor for SearchBytes<'_> {
630 fn chunk(&self) -> &[u8] {
631 self.0[self.1]
632 }
633
634 fn advance(&mut self) -> bool {
635 if self.1 == 0 {
636 self.1 += 1;
637 true
638 } else {
639 false
640 }
641 }
642
643 fn backtrack(&mut self) -> bool {
644 if self.1 == 1 {
645 self.1 -= 1;
646 true
647 } else {
648 false
649 }
650 }
651
652 fn total_bytes(&self) -> Option<usize> {
653 Some(self.0[0].len() + self.0[1].len())
654 }
655
656 fn offset(&self) -> usize {
657 match self.1 {
658 1 => self.0[0].len(),
659 _ => 0,
660 }
661 }
662}
663
664fn get_inputs(
665 bytes: &Bytes,
666 range: std::ops::Range<usize>,
667) -> (Input<SearchBytes<'_>>, Input<SearchBytes<'_>>) {
668 let haystack = SearchBytes(bytes.slices(range).to_array(), 0);
669 let fwd_input = Input::new(haystack);
670 let rev_input = Input::new(haystack);
671 (fwd_input, rev_input)
672}