1use std::{collections::HashMap, ops::RangeBounds, sync::LazyLock};
13
14use parking_lot::{RwLock, RwLockWriteGuard};
15use regex_automata::{
16 Anchored, Input, PatternID,
17 hybrid::dfa::{Cache, DFA},
18 nfa::thompson,
19 util::syntax,
20};
21
22use super::{Point, Text, TextRange};
23
24impl Text {
25 pub fn search_fwd<R: RegexPattern>(
26 &mut self,
27 pat: R,
28 range: impl TextRange,
29 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
30 let bytes = self.bytes_mut();
31 let range = range.to_range_fwd(bytes.len().byte());
32 let dfas = dfas_from_pat(pat)?;
33 let haystack = {
34 bytes.make_contiguous(range.clone());
35 bytes.get_contiguous(range.clone()).unwrap()
36 };
37
38 let mut fwd_input = Input::new(haystack);
39 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
40 let mut fwd_cache = dfas.fwd.1.write();
41 let mut rev_cache = dfas.rev.1.write();
42
43 let bytes = bytes as &super::Bytes;
44 Ok(std::iter::from_fn(move || {
45 let init = fwd_input.start();
46 let h_end = loop {
47 if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
48 if half.offset() == init {
50 fwd_input.set_start(init + 1);
51 } else {
52 break half.offset();
53 }
54 } else {
55 return None;
56 }
57 };
58
59 fwd_input.set_start(h_end);
60 rev_input.set_end(h_end);
61
62 let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
63 return None;
64 };
65 let h_start = half.offset();
66
67 let p0 = bytes.point_at(h_start + range.start);
68 let p1 = bytes.point_at(h_end + range.start);
69
70 Some(R::get_match([p0, p1], half.pattern()))
71 }))
72 }
73
74 pub fn search_rev<R: RegexPattern>(
76 &mut self,
77 pat: R,
78 range: impl TextRange,
79 ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
80 let bytes = self.bytes_mut();
81 let range = range.to_range_rev(bytes.len().byte());
82 let dfas = dfas_from_pat(pat)?;
83 let haystack = {
84 bytes.make_contiguous(range.clone());
85 bytes.get_contiguous(range.clone()).unwrap()
86 };
87
88 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
89 let mut rev_input = Input::new(haystack);
90 let mut fwd_cache = dfas.fwd.1.write();
91 let mut rev_cache = dfas.rev.1.write();
92
93 let bytes = bytes as &super::Bytes;
94 let gap = range.start;
95 Ok(std::iter::from_fn(move || {
96 let init = rev_input.end();
97 let start = loop {
98 if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
99 if half.offset() == init {
101 rev_input.set_end(init.checked_sub(1)?);
102 } else {
103 break half.offset();
104 }
105 } else {
106 return None;
107 }
108 };
109
110 rev_input.set_end(start);
111 fwd_input.set_start(start);
112
113 let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
114 return None;
115 };
116 let end = half.offset();
117
118 let p0 = bytes.point_at(start + gap);
119 let p1 = bytes.point_at(end + gap);
120
121 Some(R::get_match([p0, p1], half.pattern()))
122 }))
123 }
124
125 pub fn matches(
130 &mut self,
131 pat: impl RegexPattern,
132 range: impl TextRange,
133 ) -> Result<bool, Box<regex_syntax::Error>> {
134 let range = range.to_range_fwd(self.len().byte());
135 let dfas = dfas_from_pat(pat)?;
136
137 let haystack = self.contiguous(range);
138 let fwd_input = Input::new(haystack);
139
140 let mut fwd_cache = dfas.fwd.1.write();
141 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
142 Ok(true)
143 } else {
144 Ok(false)
145 }
146 }
147}
148
149pub trait Matcheable: Sized {
150 fn matches(
151 &self,
152 pat: impl RegexPattern,
153 range: impl RangeBounds<usize> + Clone,
154 ) -> Result<bool, Box<regex_syntax::Error>>;
155}
156
157impl<const N: usize> Matcheable for std::array::IntoIter<&str, N> {
158 fn matches(
159 &self,
160 pat: impl RegexPattern,
161 range: impl RangeBounds<usize> + Clone,
162 ) -> Result<bool, Box<regex_syntax::Error>> {
163 let str: String = self.as_slice().iter().copied().collect();
164 let (start, end) = crate::get_ends(range, str.len());
165 let dfas = dfas_from_pat(pat)?;
166 let fwd_input =
167 Input::new(unsafe { std::str::from_utf8_unchecked(&str.as_bytes()[start..end]) });
168
169 let mut fwd_cache = dfas.fwd.1.write();
170 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
171 Ok(true)
172 } else {
173 Ok(false)
174 }
175 }
176}
177
178impl Matcheable for &'_ str {
179 fn matches(
180 &self,
181 pat: impl RegexPattern,
182 range: impl RangeBounds<usize> + Clone,
183 ) -> Result<bool, Box<regex_syntax::Error>> {
184 let (start, end) = crate::get_ends(range, self.len());
185 let dfas = dfas_from_pat(pat)?;
186 let fwd_input =
187 Input::new(unsafe { std::str::from_utf8_unchecked(&self.as_bytes()[start..end]) });
188
189 let mut fwd_cache = dfas.fwd.1.write();
190 if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
191 Ok(true)
192 } else {
193 Ok(false)
194 }
195 }
196}
197
198pub struct Searcher {
199 pat: String,
200 fwd_dfa: &'static DFA,
201 rev_dfa: &'static DFA,
202 fwd_cache: RwLockWriteGuard<'static, Cache>,
203 rev_cache: RwLockWriteGuard<'static, Cache>,
204}
205
206impl Searcher {
207 pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
208 let dfas = dfas_from_pat(&pat)?;
209 Ok(Self {
210 pat,
211 fwd_dfa: &dfas.fwd.0,
212 rev_dfa: &dfas.rev.0,
213 fwd_cache: dfas.fwd.1.write(),
214 rev_cache: dfas.rev.1.write(),
215 })
216 }
217
218 pub fn search_fwd<'b>(
219 &'b mut self,
220 text: &'b mut Text,
221 range: impl TextRange,
222 ) -> impl Iterator<Item = [Point; 2]> + 'b {
223 let range = range.to_range_fwd(text.len().byte());
224 let mut last_point = text.point_at(range.start);
225
226 let haystack = text.contiguous(range.clone());
227 let mut fwd_input = Input::new(haystack).anchored(Anchored::No);
228 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
229
230 let fwd_dfa = &self.fwd_dfa;
231 let rev_dfa = &self.rev_dfa;
232 let fwd_cache = &mut self.fwd_cache;
233 let rev_cache = &mut self.rev_cache;
234 let gap = range.start;
235 std::iter::from_fn(move || {
236 let init = fwd_input.start();
237 let end = loop {
238 if let Ok(Some(half)) = fwd_dfa.try_search_fwd(fwd_cache, &fwd_input) {
239 if half.offset() == init {
241 fwd_input.set_start(init + 1);
242 } else {
243 break half.offset();
244 }
245 } else {
246 return None;
247 }
248 };
249
250 fwd_input.set_start(end);
251 rev_input.set_end(end);
252
253 let half = unsafe {
254 rev_dfa
255 .try_search_rev(rev_cache, &rev_input)
256 .unwrap()
257 .unwrap_unchecked()
258 };
259 let start = half.offset();
260
261 let start = unsafe {
262 std::str::from_utf8_unchecked(&haystack.as_bytes()[last_point.byte() - gap..start])
263 }
264 .chars()
265 .fold(last_point, |p, b| p.fwd(b));
266 let end = unsafe {
267 std::str::from_utf8_unchecked(&haystack.as_bytes()[start.byte() - gap..end])
268 }
269 .chars()
270 .fold(start, |p, b| p.fwd(b));
271
272 last_point = end;
273
274 Some([start, end])
275 })
276 }
277
278 pub fn search_rev<'b>(
279 &'b mut self,
280 text: &'b mut Text,
281 range: impl TextRange,
282 ) -> impl Iterator<Item = [Point; 2]> + 'b {
283 let range = range.to_range_rev(text.len().byte());
284 let mut last_point = text.point_at(range.end);
285
286 let haystack = text.contiguous(range.clone());
287 let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
288 let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
289
290 let fwd_dfa = &self.fwd_dfa;
291 let rev_dfa = &self.rev_dfa;
292 let fwd_cache = &mut self.fwd_cache;
293 let rev_cache = &mut self.rev_cache;
294 let gap = range.start;
295 std::iter::from_fn(move || {
296 let init = rev_input.end();
297 let start = loop {
298 if let Ok(Some(half)) = rev_dfa.try_search_rev(rev_cache, &rev_input) {
299 if half.offset() == init {
301 rev_input.set_end(init - 1);
302 } else {
303 break half.offset();
304 }
305 } else {
306 return None;
307 }
308 };
309
310 fwd_input.set_start(start);
311 rev_input.set_end(start);
312
313 let half = fwd_dfa
314 .try_search_fwd(fwd_cache, &fwd_input)
315 .unwrap()
316 .unwrap();
317
318 let end = unsafe {
319 std::str::from_utf8_unchecked(
320 &haystack.as_bytes()[half.offset()..(last_point.byte() - gap)],
321 )
322 }
323 .chars()
324 .fold(last_point, |p, b| p.rev(b));
325 let start = unsafe {
326 std::str::from_utf8_unchecked(&haystack.as_bytes()[start..(end.byte() - gap)])
327 }
328 .chars()
329 .fold(end, |p, b| p.rev(b));
330
331 last_point = start;
332
333 Some([start, end])
334 })
335 }
336
337 pub fn matches(&mut self, query: impl AsRef<[u8]>) -> bool {
339 let input = Input::new(&query).anchored(Anchored::Yes);
340
341 let Ok(Some(half)) = self.fwd_dfa.try_search_fwd(&mut self.fwd_cache, &input) else {
342 return false;
343 };
344
345 half.offset() == query.as_ref().len()
346 }
347
348 pub fn is_empty(&self) -> bool {
350 self.pat.is_empty()
351 }
352}
353
354struct DFAs {
355 fwd: (DFA, RwLock<Cache>),
356 rev: (DFA, RwLock<Cache>),
357}
358
359fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
360 static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
361 LazyLock::new(RwLock::default);
362
363 let mut list = DFA_LIST.write();
364
365 let mut bytes = [0; 4];
366 let pat = pat.as_patterns(&mut bytes);
367
368 if let Some(dfas) = list.get(&pat) {
369 Ok(*dfas)
370 } else {
371 let pat = pat.leak();
372 let (fwd, rev) = pat.dfas()?;
373
374 let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
375 let dfas = Box::leak(Box::new(DFAs {
376 fwd: (fwd, RwLock::new(fwd_cache)),
377 rev: (rev, RwLock::new(rev_cache)),
378 }));
379 let _ = list.insert(pat, dfas);
380 Ok(dfas)
381 }
382}
383
384#[derive(Clone, Copy, PartialEq, Eq, Hash)]
385enum Patterns<'a> {
386 One(&'a str),
387 Many(&'a [&'a str]),
388}
389
390impl Patterns<'_> {
391 fn leak(&self) -> Patterns<'static> {
392 match self {
393 Patterns::One(str) => Patterns::One(String::from(*str).leak()),
394 Patterns::Many(strs) => Patterns::Many(
395 strs.iter()
396 .map(|s| {
397 let str: &'static str = s.to_string().leak();
398 str
399 })
400 .collect::<Vec<&'static str>>()
401 .leak(),
402 ),
403 }
404 }
405
406 fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
407 let mut fwd_builder = DFA::builder();
408 fwd_builder.syntax(syntax::Config::new().multi_line(true));
409 let mut rev_builder = DFA::builder();
410 rev_builder
411 .syntax(syntax::Config::new().multi_line(true))
412 .thompson(thompson::Config::new().reverse(true));
413
414 match self {
415 Patterns::One(pat) => {
416 let pat = pat.replace("\\b", "(?-u:\\b)");
417 regex_syntax::Parser::new().parse(&pat)?;
418 let fwd = fwd_builder.build(&pat).unwrap();
419 let rev = rev_builder.build(&pat).unwrap();
420 Ok((fwd, rev))
421 }
422 Patterns::Many(pats) => {
423 let pats: Vec<String> =
424 pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
425 for pat in pats.iter() {
426 regex_syntax::Parser::new().parse(pat)?;
427 }
428 let fwd = fwd_builder.build_many(&pats).unwrap();
429 let rev = rev_builder.build_many(&pats).unwrap();
430 Ok((fwd, rev))
431 }
432 }
433 }
434}
435
436pub trait RegexPattern: InnerRegexPattern {
437 type Match: 'static;
438
439 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match;
440}
441
442impl RegexPattern for &str {
443 type Match = [Point; 2];
444
445 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
446 points
447 }
448}
449
450impl RegexPattern for String {
451 type Match = [Point; 2];
452
453 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
454 points
455 }
456}
457
458impl RegexPattern for &String {
459 type Match = [Point; 2];
460
461 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
462 points
463 }
464}
465
466impl RegexPattern for char {
467 type Match = [Point; 2];
468
469 fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
470 points
471 }
472}
473
474impl<const N: usize> RegexPattern for [&str; N] {
475 type Match = (usize, [Point; 2]);
476
477 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
478 (pattern.as_usize(), points)
479 }
480}
481
482impl RegexPattern for &[&str] {
483 type Match = (usize, [Point; 2]);
484
485 fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
486 (pattern.as_usize(), points)
487 }
488}
489
490trait InnerRegexPattern {
491 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
492}
493
494impl InnerRegexPattern for &str {
495 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
496 Patterns::One(self)
497 }
498}
499
500impl InnerRegexPattern for String {
501 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
502 Patterns::One(self)
503 }
504}
505
506impl InnerRegexPattern for &String {
507 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
508 Patterns::One(self)
509 }
510}
511
512impl InnerRegexPattern for char {
513 fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
514 Patterns::One(self.encode_utf8(bytes) as &str)
515 }
516}
517
518impl<const N: usize> InnerRegexPattern for [&str; N] {
519 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
520 Patterns::Many(self)
521 }
522}
523
524impl InnerRegexPattern for &[&str] {
525 fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
526 Patterns::Many(self)
527 }
528}