1use std::{
13 collections::HashMap,
14 ops::RangeBounds,
15 sync::{LazyLock, RwLock, RwLockWriteGuard},
16};
17
18use regex_automata::{
19 Anchored, Input, PatternID,
20 hybrid::dfa::{Cache, DFA},
21 nfa::thompson,
22 util::syntax,
23};
24
25use super::{Bytes, Point, Text, TextRange};
26
27impl Text {
28 pub fn search_fwd<R: RegexPattern>(
40 &mut self,
41 pat: R,
42 range: impl TextRange,
43 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
44 self.bytes_mut().search_fwd(pat, range)
45 }
46
47 pub fn search_rev<R: RegexPattern>(
59 &mut self,
60 pat: R,
61 range: impl TextRange,
62 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
63 self.bytes_mut().search_rev(pat, range)
64 }
65
66 pub fn matches(
71 &mut self,
72 pat: impl RegexPattern,
73 range: impl TextRange,
74 ) -> Result<bool, Box<regex_syntax::Error>> {
75 let range = range.to_range(self.len().byte());
76 let dfas = dfas_from_pat(pat)?;
77
78 let haystack = self.contiguous(range);
79 let fwd_input = Input::new(haystack);
80
81 let mut fwd_cache = dfas.fwd.1.write().unwrap();
82 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
83 Ok(true)
84 } else {
85 Ok(false)
86 }
87 }
88}
89
90impl Bytes {
91 pub fn search_fwd<R: RegexPattern>(
103 &mut self,
104 pat: R,
105 range: impl TextRange,
106 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
107 let range = range.to_range(self.len().byte());
108 let dfas = dfas_from_pat(pat)?;
109 let haystack = {
110 self.make_contiguous(range.clone());
111 self.get_contiguous(range.clone()).unwrap()
112 };
113
114 let mut fwd_input = Input::new(haystack);
115 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
116 let mut fwd_cache = dfas.fwd.1.write().unwrap();
117 let mut rev_cache = dfas.rev.1.write().unwrap();
118
119 let bytes = self as &super::Bytes;
120 Ok(std::iter::from_fn(move || {
121 let init = fwd_input.start();
122 let end = loop {
123 if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
124 if half.offset() == init {
126 fwd_input.set_start(init + 1);
127 } else {
128 break half.offset();
129 }
130 } else {
131 return None;
132 }
133 };
134
135 fwd_input.set_start(end);
136 rev_input.set_end(end);
137
138 let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
139 return None;
140 };
141 let start = half.offset();
142
143 let p0 = bytes.point_at(start + range.start);
144 let p1 = bytes.point_at(end + range.start);
145
146 Some(R::get_match([p0, p1], half.pattern()))
147 }))
148 }
149
150 pub fn search_rev<R: RegexPattern>(
162 &mut self,
163 pat: R,
164 range: impl TextRange,
165 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
166 let range = range.to_range(self.len().byte());
167 let dfas = dfas_from_pat(pat)?;
168 let haystack = {
169 self.make_contiguous(range.clone());
170 self.get_contiguous(range.clone()).unwrap()
171 };
172
173 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
174 let mut rev_input = Input::new(haystack);
175 let mut fwd_cache = dfas.fwd.1.write().unwrap();
176 let mut rev_cache = dfas.rev.1.write().unwrap();
177
178 let bytes = self as &super::Bytes;
179 let gap = range.start;
180 Ok(std::iter::from_fn(move || {
181 let init = rev_input.end();
182 let start = loop {
183 if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
184 if half.offset() == init {
186 rev_input.set_end(init.checked_sub(1)?);
187 } else {
188 break half.offset();
189 }
190 } else {
191 return None;
192 }
193 };
194
195 rev_input.set_end(start);
196 fwd_input.set_start(start);
197
198 let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
199 return None;
200 };
201 let end = half.offset();
202
203 let p0 = bytes.point_at(start + gap);
204 let p1 = bytes.point_at(end + gap);
205
206 Some(R::get_match([p0, p1], half.pattern()))
207 }))
208 }
209
210 pub fn matches(
215 &mut self,
216 pat: impl RegexPattern,
217 range: impl TextRange,
218 ) -> Result<bool, Box<regex_syntax::Error>> {
219 let range = range.to_range(self.len().byte());
220 let dfas = dfas_from_pat(pat)?;
221
222 let haystack = self.contiguous(range);
223 let fwd_input = Input::new(haystack);
224
225 let mut fwd_cache = dfas.fwd.1.write().unwrap();
226 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
227 Ok(true)
228 } else {
229 Ok(false)
230 }
231 }
232}
233
234pub trait Matcheable: Sized {
236 fn search_fwd(
238 &self,
239 pat: impl RegexPattern,
240 range: impl RangeBounds<usize> + Clone,
241 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>>;
242
243 fn search_rev(
245 &self,
246 pat: impl RegexPattern,
247 range: impl RangeBounds<usize> + Clone,
248 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>>;
249
250 fn matches(
252 &self,
253 pat: impl RegexPattern,
254 range: impl RangeBounds<usize> + Clone,
255 ) -> Result<bool, Box<regex_syntax::Error>>;
256}
257
258impl<S: AsRef<str>> Matcheable for S {
259 fn search_fwd(
260 &self,
261 pat: impl RegexPattern,
262 range: impl RangeBounds<usize> + Clone,
263 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>> {
264 let str = self.as_ref();
265 let (start, end) = crate::get_ends(range, str.len());
266 let dfas = dfas_from_pat(pat)?;
267
268 let haystack = &str[start..end];
269
270 let mut fwd_input = Input::new(haystack);
271 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
272 let mut fwd_cache = dfas.fwd.1.write().unwrap();
273 let mut rev_cache = dfas.rev.1.write().unwrap();
274
275 Ok(std::iter::from_fn(move || {
276 let init = fwd_input.start();
277 let end = loop {
278 if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
279 if half.offset() == init {
281 fwd_input.set_start(init + 1);
282 } else {
283 break half.offset();
284 }
285 } else {
286 return None;
287 }
288 };
289
290 fwd_input.set_start(end);
291 rev_input.set_end(end);
292
293 let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
294 return None;
295 };
296 let start = half.offset();
297
298 Some(([start, end], &str[start..end]))
299 }))
300 }
301
302 fn search_rev(
303 &self,
304 pat: impl RegexPattern,
305 range: impl RangeBounds<usize> + Clone,
306 ) -> Result<impl Iterator<Item = ([usize; 2], &str)>, Box<regex_syntax::Error>> {
307 let str = self.as_ref();
308 let (start, end) = crate::get_ends(range, str.len());
309 let dfas = dfas_from_pat(pat)?;
310
311 let haystack = &str[start..end];
312
313 let mut fwd_input = Input::new(haystack);
314 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
315 let mut fwd_cache = dfas.fwd.1.write().unwrap();
316 let mut rev_cache = dfas.rev.1.write().unwrap();
317
318 Ok(std::iter::from_fn(move || {
319 let init = rev_input.end();
320 let start = loop {
321 if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
322 if half.offset() == init {
324 rev_input.set_end(init.checked_sub(1)?);
325 } else {
326 break half.offset();
327 }
328 } else {
329 return None;
330 }
331 };
332
333 rev_input.set_end(start);
334 fwd_input.set_start(start);
335
336 let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
337 return None;
338 };
339 let end = half.offset();
340
341 Some(([start, end], &str[start..end]))
342 }))
343 }
344
345 fn matches(
346 &self,
347 pat: impl RegexPattern,
348 range: impl RangeBounds<usize> + Clone,
349 ) -> Result<bool, Box<regex_syntax::Error>> {
350 let str = self.as_ref();
351 let (start, end) = crate::get_ends(range, str.len());
352 let dfas = dfas_from_pat(pat)?;
353 let fwd_input = Input::new(&str[start..end]);
354
355 let mut fwd_cache = dfas.fwd.1.write().unwrap();
356 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
357 Ok(true)
358 } else {
359 Ok(false)
360 }
361 }
362}
363
364pub struct Searcher {
368 pat: String,
369 fwd_dfa: &'static DFA,
370 rev_dfa: &'static DFA,
371 fwd_cache: RwLockWriteGuard<'static, Cache>,
372 rev_cache: RwLockWriteGuard<'static, Cache>,
373}
374
375impl Searcher {
376 pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
378 let dfas = dfas_from_pat(&pat)?;
379 Ok(Self {
380 pat,
381 fwd_dfa: &dfas.fwd.0,
382 rev_dfa: &dfas.rev.0,
383 fwd_cache: dfas.fwd.1.write().unwrap(),
384 rev_cache: dfas.rev.1.write().unwrap(),
385 })
386 }
387
388 pub fn search_fwd<'b>(
392 &'b mut self,
393 as_mut_bytes: &'b mut impl AsMutBytes,
394 range: impl TextRange,
395 ) -> impl Iterator<Item = [Point; 2]> + 'b {
396 let bytes = as_mut_bytes.as_mut_bytes();
397 let range = range.to_range(bytes.len().byte());
398 let mut last_point = bytes.point_at(range.start);
399
400 let haystack = bytes.contiguous(range.clone());
401 let mut fwd_input = Input::new(haystack).anchored(Anchored::No);
402 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
403
404 let fwd_dfa = &self.fwd_dfa;
405 let rev_dfa = &self.rev_dfa;
406 let fwd_cache = &mut self.fwd_cache;
407 let rev_cache = &mut self.rev_cache;
408 let gap = range.start;
409 std::iter::from_fn(move || {
410 let init = fwd_input.start();
411 let end = loop {
412 if let Ok(Some(half)) = fwd_dfa.try_search_fwd(fwd_cache, &fwd_input) {
413 if half.offset() == init {
415 fwd_input.set_start(init + 1);
416 } else {
417 break half.offset();
418 }
419 } else {
420 return None;
421 }
422 };
423
424 fwd_input.set_start(end);
425 rev_input.set_end(end);
426
427 let half = unsafe {
428 rev_dfa
429 .try_search_rev(rev_cache, &rev_input)
430 .unwrap()
431 .unwrap_unchecked()
432 };
433 let start = half.offset();
434
435 let start = haystack[last_point.byte() - gap..start]
436 .chars()
437 .fold(last_point, |p, b| p.fwd(b));
438 let end = haystack[start.byte() - gap..end]
439 .chars()
440 .fold(start, |p, b| p.fwd(b));
441
442 last_point = end;
443
444 Some([start, end])
445 })
446 }
447
448 pub fn search_rev<'b>(
452 &'b mut self,
453 as_mut_bytes: &'b mut impl AsMutBytes,
454 range: impl TextRange,
455 ) -> impl Iterator<Item = [Point; 2]> + 'b {
456 let bytes = as_mut_bytes.as_mut_bytes();
457 let range = range.to_range(bytes.len().byte());
458 let mut last_point = bytes.point_at(range.end);
459
460 let haystack = bytes.contiguous(range.clone());
461 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
462 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
463
464 let fwd_dfa = &self.fwd_dfa;
465 let rev_dfa = &self.rev_dfa;
466 let fwd_cache = &mut self.fwd_cache;
467 let rev_cache = &mut self.rev_cache;
468 let gap = range.start;
469 std::iter::from_fn(move || {
470 let init = rev_input.end();
471 let start = loop {
472 if let Ok(Some(half)) = rev_dfa.try_search_rev(rev_cache, &rev_input) {
473 if half.offset() == init {
475 rev_input.set_end(init - 1);
476 } else {
477 break half.offset();
478 }
479 } else {
480 return None;
481 }
482 };
483
484 fwd_input.set_start(start);
485 rev_input.set_end(start);
486
487 let half = fwd_dfa
488 .try_search_fwd(fwd_cache, &fwd_input)
489 .unwrap()
490 .unwrap();
491
492 let end = haystack[half.offset()..(last_point.byte() - gap)]
493 .chars()
494 .fold(last_point, |p, b| p.rev(b));
495 let start = haystack[start..(end.byte() - gap)]
496 .chars()
497 .fold(end, |p, b| p.rev(b));
498
499 last_point = start;
500
501 Some([start, end])
502 })
503 }
504
505 pub fn matches(&mut self, query: impl AsRef<[u8]>) -> bool {
507 let input = Input::new(&query).anchored(Anchored::Yes);
508
509 let Ok(Some(half)) = self.fwd_dfa.try_search_fwd(&mut self.fwd_cache, &input) else {
510 return false;
511 };
512
513 half.offset() == query.as_ref().len()
514 }
515
516 pub fn is_empty(&self) -> bool {
518 self.pat.is_empty()
519 }
520}
521
522struct DFAs {
523 fwd: (DFA, RwLock<Cache>),
524 rev: (DFA, RwLock<Cache>),
525}
526
527fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
528 static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
529 LazyLock::new(RwLock::default);
530
531 let mut list = DFA_LIST.write().unwrap();
532
533 let mut bytes = [0; 4];
534 let pat = pat.as_patterns(&mut bytes);
535
536 if let Some(dfas) = list.get(&pat) {
537 Ok(*dfas)
538 } else {
539 let pat = pat.leak();
540 let (fwd, rev) = pat.dfas()?;
541
542 let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
543 let dfas = Box::leak(Box::new(DFAs {
544 fwd: (fwd, RwLock::new(fwd_cache)),
545 rev: (rev, RwLock::new(rev_cache)),
546 }));
547 let _ = list.insert(pat, dfas);
548 Ok(dfas)
549 }
550}
551
552#[derive(Clone, Copy, PartialEq, Eq, Hash)]
553enum Patterns<'a> {
554 One(&'a str),
555 Many(&'a [&'a str]),
556}
557
558impl Patterns<'_> {
559 fn leak(&self) -> Patterns<'static> {
560 match self {
561 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
562 Patterns::Many(strs) => Patterns::Many(
563 strs.iter()
564 .map(|s| {
565 let str: &'static str = s.to_string().leak();
566 str
567 })
568 .collect::<Vec<&'static str>>()
569 .leak(),
570 ),
571 }
572 }
573
574 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
575 let mut fwd_builder = DFA::builder();
576 fwd_builder.syntax(syntax::Config::new().multi_line(true));
577 let mut rev_builder = DFA::builder();
578 rev_builder
579 .syntax(syntax::Config::new().multi_line(true))
580 .thompson(thompson::Config::new().reverse(true));
581
582 match self {
583 Patterns::One(pat) => {
584 let pat = pat.replace("\\b", "(?-u:\\b)");
585 regex_syntax::Parser::new().parse(&pat)?;
586 let fwd = fwd_builder.build(&pat).unwrap();
587 let rev = rev_builder.build(&pat).unwrap();
588 Ok((fwd, rev))
589 }
590 Patterns::Many(pats) => {
591 let pats: Vec<String> =
592 pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
593 for pat in pats.iter() {
594 regex_syntax::Parser::new().parse(pat)?;
595 }
596 let fwd = fwd_builder.build_many(&pats).unwrap();
597 let rev = rev_builder.build_many(&pats).unwrap();
598 Ok((fwd, rev))
599 }
600 }
601 }
602}
603
604pub trait RegexPattern: InnerRegexPattern {
609 type Match: 'static;
611
612 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match;
614}
615
616impl RegexPattern for &str {
617 type Match = [Point; 2];
618
619 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
620 points
621 }
622}
623
624impl RegexPattern for String {
625 type Match = [Point; 2];
626
627 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
628 points
629 }
630}
631
632impl RegexPattern for &String {
633 type Match = [Point; 2];
634
635 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
636 points
637 }
638}
639
640impl RegexPattern for char {
641 type Match = [Point; 2];
642
643 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
644 points
645 }
646}
647
648impl<const N: usize> RegexPattern for [&str; N] {
649 type Match = (usize, [Point; 2]);
650
651 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
652 (pattern.as_usize(), points)
653 }
654}
655
656impl RegexPattern for &[&str] {
657 type Match = (usize, [Point; 2]);
658
659 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
660 (pattern.as_usize(), points)
661 }
662}
663
664trait InnerRegexPattern {
665 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
666}
667
668impl InnerRegexPattern for &str {
669 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
670 Patterns::One(self)
671 }
672}
673
674impl InnerRegexPattern for String {
675 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
676 Patterns::One(self)
677 }
678}
679
680impl InnerRegexPattern for &String {
681 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
682 Patterns::One(self)
683 }
684}
685
686impl InnerRegexPattern for char {
687 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
688 Patterns::One(self.encode_utf8(bytes) as &str)
689 }
690}
691
692impl<const N: usize> InnerRegexPattern for [&str; N] {
693 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
694 Patterns::Many(self)
695 }
696}
697
698impl InnerRegexPattern for &[&str] {
699 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
700 Patterns::Many(self)
701 }
702}
703
704pub trait AsMutBytes {
706 fn as_mut_bytes(&mut self) -> &mut Bytes;
707}
708
709impl AsMutBytes for Bytes {
710 fn as_mut_bytes(&mut self) -> &mut Bytes {
711 self
712 }
713}
714
715impl AsMutBytes for Text {
716 fn as_mut_bytes(&mut self) -> &mut Bytes {
717 self.bytes_mut()
718 }
719}