duat_core/text/
search.rs

1//! Utilities for searching the [`Text`]
2//!
3//! This includes some methods for the [`Text`] itself, meant for
4//! general use when editing it. It also has the [`Searcher`] struct,
5//! which is used when doing [incremental search] in the [`CmdLine`].
6//! This iterator is then used in a [`IncSearcher`] that can decide
7//! what to do with the results.
8//!
9//! [incremental search]: crate::widgets::IncSearch
10//! [`CmdLine`]: crate::widgets::CmdLine
11//! [`IncSearcher`]: crate::mode::IncSearcher
12use std::{collections::HashMap, ops::RangeBounds, sync::LazyLock};
13
14use parking_lot::{RwLock, RwLockWriteGuard};
15use regex_automata::{
16    Anchored, Input, PatternID,
17    hybrid::dfa::{Cache, DFA},
18    nfa::thompson,
19    util::syntax,
20};
21
22use super::{Point, Text, TextRange};
23
24impl Text {
25    pub fn search_fwd<R: RegexPattern>(
26        &mut self,
27        pat: R,
28        range: impl TextRange,
29    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
30        let bytes = self.bytes_mut();
31        let range = range.to_range_fwd(bytes.len().byte());
32        let dfas = dfas_from_pat(pat)?;
33        let haystack = {
34            bytes.make_contiguous(range.clone());
35            bytes.get_contiguous(range.clone()).unwrap()
36        };
37
38        let mut fwd_input = Input::new(haystack);
39        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
40        let mut fwd_cache = dfas.fwd.1.write();
41        let mut rev_cache = dfas.rev.1.write();
42
43        let bytes = bytes as &super::Bytes;
44        Ok(std::iter::from_fn(move || {
45            let init = fwd_input.start();
46            let h_end = loop {
47                if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
48                    // Ignore empty matches at the start of the input.
49                    if half.offset() == init {
50                        fwd_input.set_start(init + 1);
51                    } else {
52                        break half.offset();
53                    }
54                } else {
55                    return None;
56                }
57            };
58
59            fwd_input.set_start(h_end);
60            rev_input.set_end(h_end);
61
62            let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
63                return None;
64            };
65            let h_start = half.offset();
66
67            let p0 = bytes.point_at(h_start + range.start);
68            let p1 = bytes.point_at(h_end + range.start);
69
70            Some(R::get_match([p0, p1], half.pattern()))
71        }))
72    }
73
74    /// Returns an iterator over the reverse matches of the regex
75    pub fn search_rev<R: RegexPattern>(
76        &mut self,
77        pat: R,
78        range: impl TextRange,
79    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
80        let bytes = self.bytes_mut();
81        let range = range.to_range_rev(bytes.len().byte());
82        let dfas = dfas_from_pat(pat)?;
83        let haystack = {
84            bytes.make_contiguous(range.clone());
85            bytes.get_contiguous(range.clone()).unwrap()
86        };
87
88        let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
89        let mut rev_input = Input::new(haystack);
90        let mut fwd_cache = dfas.fwd.1.write();
91        let mut rev_cache = dfas.rev.1.write();
92
93        let bytes = bytes as &super::Bytes;
94        let gap = range.start;
95        Ok(std::iter::from_fn(move || {
96            let init = rev_input.end();
97            let start = loop {
98                if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
99                    // Ignore empty matches at the end of the input.
100                    if half.offset() == init {
101                        rev_input.set_end(init.checked_sub(1)?);
102                    } else {
103                        break half.offset();
104                    }
105                } else {
106                    return None;
107                }
108            };
109
110            rev_input.set_end(start);
111            fwd_input.set_start(start);
112
113            let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
114                return None;
115            };
116            let end = half.offset();
117
118            let p0 = bytes.point_at(start + gap);
119            let p1 = bytes.point_at(end + gap);
120
121            Some(R::get_match([p0, p1], half.pattern()))
122        }))
123    }
124
125    /// Returns true if the pattern is found in the given range
126    ///
127    /// This is unanchored by default, if you want an anchored search,
128    /// use the `"^$"` characters.
129    pub fn matches(
130        &mut self,
131        pat: impl RegexPattern,
132        range: impl TextRange,
133    ) -> Result<bool, Box<regex_syntax::Error>> {
134        let range = range.to_range_fwd(self.len().byte());
135        let dfas = dfas_from_pat(pat)?;
136
137        let haystack = self.contiguous(range);
138        let fwd_input = Input::new(haystack);
139
140        let mut fwd_cache = dfas.fwd.1.write();
141        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
142            Ok(true)
143        } else {
144            Ok(false)
145        }
146    }
147}
148
149pub trait Matcheable: Sized {
150    fn matches(
151        &self,
152        pat: impl RegexPattern,
153        range: impl RangeBounds<usize> + Clone,
154    ) -> Result<bool, Box<regex_syntax::Error>>;
155}
156
157impl<const N: usize> Matcheable for std::array::IntoIter<&str, N> {
158    fn matches(
159        &self,
160        pat: impl RegexPattern,
161        range: impl RangeBounds<usize> + Clone,
162    ) -> Result<bool, Box<regex_syntax::Error>> {
163        let str: String = self.as_slice().iter().copied().collect();
164        let (start, end) = crate::get_ends(range, str.len());
165        let dfas = dfas_from_pat(pat)?;
166        let fwd_input =
167            Input::new(unsafe { std::str::from_utf8_unchecked(&str.as_bytes()[start..end]) });
168
169        let mut fwd_cache = dfas.fwd.1.write();
170        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
171            Ok(true)
172        } else {
173            Ok(false)
174        }
175    }
176}
177
178impl Matcheable for &'_ str {
179    fn matches(
180        &self,
181        pat: impl RegexPattern,
182        range: impl RangeBounds<usize> + Clone,
183    ) -> Result<bool, Box<regex_syntax::Error>> {
184        let (start, end) = crate::get_ends(range, self.len());
185        let dfas = dfas_from_pat(pat)?;
186        let fwd_input =
187            Input::new(unsafe { std::str::from_utf8_unchecked(&self.as_bytes()[start..end]) });
188
189        let mut fwd_cache = dfas.fwd.1.write();
190        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
191            Ok(true)
192        } else {
193            Ok(false)
194        }
195    }
196}
197
198pub struct Searcher {
199    pat: String,
200    fwd_dfa: &'static DFA,
201    rev_dfa: &'static DFA,
202    fwd_cache: RwLockWriteGuard<'static, Cache>,
203    rev_cache: RwLockWriteGuard<'static, Cache>,
204}
205
206impl Searcher {
207    pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
208        let dfas = dfas_from_pat(&pat)?;
209        Ok(Self {
210            pat,
211            fwd_dfa: &dfas.fwd.0,
212            rev_dfa: &dfas.rev.0,
213            fwd_cache: dfas.fwd.1.write(),
214            rev_cache: dfas.rev.1.write(),
215        })
216    }
217
218    pub fn search_fwd<'b>(
219        &'b mut self,
220        text: &'b mut Text,
221        range: impl TextRange,
222    ) -> impl Iterator<Item = [Point; 2]> + 'b {
223        let range = range.to_range_fwd(text.len().byte());
224        let mut last_point = text.point_at(range.start);
225
226        let haystack = text.contiguous(range.clone());
227        let mut fwd_input = Input::new(haystack).anchored(Anchored::No);
228        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
229
230        let fwd_dfa = &self.fwd_dfa;
231        let rev_dfa = &self.rev_dfa;
232        let fwd_cache = &mut self.fwd_cache;
233        let rev_cache = &mut self.rev_cache;
234        let gap = range.start;
235        std::iter::from_fn(move || {
236            let init = fwd_input.start();
237            let end = loop {
238                if let Ok(Some(half)) = fwd_dfa.try_search_fwd(fwd_cache, &fwd_input) {
239                    // Ignore empty matches at the start of the input.
240                    if half.offset() == init {
241                        fwd_input.set_start(init + 1);
242                    } else {
243                        break half.offset();
244                    }
245                } else {
246                    return None;
247                }
248            };
249
250            fwd_input.set_start(end);
251            rev_input.set_end(end);
252
253            let half = unsafe {
254                rev_dfa
255                    .try_search_rev(rev_cache, &rev_input)
256                    .unwrap()
257                    .unwrap_unchecked()
258            };
259            let start = half.offset();
260
261            let start = unsafe {
262                std::str::from_utf8_unchecked(&haystack.as_bytes()[last_point.byte() - gap..start])
263            }
264            .chars()
265            .fold(last_point, |p, b| p.fwd(b));
266            let end = unsafe {
267                std::str::from_utf8_unchecked(&haystack.as_bytes()[start.byte() - gap..end])
268            }
269            .chars()
270            .fold(start, |p, b| p.fwd(b));
271
272            last_point = end;
273
274            Some([start, end])
275        })
276    }
277
278    pub fn search_rev<'b>(
279        &'b mut self,
280        text: &'b mut Text,
281        range: impl TextRange,
282    ) -> impl Iterator<Item = [Point; 2]> + 'b {
283        let range = range.to_range_rev(text.len().byte());
284        let mut last_point = text.point_at(range.end);
285
286        let haystack = text.contiguous(range.clone());
287        let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
288        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
289
290        let fwd_dfa = &self.fwd_dfa;
291        let rev_dfa = &self.rev_dfa;
292        let fwd_cache = &mut self.fwd_cache;
293        let rev_cache = &mut self.rev_cache;
294        let gap = range.start;
295        std::iter::from_fn(move || {
296            let init = rev_input.end();
297            let start = loop {
298                if let Ok(Some(half)) = rev_dfa.try_search_rev(rev_cache, &rev_input) {
299                    // Ignore empty matches at the end of the input.
300                    if half.offset() == init {
301                        rev_input.set_end(init - 1);
302                    } else {
303                        break half.offset();
304                    }
305                } else {
306                    return None;
307                }
308            };
309
310            fwd_input.set_start(start);
311            rev_input.set_end(start);
312
313            let half = fwd_dfa
314                .try_search_fwd(fwd_cache, &fwd_input)
315                .unwrap()
316                .unwrap();
317
318            let end = unsafe {
319                std::str::from_utf8_unchecked(
320                    &haystack.as_bytes()[half.offset()..(last_point.byte() - gap)],
321                )
322            }
323            .chars()
324            .fold(last_point, |p, b| p.rev(b));
325            let start = unsafe {
326                std::str::from_utf8_unchecked(&haystack.as_bytes()[start..(end.byte() - gap)])
327            }
328            .chars()
329            .fold(end, |p, b| p.rev(b));
330
331            last_point = start;
332
333            Some([start, end])
334        })
335    }
336
337    /// Whether or not the regex matches a specific pattern
338    pub fn matches(&mut self, query: impl AsRef<[u8]>) -> bool {
339        let input = Input::new(&query).anchored(Anchored::Yes);
340
341        let Ok(Some(half)) = self.fwd_dfa.try_search_fwd(&mut self.fwd_cache, &input) else {
342            return false;
343        };
344
345        half.offset() == query.as_ref().len()
346    }
347
348    /// Whether or not there even is a pattern to search for
349    pub fn is_empty(&self) -> bool {
350        self.pat.is_empty()
351    }
352}
353
354struct DFAs {
355    fwd: (DFA, RwLock<Cache>),
356    rev: (DFA, RwLock<Cache>),
357}
358
359fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
360    static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
361        LazyLock::new(RwLock::default);
362
363    let mut list = DFA_LIST.write();
364
365    let mut bytes = [0; 4];
366    let pat = pat.as_patterns(&mut bytes);
367
368    if let Some(dfas) = list.get(&pat) {
369        Ok(*dfas)
370    } else {
371        let pat = pat.leak();
372        let (fwd, rev) = pat.dfas()?;
373
374        let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
375        let dfas = Box::leak(Box::new(DFAs {
376            fwd: (fwd, RwLock::new(fwd_cache)),
377            rev: (rev, RwLock::new(rev_cache)),
378        }));
379        let _ = list.insert(pat, dfas);
380        Ok(dfas)
381    }
382}
383
384#[derive(Clone, Copy, PartialEq, Eq, Hash)]
385enum Patterns<'a> {
386    One(&'a str),
387    Many(&'a [&'a str]),
388}
389
390impl Patterns<'_> {
391    fn leak(&self) -> Patterns<'static> {
392        match self {
393            Patterns::One(str) => Patterns::One(String::from(*str).leak()),
394            Patterns::Many(strs) => Patterns::Many(
395                strs.iter()
396                    .map(|s| {
397                        let str: &'static str = s.to_string().leak();
398                        str
399                    })
400                    .collect::<Vec<&'static str>>()
401                    .leak(),
402            ),
403        }
404    }
405
406    fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
407        let mut fwd_builder = DFA::builder();
408        fwd_builder.syntax(syntax::Config::new().multi_line(true));
409        let mut rev_builder = DFA::builder();
410        rev_builder
411            .syntax(syntax::Config::new().multi_line(true))
412            .thompson(thompson::Config::new().reverse(true));
413
414        match self {
415            Patterns::One(pat) => {
416                let pat = pat.replace("\\b", "(?-u:\\b)");
417                regex_syntax::Parser::new().parse(&pat)?;
418                let fwd = fwd_builder.build(&pat).unwrap();
419                let rev = rev_builder.build(&pat).unwrap();
420                Ok((fwd, rev))
421            }
422            Patterns::Many(pats) => {
423                let pats: Vec<String> =
424                    pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
425                for pat in pats.iter() {
426                    regex_syntax::Parser::new().parse(pat)?;
427                }
428                let fwd = fwd_builder.build_many(&pats).unwrap();
429                let rev = rev_builder.build_many(&pats).unwrap();
430                Ok((fwd, rev))
431            }
432        }
433    }
434}
435
436pub trait RegexPattern: InnerRegexPattern {
437    type Match: 'static;
438
439    fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match;
440}
441
442impl RegexPattern for &str {
443    type Match = [Point; 2];
444
445    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
446        points
447    }
448}
449
450impl RegexPattern for String {
451    type Match = [Point; 2];
452
453    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
454        points
455    }
456}
457
458impl RegexPattern for &String {
459    type Match = [Point; 2];
460
461    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
462        points
463    }
464}
465
466impl RegexPattern for char {
467    type Match = [Point; 2];
468
469    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
470        points
471    }
472}
473
474impl<const N: usize> RegexPattern for [&str; N] {
475    type Match = (usize, [Point; 2]);
476
477    fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
478        (pattern.as_usize(), points)
479    }
480}
481
482impl RegexPattern for &[&str] {
483    type Match = (usize, [Point; 2]);
484
485    fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
486        (pattern.as_usize(), points)
487    }
488}
489
490trait InnerRegexPattern {
491    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
492}
493
494impl InnerRegexPattern for &str {
495    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
496        Patterns::One(self)
497    }
498}
499
500impl InnerRegexPattern for String {
501    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
502        Patterns::One(self)
503    }
504}
505
506impl InnerRegexPattern for &String {
507    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
508        Patterns::One(self)
509    }
510}
511
512impl InnerRegexPattern for char {
513    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
514        Patterns::One(self.encode_utf8(bytes) as &str)
515    }
516}
517
518impl<const N: usize> InnerRegexPattern for [&str; N] {
519    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
520        Patterns::Many(self)
521    }
522}
523
524impl InnerRegexPattern for &[&str] {
525    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
526        Patterns::Many(self)
527    }
528}