1use std::{
13 collections::HashMap,
14 ops::RangeBounds,
15 sync::{LazyLock, RwLock, RwLockWriteGuard},
16};
17
18use regex_cursor::{
19 Cursor, Input,
20 engines::hybrid::{try_search_fwd, try_search_rev},
21 regex_automata::{
22 Anchored, PatternID,
23 hybrid::dfa::{Cache, DFA},
24 nfa::thompson,
25 util::syntax,
26 },
27};
28
29use super::{Bytes, Point, Text, TextRange};
30
31impl Text {
32 pub fn search_fwd<R: RegexPattern>(
44 &self,
45 pat: R,
46 range: impl TextRange,
47 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
48 self.0.bytes.search_fwd(pat, range)
49 }
50
51 pub fn search_rev<R: RegexPattern>(
63 &self,
64 pat: R,
65 range: impl TextRange,
66 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
67 self.0.bytes.search_rev(pat, range)
68 }
69
70 pub fn matches(
75 &self,
76 pat: impl RegexPattern,
77 range: impl TextRange,
78 ) -> Result<bool, Box<regex_syntax::Error>> {
79 let range = range.to_range(self.len().byte());
80 let dfas = dfas_from_pat(pat)?;
81
82 let (mut fwd_input, _) = get_inputs(self, range.clone());
83 fwd_input.anchored(Anchored::Yes);
84
85 let mut fwd_cache = dfas.fwd.1.write().unwrap();
86 if let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) {
87 Ok(hm.offset() + range.start == range.end)
88 } else {
89 Ok(false)
90 }
91 }
92}
93
94impl Bytes {
95 pub fn search_fwd<R: RegexPattern>(
107 &self,
108 pat: R,
109 range: impl TextRange,
110 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
111 let range = range.to_range(self.len().byte());
112 let dfas = dfas_from_pat(pat)?;
113
114 let b_start = self.point_at_byte(range.start).byte();
115
116 let (mut fwd_input, mut rev_input) = get_inputs(self, range.clone());
117 rev_input.anchored(Anchored::Yes);
118
119 let mut fwd_cache = dfas.fwd.1.write().unwrap();
120 let mut rev_cache = dfas.rev.1.write().unwrap();
121
122 Ok(std::iter::from_fn(move || {
123 let init = fwd_input.start();
124 let h_end = loop {
125 if let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input)
126 {
127 if half.offset() == init {
129 fwd_input.set_start(init + 1);
130 } else {
131 break half.offset();
132 }
133 } else {
134 return None;
135 }
136 };
137
138 fwd_input.set_start(h_end);
139 rev_input.set_end(h_end);
140
141 let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input) else {
142 return None;
143 };
144 let h_start = half.offset();
145
146 let p0 = self.point_at_byte(b_start + h_start);
147 let p1 = self.point_at_byte(b_start + h_end);
148
149 Some(R::get_match([p0, p1], half.pattern()))
150 }))
151 }
152
153 pub fn search_rev<R: RegexPattern>(
165 &self,
166 pat: R,
167 range: impl TextRange,
168 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
169 let range = range.to_range(self.len().byte());
170 let dfas = dfas_from_pat(pat)?;
171
172 let (mut fwd_input, mut rev_input) = get_inputs(self, range.clone());
173 fwd_input.anchored(Anchored::Yes);
174
175 let mut fwd_cache = dfas.fwd.1.write().unwrap();
176 let mut rev_cache = dfas.rev.1.write().unwrap();
177
178 Ok(std::iter::from_fn(move || {
179 let init = rev_input.end();
180 let start = loop {
181 if let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input)
182 {
183 if half.offset() == init {
185 rev_input.set_end(init.checked_sub(1)?);
186 } else {
187 break half.offset();
188 }
189 } else {
190 return None;
191 }
192 };
193
194 rev_input.set_end(start);
195 fwd_input.set_start(start);
196
197 let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) else {
198 return None;
199 };
200 let end = half.offset();
201
202 let p0 = self.point_at_byte(range.start + start);
203 let p1 = self.point_at_byte(range.start + end);
204
205 Some(R::get_match([p0, p1], half.pattern()))
206 }))
207 }
208
209 pub fn matches(
214 &self,
215 pat: impl RegexPattern,
216 range: impl TextRange,
217 ) -> Result<bool, Box<regex_syntax::Error>> {
218 let range = range.to_range(self.len().byte());
219 let dfas = dfas_from_pat(pat)?;
220
221 let (mut fwd_input, _) = get_inputs(self, range.clone());
222 fwd_input.anchored(Anchored::Yes);
223
224 let mut fwd_cache = dfas.fwd.1.write().unwrap();
225 if let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) {
226 Ok(hm.offset() == range.end)
227 } else {
228 Ok(false)
229 }
230 }
231}
232
233pub trait Matcheable: Sized {
235 fn search_fwd(
237 &self,
238 pat: impl RegexPattern,
239 range: impl RangeBounds<usize> + Clone,
240 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>>;
241
242 fn search_rev(
244 &self,
245 pat: impl RegexPattern,
246 range: impl RangeBounds<usize> + Clone,
247 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>>;
248
249 fn reg_matches(
251 &self,
252 pat: impl RegexPattern,
253 range: impl RangeBounds<usize> + Clone,
254 ) -> Result<bool, Box<regex_syntax::Error>>;
255}
256
257impl<S: AsRef<str>> Matcheable for S {
258 fn search_fwd(
259 &self,
260 pat: impl RegexPattern,
261 range: impl RangeBounds<usize> + Clone,
262 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>> {
263 let (start, end) = crate::get_ends(range, self.as_ref().len());
264 let str = &self.as_ref()[start..end];
265 let dfas = dfas_from_pat(pat)?;
266
267 let mut fwd_input = Input::new(str);
268 let mut rev_input = Input::new(str);
269 rev_input.anchored(Anchored::Yes);
270
271 let mut fwd_cache = dfas.fwd.1.write().unwrap();
272 let mut rev_cache = dfas.rev.1.write().unwrap();
273
274 Ok(std::iter::from_fn(move || {
275 let init = fwd_input.start();
276 let h_end = loop {
277 if let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input)
278 {
279 if half.offset() == init {
281 fwd_input.set_start(init + 1);
282 } else {
283 break half.offset();
284 }
285 } else {
286 return None;
287 }
288 };
289
290 fwd_input.set_start(h_end);
291 rev_input.set_end(h_end);
292
293 let Ok(Some(hm)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input) else {
294 return None;
295 };
296 let h_start = hm.offset();
297
298 Some(([start + h_start, start + h_end], &str[h_start..h_end]))
299 }))
300 }
301
302 fn search_rev(
303 &self,
304 pat: impl RegexPattern,
305 range: impl RangeBounds<usize> + Clone,
306 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>> {
307 let (start, end) = crate::get_ends(range, self.as_ref().len());
308 let str = &self.as_ref()[start..end];
309 let dfas = dfas_from_pat(pat)?;
310
311 let mut fwd_input = Input::new(str);
312 fwd_input.anchored(Anchored::Yes);
313 let mut rev_input = Input::new(str);
314
315 let mut fwd_cache = dfas.fwd.1.write().unwrap();
316 let mut rev_cache = dfas.rev.1.write().unwrap();
317
318 Ok(std::iter::from_fn(move || {
319 let init = rev_input.end();
320 let h_start = loop {
321 if let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input)
322 {
323 if half.offset() == init {
325 rev_input.set_end(init.checked_sub(1)?);
326 } else {
327 break half.offset();
328 }
329 } else {
330 return None;
331 }
332 };
333
334 rev_input.set_end(h_start);
335 fwd_input.set_start(h_start);
336
337 let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) else {
338 return None;
339 };
340 let h_end = hm.offset();
341
342 Some(([start + h_start, start + h_end], &str[h_start..h_end]))
343 }))
344 }
345
346 fn reg_matches(
347 &self,
348 pat: impl RegexPattern,
349 range: impl RangeBounds<usize> + Clone,
350 ) -> Result<bool, Box<regex_syntax::Error>> {
351 let (start, end) = crate::get_ends(range, self.as_ref().len());
352 let str = &self.as_ref()[start..end];
353 let dfas = dfas_from_pat(pat)?;
354
355 let mut fwd_input = Input::new(str);
356 fwd_input.anchored(Anchored::Yes);
357
358 let mut fwd_cache = dfas.fwd.1.write().unwrap();
359 if let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) {
360 Ok(start + hm.offset() == end)
361 } else {
362 Ok(false)
363 }
364 }
365}
366
367pub struct Searcher {
371 pat: String,
372 fwd_dfa: &'static DFA,
373 rev_dfa: &'static DFA,
374 fwd_cache: RwLockWriteGuard<'static, Cache>,
375 rev_cache: RwLockWriteGuard<'static, Cache>,
376}
377
378impl Searcher {
379 pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
381 let dfas = dfas_from_pat(&pat)?;
382 Ok(Self {
383 pat,
384 fwd_dfa: &dfas.fwd.0,
385 rev_dfa: &dfas.rev.0,
386 fwd_cache: dfas.fwd.1.write().unwrap(),
387 rev_cache: dfas.rev.1.write().unwrap(),
388 })
389 }
390
391 pub fn search_fwd<'b>(
395 &'b mut self,
396 ref_bytes: &'b impl AsRef<Bytes>,
397 range: impl TextRange,
398 ) -> impl Iterator<Item = [Point; 2]> + 'b {
399 let bytes = ref_bytes.as_ref();
400 let range = range.to_range(bytes.len().byte());
401 let mut last_point = bytes.point_at_byte(range.start);
402
403 let (mut fwd_input, mut rev_input) = get_inputs(bytes, range.clone());
404 rev_input.set_anchored(Anchored::Yes);
405
406 let fwd_dfa = &self.fwd_dfa;
407 let rev_dfa = &self.rev_dfa;
408 let fwd_cache = &mut self.fwd_cache;
409 let rev_cache = &mut self.rev_cache;
410
411 std::iter::from_fn(move || {
412 let init = fwd_input.start();
413 let h_end = loop {
414 if let Ok(Some(half)) = try_search_fwd(fwd_dfa, fwd_cache, &mut fwd_input) {
415 if half.offset() == init {
417 fwd_input.set_start(init + 1);
418 } else {
419 break half.offset();
420 }
421 } else {
422 return None;
423 }
424 };
425
426 fwd_input.set_start(h_end);
427 rev_input.set_end(h_end);
428
429 let h_start = unsafe {
430 try_search_rev(rev_dfa, rev_cache, &mut rev_input)
431 .unwrap()
432 .unwrap_unchecked()
433 .offset()
434 };
435
436 let start = unsafe {
440 bytes
441 .buffers(last_point.byte()..range.start + h_start)
442 .chars_unchecked()
443 .fold(last_point, |p, b| p.fwd(b))
444 };
445 let end = unsafe {
446 bytes
447 .buffers(start.byte()..range.start + h_end)
448 .chars_unchecked()
449 .fold(start, |p, b| p.fwd(b))
450 };
451
452 last_point = end;
453
454 Some([start, end])
455 })
456 }
457
458 pub fn search_rev<'b>(
462 &'b mut self,
463 ref_bytes: &'b impl AsRef<Bytes>,
464 range: impl TextRange,
465 ) -> impl Iterator<Item = [Point; 2]> + 'b {
466 let bytes = ref_bytes.as_ref();
467 let range = range.to_range(bytes.len().byte());
468 let mut last_point = bytes.point_at_byte(range.end);
469
470 let (mut fwd_input, mut rev_input) = get_inputs(bytes, range.clone());
471 fwd_input.anchored(Anchored::Yes);
472
473 let fwd_dfa = &self.fwd_dfa;
474 let rev_dfa = &self.rev_dfa;
475 let fwd_cache = &mut self.fwd_cache;
476 let rev_cache = &mut self.rev_cache;
477 std::iter::from_fn(move || {
478 let init = rev_input.end();
479 let h_start = loop {
480 if let Ok(Some(half)) = try_search_rev(rev_dfa, rev_cache, &mut rev_input) {
481 if half.offset() == init {
483 rev_input.set_end(init - 1);
484 } else {
485 break half.offset();
486 }
487 } else {
488 return None;
489 }
490 };
491
492 fwd_input.set_start(h_start);
493 rev_input.set_end(h_start);
494
495 let h_end = unsafe {
496 try_search_fwd(fwd_dfa, fwd_cache, &mut fwd_input)
497 .unwrap()
498 .unwrap_unchecked()
499 .offset()
500 };
501
502 let end = unsafe {
506 bytes
507 .buffers(range.start + h_end..last_point.byte())
508 .chars_unchecked()
509 .fold(last_point, |p, b| p.rev(b))
510 };
511 let start = unsafe {
512 bytes
513 .buffers(range.start + h_start..end.byte())
514 .chars_unchecked()
515 .fold(end, |p, b| p.rev(b))
516 };
517
518 last_point = start;
519
520 Some([start, end])
521 })
522 }
523
524 pub fn matches(&mut self, cursor: impl Cursor) -> bool {
526 let total_bytes = cursor.total_bytes();
527
528 let mut input = Input::new(cursor);
529 input.anchored(Anchored::Yes);
530
531 let Ok(Some(half)) = try_search_fwd(self.fwd_dfa, &mut self.fwd_cache, &mut input) else {
532 return false;
533 };
534
535 total_bytes.is_some_and(|len| len == half.offset())
536 }
537
538 pub fn is_empty(&self) -> bool {
540 self.pat.is_empty()
541 }
542}
543
544struct DFAs {
545 fwd: (DFA, RwLock<Cache>),
546 rev: (DFA, RwLock<Cache>),
547}
548
549fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
550 static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
551 LazyLock::new(RwLock::default);
552
553 let mut list = DFA_LIST.write().unwrap();
554
555 let mut bytes = [0; 4];
556 let pat = pat.as_patterns(&mut bytes);
557
558 if let Some(dfas) = list.get(&pat) {
559 Ok(*dfas)
560 } else {
561 let pat = pat.leak();
562 let (fwd, rev) = pat.dfas()?;
563
564 let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
565 let dfas = Box::leak(Box::new(DFAs {
566 fwd: (fwd, RwLock::new(fwd_cache)),
567 rev: (rev, RwLock::new(rev_cache)),
568 }));
569 let _ = list.insert(pat, dfas);
570 Ok(dfas)
571 }
572}
573
574#[derive(Clone, Copy, PartialEq, Eq, Hash)]
575enum Patterns<'a> {
576 One(&'a str),
577 Many(&'a [&'a str]),
578}
579
580impl Patterns<'_> {
581 fn leak(&self) -> Patterns<'static> {
582 match self {
583 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
584 Patterns::Many(strs) => Patterns::Many(
585 strs.iter()
586 .map(|s| {
587 let str: &'static str = s.to_string().leak();
588 str
589 })
590 .collect::<Vec<&'static str>>()
591 .leak(),
592 ),
593 }
594 }
595
596 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
597 let mut fwd_builder = DFA::builder();
598 fwd_builder.syntax(syntax::Config::new().multi_line(true));
599 let mut rev_builder = DFA::builder();
600 rev_builder
601 .syntax(syntax::Config::new().multi_line(true))
602 .thompson(thompson::Config::new().reverse(true));
603
604 match self {
605 Patterns::One(pat) => {
606 let pat = pat.replace("\\b", "(?-u:\\b)");
607 syntax::parse(&pat)?;
608 let fwd = fwd_builder.build(&pat).unwrap();
609 let rev = rev_builder.build(&pat).unwrap();
610 Ok((fwd, rev))
611 }
612 Patterns::Many(pats) => {
613 let pats: Vec<String> =
614 pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
615 for pat in pats.iter() {
616 regex_syntax::Parser::new().parse(pat)?;
617 }
618 let fwd = fwd_builder.build_many(&pats).unwrap();
619 let rev = rev_builder.build_many(&pats).unwrap();
620 Ok((fwd, rev))
621 }
622 }
623 }
624}
625
626pub trait RegexPattern: InnerRegexPattern {
631 type Match: 'static;
633
634 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match;
636}
637
638impl RegexPattern for &str {
639 type Match = [Point; 2];
640
641 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
642 points
643 }
644}
645
646impl RegexPattern for String {
647 type Match = [Point; 2];
648
649 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
650 points
651 }
652}
653
654impl RegexPattern for &String {
655 type Match = [Point; 2];
656
657 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
658 points
659 }
660}
661
662impl RegexPattern for char {
663 type Match = [Point; 2];
664
665 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
666 points
667 }
668}
669
670impl<const N: usize> RegexPattern for [&str; N] {
671 type Match = (usize, [Point; 2]);
672
673 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
674 (pattern.as_usize(), points)
675 }
676}
677
678impl RegexPattern for &[&str] {
679 type Match = (usize, [Point; 2]);
680
681 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
682 (pattern.as_usize(), points)
683 }
684}
685
686trait InnerRegexPattern {
687 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
688}
689
690impl InnerRegexPattern for &str {
691 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
692 Patterns::One(self)
693 }
694}
695
696impl InnerRegexPattern for String {
697 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
698 Patterns::One(self)
699 }
700}
701
702impl InnerRegexPattern for &String {
703 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
704 Patterns::One(self)
705 }
706}
707
708impl InnerRegexPattern for char {
709 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
710 Patterns::One(self.encode_utf8(bytes) as &str)
711 }
712}
713
714impl<const N: usize> InnerRegexPattern for [&str; N] {
715 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
716 Patterns::Many(self)
717 }
718}
719
720impl InnerRegexPattern for &[&str] {
721 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
722 Patterns::Many(self)
723 }
724}
725
726#[derive(Clone, Copy)]
727struct SearchBytes<'a>([&'a [u8]; 2], usize);
728
729impl Cursor for SearchBytes<'_> {
730 fn chunk(&self) -> &[u8] {
731 self.0[self.1]
732 }
733
734 fn advance(&mut self) -> bool {
735 if self.1 == 0 {
736 self.1 += 1;
737 true
738 } else {
739 false
740 }
741 }
742
743 fn backtrack(&mut self) -> bool {
744 if self.1 == 1 {
745 self.1 -= 1;
746 true
747 } else {
748 false
749 }
750 }
751
752 fn total_bytes(&self) -> Option<usize> {
753 Some(self.0[0].len() + self.0[1].len())
754 }
755
756 fn offset(&self) -> usize {
757 match self.1 {
758 1 => self.0[0].len(),
759 _ => 0,
760 }
761 }
762}
763
764fn get_inputs(
765 bytes: &Bytes,
766 range: std::ops::Range<usize>,
767) -> (Input<SearchBytes<'_>>, Input<SearchBytes<'_>>) {
768 let haystack = SearchBytes(bytes.buffers(range).to_array(), 0);
769 let fwd_input = Input::new(haystack);
770 let rev_input = Input::new(haystack);
771 (fwd_input, rev_input)
772}