duat_core/text/
search.rs

1//! Utilities for searching the [`Text`]
2//!
3//! This includes some methods for the [`Text`] itself, meant for
4//! general use when editing it. It also has the [`Searcher`] struct,
5//! which is used when doing [incremental search] in the [`CmdLine`].
6//! This iterator is then used in a [`IncSearcher`] that can decide
7//! what to do with the results.
8//!
9//! [incremental search]: crate::widgets::IncSearch
10//! [`CmdLine`]: crate::widgets::CmdLine
11//! [`IncSearcher`]: crate::mode::IncSearcher
12use std::{collections::HashMap, ops::RangeBounds, sync::LazyLock};
13
14use parking_lot::{RwLock, RwLockWriteGuard};
15use regex_automata::{
16    Anchored, Input, PatternID,
17    hybrid::dfa::{Cache, DFA},
18    nfa::thompson::Config,
19};
20
21use super::{Point, Text, TextRange};
22
23impl Text {
24    pub fn search_fwd<R: RegexPattern>(
25        &mut self,
26        pat: R,
27        range: impl TextRange,
28    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
29        let range = range.to_range_fwd(self.len().byte());
30        let dfas = dfas_from_pat(pat)?;
31        let haystack = unsafe {
32            self.make_contiguous_in(range.clone());
33            self.continuous_in_unchecked(range.clone())
34        };
35
36        let mut fwd_input = Input::new(haystack);
37        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
38        let mut fwd_cache = dfas.fwd.1.write();
39        let mut rev_cache = dfas.rev.1.write();
40
41        let ref_self = self as &Text;
42        Ok(std::iter::from_fn(move || {
43            let init = fwd_input.start();
44            let h_end = loop {
45                if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
46                    // Ignore empty matches at the start of the input.
47                    if half.offset() == init {
48                        fwd_input.set_start(init + 1);
49                    } else {
50                        break half.offset();
51                    }
52                } else {
53                    return None;
54                }
55            };
56
57            fwd_input.set_start(h_end);
58            rev_input.set_end(h_end);
59
60            let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
61                return None;
62            };
63            let h_start = half.offset();
64
65            let p0 = ref_self.point_at(h_start + range.start);
66            let p1 = ref_self.point_at(h_end + range.start);
67
68            Some(R::get_match((p0, p1), half.pattern()))
69        }))
70    }
71
72    /// Returns an iterator over the reverse matches of the regex
73    pub fn search_rev<R: RegexPattern>(
74        &mut self,
75        pat: R,
76        range: impl TextRange,
77    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
78        let range = range.to_range_rev(self.len().byte());
79        let dfas = dfas_from_pat(pat)?;
80        let haystack = unsafe {
81            self.make_contiguous_in(range.clone());
82            self.continuous_in_unchecked(range.clone())
83        };
84
85        let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
86        let mut rev_input = Input::new(haystack);
87        let mut fwd_cache = dfas.fwd.1.write();
88        let mut rev_cache = dfas.rev.1.write();
89
90        let ref_self = self as &Text;
91        let gap = range.start;
92        Ok(std::iter::from_fn(move || {
93            let init = rev_input.end();
94            let start = loop {
95                if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
96                    // Ignore empty matches at the end of the input.
97                    if half.offset() == init {
98                        rev_input.set_end(init.checked_sub(1)?);
99                    } else {
100                        break half.offset();
101                    }
102                } else {
103                    return None;
104                }
105            };
106
107            rev_input.set_end(start);
108            fwd_input.set_start(start);
109
110            let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
111                return None;
112            };
113            let end = half.offset();
114
115            let p0 = ref_self.point_at(start + gap);
116            let p1 = ref_self.point_at(end + gap);
117
118            Some(R::get_match((p0, p1), half.pattern()))
119        }))
120    }
121
122    /// Returns true if the pattern is found in the given range
123    ///
124    /// This is unanchored by default, if you want an anchored search,
125    /// use the `"^$"` characters.
126    pub fn matches(
127        &mut self,
128        pat: impl RegexPattern,
129        range: impl TextRange,
130    ) -> Result<bool, Box<regex_syntax::Error>> {
131        let range = range.to_range_fwd(self.len().byte());
132        let dfas = dfas_from_pat(pat)?;
133
134        let haystack = unsafe {
135            self.make_contiguous_in(range.clone());
136            self.continuous_in_unchecked(range)
137        };
138        let fwd_input = Input::new(haystack);
139
140        let mut fwd_cache = dfas.fwd.1.write();
141        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
142            Ok(true)
143        } else {
144            Ok(false)
145        }
146    }
147}
148
149pub trait Matcheable: Sized {
150    fn matches(
151        &self,
152        pat: impl RegexPattern,
153        range: impl RangeBounds<usize> + Clone,
154    ) -> Result<bool, Box<regex_syntax::Error>>;
155}
156
157impl<S: AsRef<str>> Matcheable for S {
158    fn matches(
159        &self,
160        pat: impl RegexPattern,
161        range: impl RangeBounds<usize> + Clone,
162    ) -> Result<bool, Box<regex_syntax::Error>> {
163        let s = self.as_ref();
164        let (start, end) = crate::get_ends(range, s.len());
165        let dfas = dfas_from_pat(pat)?;
166        let fwd_input =
167            Input::new(unsafe { std::str::from_utf8_unchecked(&s.as_bytes()[start..end]) });
168
169        let mut fwd_cache = dfas.fwd.1.write();
170        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
171            Ok(true)
172        } else {
173            Ok(false)
174        }
175    }
176}
177
178pub struct Searcher {
179    pat: String,
180    fwd_dfa: &'static DFA,
181    rev_dfa: &'static DFA,
182    fwd_cache: RwLockWriteGuard<'static, Cache>,
183    rev_cache: RwLockWriteGuard<'static, Cache>,
184}
185
186impl Searcher {
187    pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
188        let dfas = dfas_from_pat(&pat)?;
189        Ok(Self {
190            pat,
191            fwd_dfa: &dfas.fwd.0,
192            rev_dfa: &dfas.rev.0,
193            fwd_cache: dfas.fwd.1.write(),
194            rev_cache: dfas.rev.1.write(),
195        })
196    }
197
198    pub fn search_fwd<'b>(
199        &'b mut self,
200        text: &'b mut Text,
201        range: impl TextRange,
202    ) -> impl Iterator<Item = (Point, Point)> + 'b {
203        let range = range.to_range_fwd(text.len().byte());
204        let haystack = unsafe {
205            text.make_contiguous_in(range.clone());
206            text.continuous_in_unchecked(range.clone())
207        };
208        let mut fwd_input = Input::new(haystack).anchored(Anchored::No);
209        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
210        let mut last_point = text.point_at(range.start);
211
212        let fwd_dfa = &self.fwd_dfa;
213        let rev_dfa = &self.rev_dfa;
214        let fwd_cache = &mut self.fwd_cache;
215        let rev_cache = &mut self.rev_cache;
216        let gap = range.start;
217        std::iter::from_fn(move || {
218            let init = fwd_input.start();
219            let end = loop {
220                if let Ok(Some(half)) = fwd_dfa.try_search_fwd(fwd_cache, &fwd_input) {
221                    // Ignore empty matches at the start of the input.
222                    if half.offset() == init {
223                        fwd_input.set_start(init + 1);
224                    } else {
225                        break half.offset();
226                    }
227                } else {
228                    return None;
229                }
230            };
231
232            fwd_input.set_start(end);
233            rev_input.set_end(end);
234
235            let half = unsafe {
236                rev_dfa
237                    .try_search_rev(rev_cache, &rev_input)
238                    .unwrap()
239                    .unwrap_unchecked()
240            };
241            let start = half.offset();
242
243            let start = unsafe {
244                std::str::from_utf8_unchecked(&haystack.as_bytes()[last_point.byte() - gap..start])
245            }
246            .chars()
247            .fold(last_point, |p, b| p.fwd(b));
248            let end = unsafe {
249                std::str::from_utf8_unchecked(&haystack.as_bytes()[start.byte() - gap..end])
250            }
251            .chars()
252            .fold(start, |p, b| p.fwd(b));
253
254            last_point = end;
255
256            Some((start, end))
257        })
258    }
259
260    pub fn search_rev<'b>(
261        &'b mut self,
262        text: &'b mut Text,
263        range: impl TextRange,
264    ) -> impl Iterator<Item = (Point, Point)> + 'b {
265        let range = range.to_range_rev(text.len().byte());
266        let haystack = unsafe {
267            text.make_contiguous_in(range.clone());
268            text.continuous_in_unchecked(range.clone())
269        };
270        let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
271        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
272        let mut last_point = text.point_at(range.end);
273
274        let fwd_dfa = &self.fwd_dfa;
275        let rev_dfa = &self.rev_dfa;
276        let fwd_cache = &mut self.fwd_cache;
277        let rev_cache = &mut self.rev_cache;
278        let gap = range.start;
279        std::iter::from_fn(move || {
280            let init = rev_input.end();
281            let start = loop {
282                if let Ok(Some(half)) = rev_dfa.try_search_rev(rev_cache, &rev_input) {
283                    // Ignore empty matches at the end of the input.
284                    if half.offset() == init {
285                        rev_input.set_end(init - 1);
286                    } else {
287                        break half.offset();
288                    }
289                } else {
290                    return None;
291                }
292            };
293
294            fwd_input.set_start(start);
295            rev_input.set_end(start);
296
297            let half = fwd_dfa
298                .try_search_fwd(fwd_cache, &fwd_input)
299                .unwrap()
300                .unwrap();
301
302            let end = unsafe {
303                std::str::from_utf8_unchecked(
304                    &haystack.as_bytes()[half.offset()..(last_point.byte() - gap)],
305                )
306            }
307            .chars()
308            .fold(last_point, |p, b| p.rev(b));
309            let start = unsafe {
310                std::str::from_utf8_unchecked(&haystack.as_bytes()[start..(end.byte() - gap)])
311            }
312            .chars()
313            .fold(end, |p, b| p.rev(b));
314
315            last_point = start;
316
317            Some((start, end))
318        })
319    }
320
321    /// Whether or not the regex matches a specific pattern
322    pub fn matches(&mut self, query: impl AsRef<[u8]>) -> bool {
323        let input = Input::new(&query).anchored(Anchored::Yes);
324
325        let Ok(Some(half)) = self.fwd_dfa.try_search_fwd(&mut self.fwd_cache, &input) else {
326            return false;
327        };
328
329        half.offset() == query.as_ref().len()
330    }
331
332    /// Whether or not there even is a pattern to search for
333    pub fn is_empty(&self) -> bool {
334        self.pat.is_empty()
335    }
336}
337
338struct DFAs {
339    fwd: (DFA, RwLock<Cache>),
340    rev: (DFA, RwLock<Cache>),
341}
342
343fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
344    static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
345        LazyLock::new(RwLock::default);
346
347    let mut list = DFA_LIST.write();
348
349    let mut bytes = [0; 4];
350    let pat = pat.as_patterns(&mut bytes);
351
352    if let Some(dfas) = list.get(&pat) {
353        Ok(*dfas)
354    } else {
355        let pat = pat.leak();
356        let (fwd, rev) = pat.dfas()?;
357
358        let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
359        let dfas = Box::leak(Box::new(DFAs {
360            fwd: (fwd, RwLock::new(fwd_cache)),
361            rev: (rev, RwLock::new(rev_cache)),
362        }));
363        let _ = list.insert(pat, dfas);
364        Ok(dfas)
365    }
366}
367
368#[derive(Clone, Copy, PartialEq, Eq, Hash)]
369enum Patterns<'a> {
370    One(&'a str),
371    Many(&'a [&'static str]),
372}
373
374impl Patterns<'_> {
375    fn leak(&self) -> Patterns<'static> {
376        match self {
377            Patterns::One(str) => Patterns::One(String::from(*str).leak()),
378            Patterns::Many(strs) => Patterns::Many(Vec::from(*strs).leak()),
379        }
380    }
381
382    fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
383        let mut fwd_builder = DFA::builder();
384        fwd_builder.thompson(Config::new().utf8(false));
385        let mut rev_builder = DFA::builder();
386        rev_builder.thompson(Config::new().reverse(true).utf8(false));
387
388        match self {
389            Patterns::One(pat) => {
390                regex_syntax::Parser::new().parse(pat)?;
391                let fwd = fwd_builder.build(pat).unwrap();
392                let rev = rev_builder.build(pat).unwrap();
393                Ok((fwd, rev))
394            }
395            Patterns::Many(pats) => {
396                for pat in *pats {
397                    regex_syntax::Parser::new().parse(pat)?;
398                }
399                let fwd = fwd_builder.build_many(pats).unwrap();
400                let rev = rev_builder.build_many(pats).unwrap();
401                Ok((fwd, rev))
402            }
403        }
404    }
405}
406
407pub trait RegexPattern: InnerRegexPattern {
408    type Match: 'static;
409
410    fn get_match(points: (Point, Point), pattern: PatternID) -> Self::Match;
411}
412
413impl RegexPattern for &str {
414    type Match = (Point, Point);
415
416    fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
417        points
418    }
419}
420
421impl RegexPattern for String {
422    type Match = (Point, Point);
423
424    fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
425        points
426    }
427}
428
429impl RegexPattern for &String {
430    type Match = (Point, Point);
431
432    fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
433        points
434    }
435}
436
437impl RegexPattern for char {
438    type Match = (Point, Point);
439
440    fn get_match(points: (Point, Point), _pattern: PatternID) -> Self::Match {
441        points
442    }
443}
444
445impl<const N: usize> RegexPattern for [&'static str; N] {
446    type Match = (Point, Point, usize);
447
448    fn get_match(points: (Point, Point), pattern: PatternID) -> Self::Match {
449        (points.0, points.1, pattern.as_usize())
450    }
451}
452
453impl RegexPattern for &[&'static str] {
454    type Match = (Point, Point, usize);
455
456    fn get_match(points: (Point, Point), pattern: PatternID) -> Self::Match {
457        (points.0, points.1, pattern.as_usize())
458    }
459}
460
461trait InnerRegexPattern {
462    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
463}
464
465impl InnerRegexPattern for &str {
466    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
467        Patterns::One(self)
468    }
469}
470
471impl InnerRegexPattern for String {
472    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
473        Patterns::One(self)
474    }
475}
476
477impl InnerRegexPattern for &String {
478    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
479        Patterns::One(self)
480    }
481}
482
483impl InnerRegexPattern for char {
484    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
485        Patterns::One(self.encode_utf8(bytes) as &str)
486    }
487}
488
489impl<const N: usize> InnerRegexPattern for [&'static str; N] {
490    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
491        Patterns::Many(self)
492    }
493}
494
495impl InnerRegexPattern for &[&'static str] {
496    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
497        Patterns::Many(self)
498    }
499}