1use std::{
13 collections::HashMap,
14 ops::RangeBounds,
15 sync::{LazyLock, RwLock, RwLockWriteGuard},
16};
17
18use regex_automata::{
19 Anchored, Input, PatternID,
20 hybrid::dfa::{Cache, DFA},
21 nfa::thompson,
22 util::syntax,
23};
24
25use super::{Point, Text, TextRange};
26
27impl Text {
28 pub fn search_fwd<R: RegexPattern>(
40 &mut self,
41 pat: R,
42 range: impl TextRange,
43 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
44 let bytes = self.bytes_mut();
45 let range = range.to_range(bytes.len().byte());
46 let dfas = dfas_from_pat(pat)?;
47 let haystack = {
48 bytes.make_contiguous(range.clone());
49 bytes.get_contiguous(range.clone()).unwrap()
50 };
51
52 let mut fwd_input = Input::new(haystack);
53 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
54 let mut fwd_cache = dfas.fwd.1.write().unwrap();
55 let mut rev_cache = dfas.rev.1.write().unwrap();
56
57 let bytes = bytes as &super::Bytes;
58 Ok(std::iter::from_fn(move || {
59 let init = fwd_input.start();
60 let h_end = loop {
61 if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
62 if half.offset() == init {
64 fwd_input.set_start(init + 1);
65 } else {
66 break half.offset();
67 }
68 } else {
69 return None;
70 }
71 };
72
73 fwd_input.set_start(h_end);
74 rev_input.set_end(h_end);
75
76 let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
77 return None;
78 };
79 let h_start = half.offset();
80
81 let p0 = bytes.point_at(h_start + range.start);
82 let p1 = bytes.point_at(h_end + range.start);
83
84 Some(R::get_match([p0, p1], half.pattern()))
85 }))
86 }
87
88 pub fn search_rev<R: RegexPattern>(
100 &mut self,
101 pat: R,
102 range: impl TextRange,
103 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
104 let bytes = self.bytes_mut();
105 let range = range.to_range(bytes.len().byte());
106 let dfas = dfas_from_pat(pat)?;
107 let haystack = {
108 bytes.make_contiguous(range.clone());
109 bytes.get_contiguous(range.clone()).unwrap()
110 };
111
112 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
113 let mut rev_input = Input::new(haystack);
114 let mut fwd_cache = dfas.fwd.1.write().unwrap();
115 let mut rev_cache = dfas.rev.1.write().unwrap();
116
117 let bytes = bytes as &super::Bytes;
118 let gap = range.start;
119 Ok(std::iter::from_fn(move || {
120 let init = rev_input.end();
121 let start = loop {
122 if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
123 if half.offset() == init {
125 rev_input.set_end(init.checked_sub(1)?);
126 } else {
127 break half.offset();
128 }
129 } else {
130 return None;
131 }
132 };
133
134 rev_input.set_end(start);
135 fwd_input.set_start(start);
136
137 let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
138 return None;
139 };
140 let end = half.offset();
141
142 let p0 = bytes.point_at(start + gap);
143 let p1 = bytes.point_at(end + gap);
144
145 Some(R::get_match([p0, p1], half.pattern()))
146 }))
147 }
148
149 pub fn matches(
154 &mut self,
155 pat: impl RegexPattern,
156 range: impl TextRange,
157 ) -> Result<bool, Box<regex_syntax::Error>> {
158 let range = range.to_range(self.len().byte());
159 let dfas = dfas_from_pat(pat)?;
160
161 let haystack = self.contiguous(range);
162 let fwd_input = Input::new(haystack);
163
164 let mut fwd_cache = dfas.fwd.1.write().unwrap();
165 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
166 Ok(true)
167 } else {
168 Ok(false)
169 }
170 }
171}
172
173pub trait Matcheable: Sized {
175 fn matches(
177 &self,
178 pat: impl RegexPattern,
179 range: impl RangeBounds<usize> + Clone,
180 ) -> Result<bool, Box<regex_syntax::Error>>;
181}
182
183impl<const N: usize> Matcheable for std::array::IntoIter<&str, N> {
184 fn matches(
185 &self,
186 pat: impl RegexPattern,
187 range: impl RangeBounds<usize> + Clone,
188 ) -> Result<bool, Box<regex_syntax::Error>> {
189 let str: String = self.as_slice().iter().copied().collect();
190 let (start, end) = crate::get_ends(range, str.len());
191 let dfas = dfas_from_pat(pat)?;
192 let fwd_input =
193 Input::new(unsafe { std::str::from_utf8_unchecked(&str.as_bytes()[start..end]) });
194
195 let mut fwd_cache = dfas.fwd.1.write().unwrap();
196 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
197 Ok(true)
198 } else {
199 Ok(false)
200 }
201 }
202}
203
204impl Matcheable for &'_ str {
205 fn matches(
206 &self,
207 pat: impl RegexPattern,
208 range: impl RangeBounds<usize> + Clone,
209 ) -> Result<bool, Box<regex_syntax::Error>> {
210 let (start, end) = crate::get_ends(range, self.len());
211 let dfas = dfas_from_pat(pat)?;
212 let fwd_input =
213 Input::new(unsafe { std::str::from_utf8_unchecked(&self.as_bytes()[start..end]) });
214
215 let mut fwd_cache = dfas.fwd.1.write().unwrap();
216 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
217 Ok(true)
218 } else {
219 Ok(false)
220 }
221 }
222}
223
224pub struct Searcher {
228 pat: String,
229 fwd_dfa: &'static DFA,
230 rev_dfa: &'static DFA,
231 fwd_cache: RwLockWriteGuard<'static, Cache>,
232 rev_cache: RwLockWriteGuard<'static, Cache>,
233}
234
235impl Searcher {
236 pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
238 let dfas = dfas_from_pat(&pat)?;
239 Ok(Self {
240 pat,
241 fwd_dfa: &dfas.fwd.0,
242 rev_dfa: &dfas.rev.0,
243 fwd_cache: dfas.fwd.1.write().unwrap(),
244 rev_cache: dfas.rev.1.write().unwrap(),
245 })
246 }
247
248 pub fn search_fwd<'b>(
252 &'b mut self,
253 text: &'b mut Text,
254 range: impl TextRange,
255 ) -> impl Iterator<Item = [Point; 2]> + 'b {
256 let range = range.to_range(text.len().byte());
257 let mut last_point = text.point_at(range.start);
258
259 let haystack = text.contiguous(range.clone());
260 let mut fwd_input = Input::new(haystack).anchored(Anchored::No);
261 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
262
263 let fwd_dfa = &self.fwd_dfa;
264 let rev_dfa = &self.rev_dfa;
265 let fwd_cache = &mut self.fwd_cache;
266 let rev_cache = &mut self.rev_cache;
267 let gap = range.start;
268 std::iter::from_fn(move || {
269 let init = fwd_input.start();
270 let end = loop {
271 if let Ok(Some(half)) = fwd_dfa.try_search_fwd(fwd_cache, &fwd_input) {
272 if half.offset() == init {
274 fwd_input.set_start(init + 1);
275 } else {
276 break half.offset();
277 }
278 } else {
279 return None;
280 }
281 };
282
283 fwd_input.set_start(end);
284 rev_input.set_end(end);
285
286 let half = unsafe {
287 rev_dfa
288 .try_search_rev(rev_cache, &rev_input)
289 .unwrap()
290 .unwrap_unchecked()
291 };
292 let start = half.offset();
293
294 let start = unsafe {
295 std::str::from_utf8_unchecked(&haystack.as_bytes()[last_point.byte() - gap..start])
296 }
297 .chars()
298 .fold(last_point, |p, b| p.fwd(b));
299 let end = unsafe {
300 std::str::from_utf8_unchecked(&haystack.as_bytes()[start.byte() - gap..end])
301 }
302 .chars()
303 .fold(start, |p, b| p.fwd(b));
304
305 last_point = end;
306
307 Some([start, end])
308 })
309 }
310
311 pub fn search_rev<'b>(
315 &'b mut self,
316 text: &'b mut Text,
317 range: impl TextRange,
318 ) -> impl Iterator<Item = [Point; 2]> + 'b {
319 let range = range.to_range(text.len().byte());
320 let mut last_point = text.point_at(range.end);
321
322 let haystack = text.contiguous(range.clone());
323 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
324 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
325
326 let fwd_dfa = &self.fwd_dfa;
327 let rev_dfa = &self.rev_dfa;
328 let fwd_cache = &mut self.fwd_cache;
329 let rev_cache = &mut self.rev_cache;
330 let gap = range.start;
331 std::iter::from_fn(move || {
332 let init = rev_input.end();
333 let start = loop {
334 if let Ok(Some(half)) = rev_dfa.try_search_rev(rev_cache, &rev_input) {
335 if half.offset() == init {
337 rev_input.set_end(init - 1);
338 } else {
339 break half.offset();
340 }
341 } else {
342 return None;
343 }
344 };
345
346 fwd_input.set_start(start);
347 rev_input.set_end(start);
348
349 let half = fwd_dfa
350 .try_search_fwd(fwd_cache, &fwd_input)
351 .unwrap()
352 .unwrap();
353
354 let end = unsafe {
355 std::str::from_utf8_unchecked(
356 &haystack.as_bytes()[half.offset()..(last_point.byte() - gap)],
357 )
358 }
359 .chars()
360 .fold(last_point, |p, b| p.rev(b));
361 let start = unsafe {
362 std::str::from_utf8_unchecked(&haystack.as_bytes()[start..(end.byte() - gap)])
363 }
364 .chars()
365 .fold(end, |p, b| p.rev(b));
366
367 last_point = start;
368
369 Some([start, end])
370 })
371 }
372
373 pub fn matches(&mut self, query: impl AsRef<[u8]>) -> bool {
375 let input = Input::new(&query).anchored(Anchored::Yes);
376
377 let Ok(Some(half)) = self.fwd_dfa.try_search_fwd(&mut self.fwd_cache, &input) else {
378 return false;
379 };
380
381 half.offset() == query.as_ref().len()
382 }
383
384 pub fn is_empty(&self) -> bool {
386 self.pat.is_empty()
387 }
388}
389
390struct DFAs {
391 fwd: (DFA, RwLock<Cache>),
392 rev: (DFA, RwLock<Cache>),
393}
394
395fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
396 static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
397 LazyLock::new(RwLock::default);
398
399 let mut list = DFA_LIST.write().unwrap();
400
401 let mut bytes = [0; 4];
402 let pat = pat.as_patterns(&mut bytes);
403
404 if let Some(dfas) = list.get(&pat) {
405 Ok(*dfas)
406 } else {
407 let pat = pat.leak();
408 let (fwd, rev) = pat.dfas()?;
409
410 let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
411 let dfas = Box::leak(Box::new(DFAs {
412 fwd: (fwd, RwLock::new(fwd_cache)),
413 rev: (rev, RwLock::new(rev_cache)),
414 }));
415 let _ = list.insert(pat, dfas);
416 Ok(dfas)
417 }
418}
419
420#[derive(Clone, Copy, PartialEq, Eq, Hash)]
421enum Patterns<'a> {
422 One(&'a str),
423 Many(&'a [&'a str]),
424}
425
426impl Patterns<'_> {
427 fn leak(&self) -> Patterns<'static> {
428 match self {
429 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
430 Patterns::Many(strs) => Patterns::Many(
431 strs.iter()
432 .map(|s| {
433 let str: &'static str = s.to_string().leak();
434 str
435 })
436 .collect::<Vec<&'static str>>()
437 .leak(),
438 ),
439 }
440 }
441
442 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
443 let mut fwd_builder = DFA::builder();
444 fwd_builder.syntax(syntax::Config::new().multi_line(true));
445 let mut rev_builder = DFA::builder();
446 rev_builder
447 .syntax(syntax::Config::new().multi_line(true))
448 .thompson(thompson::Config::new().reverse(true));
449
450 match self {
451 Patterns::One(pat) => {
452 let pat = pat.replace("\\b", "(?-u:\\b)");
453 regex_syntax::Parser::new().parse(&pat)?;
454 let fwd = fwd_builder.build(&pat).unwrap();
455 let rev = rev_builder.build(&pat).unwrap();
456 Ok((fwd, rev))
457 }
458 Patterns::Many(pats) => {
459 let pats: Vec<String> =
460 pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
461 for pat in pats.iter() {
462 regex_syntax::Parser::new().parse(pat)?;
463 }
464 let fwd = fwd_builder.build_many(&pats).unwrap();
465 let rev = rev_builder.build_many(&pats).unwrap();
466 Ok((fwd, rev))
467 }
468 }
469 }
470}
471
472pub trait RegexPattern: InnerRegexPattern {
477 type Match: 'static;
479
480 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match;
482}
483
484impl RegexPattern for &str {
485 type Match = [Point; 2];
486
487 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
488 points
489 }
490}
491
492impl RegexPattern for String {
493 type Match = [Point; 2];
494
495 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
496 points
497 }
498}
499
500impl RegexPattern for &String {
501 type Match = [Point; 2];
502
503 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
504 points
505 }
506}
507
508impl RegexPattern for char {
509 type Match = [Point; 2];
510
511 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
512 points
513 }
514}
515
516impl<const N: usize> RegexPattern for [&str; N] {
517 type Match = (usize, [Point; 2]);
518
519 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
520 (pattern.as_usize(), points)
521 }
522}
523
524impl RegexPattern for &[&str] {
525 type Match = (usize, [Point; 2]);
526
527 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
528 (pattern.as_usize(), points)
529 }
530}
531
532trait InnerRegexPattern {
533 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
534}
535
536impl InnerRegexPattern for &str {
537 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
538 Patterns::One(self)
539 }
540}
541
542impl InnerRegexPattern for String {
543 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
544 Patterns::One(self)
545 }
546}
547
548impl InnerRegexPattern for &String {
549 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
550 Patterns::One(self)
551 }
552}
553
554impl InnerRegexPattern for char {
555 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
556 Patterns::One(self.encode_utf8(bytes) as &str)
557 }
558}
559
560impl<const N: usize> InnerRegexPattern for [&str; N] {
561 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
562 Patterns::Many(self)
563 }
564}
565
566impl InnerRegexPattern for &[&str] {
567 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
568 Patterns::Many(self)
569 }
570}