Skip to main content

duat_core/text/
search.rs

1//! Utilities for searching the [`Text`]
2//!
3//! This includes some methods for the [`Text`] itself, meant for
4//! general use when editing it. It also has the [`Searcher`] struct,
5//! which is used when doing [incremental search] in the
6//! [`PromptLine`]. This iterator is then used in a [`IncSearcher`]
7//! that can decide what to do with the results.
8//!
9//! [incremental search]: https://docs.rs/duat/latest/duat/modes/struct.IncSearcher.html
10//! [`PromptLine`]: https://docs.rs/duat/latest/duat/widgets/struct.PromptLine.html
11//! [`IncSearcher`]: https://docs.rs/duat/latest/duat/modes/trait.IncSearcher.html
12//! [`Text`]: super::Text
13use std::{
14    collections::HashMap,
15    marker::PhantomData,
16    ops::Range,
17    sync::{LazyLock, Mutex},
18};
19
20use regex_cursor::{
21    Cursor, Input,
22    engines::hybrid::{try_search_fwd, try_search_rev},
23    regex_automata::{
24        PatternID,
25        hybrid::dfa::{Cache, DFA},
26        nfa::thompson,
27        util::syntax,
28    },
29};
30
31use super::TextRange;
32use crate::text::Strs;
33
34/// An [`Iterator`] over the matches returned by a search on a
35/// [haystack]
36///
37/// This is most commonly used with the [`Strs`] structs, although it
38/// is also available with `&str` and any type implementing
39/// [`Deref<Target = str>`]
40///
41/// [haystack]: RegexHaystack
42/// [`Deref<Target = str>`]: std::ops::Deref
43#[derive(Clone)]
44pub struct Matches<'m, R> {
45    haystack: [&'m [u8]; 2],
46    b_start: usize,
47    dfas: &'static DFAs,
48    fwd_input: Input<SearchBytes<'m>>,
49    rev_input: Input<SearchBytes<'m>>,
50    rev_match: Option<Range<usize>>,
51    _ghost: PhantomData<R>,
52}
53
54impl<'m, R> Matches<'m, R> {
55    /// Changes the [`TextRange`] to search on
56    ///
57    /// This _will_ reset the [`Iterator`], if it was returning
58    /// [`None`] before, it might start returning [`Some`] again if
59    /// the pattern exists in the specified [`Range`]
60    ///
61    /// # Note
62    ///
63    /// There is a crucial difference between
64    /// `text.search(pat).range(...)` and `text[...].search(pat)`, in
65    /// that the latter will return matches relative to the `...`
66    /// range, while the former will return matches relative to the
67    /// start of the `text`.
68    #[track_caller]
69    pub fn range(self, range: impl TextRange) -> Self {
70        let [s0, s1] = self.haystack;
71        let range = range.to_range(s0.len() + s1.len());
72        let i0 = &s0[range.start.min(s0.len())..range.end.min(s0.len())];
73        let i1 = &s1[range.start.saturating_sub(s0.len())..range.end.saturating_sub(s0.len())];
74
75        Self {
76            fwd_input: Input::new(SearchBytes([i0, i1], 0)),
77            rev_input: Input::new(SearchBytes([i0, i1], 0)),
78            b_start: range.start,
79            ..self
80        }
81    }
82}
83
84impl<'m, R: RegexPattern> Iterator for Matches<'m, R> {
85    type Item = R::Match;
86
87    fn next(&mut self) -> Option<Self::Item> {
88        let mut fwd_cache = match self.dfas.fwd_cache.lock() {
89            Ok(cache) => cache,
90            Err(err) => err.into_inner(),
91        };
92
93        let (fwd_input, rev_input) = (&mut self.fwd_input, &mut self.rev_input);
94        let h_end = try_search_fwd(&self.dfas.fwd_dfa, &mut fwd_cache, fwd_input)
95            .ok()
96            .flatten()?
97            .offset();
98
99        let mut rev_cache = match self.dfas.rev_cache.lock() {
100            Ok(cache) => cache,
101            Err(err) => err.into_inner(),
102        };
103
104        rev_input.set_range(fwd_input.start()..h_end);
105        let half = try_search_rev(&self.dfas.rev_dfa, &mut rev_cache, rev_input)
106            .ok()
107            .flatten()?;
108        let h_start = half.offset();
109
110        fwd_input.set_start(h_end);
111
112        // To not repeatedly match the same empty thing over and over.
113        if h_start == h_end {
114            fwd_input.set_start(h_end + 1);
115        }
116
117        Some(R::get_match(
118            self.b_start + h_start..self.b_start + h_end,
119            half.pattern(),
120        ))
121    }
122}
123
124impl<'m, R: RegexPattern> DoubleEndedIterator for Matches<'m, R> {
125    fn next_back(&mut self) -> Option<Self::Item> {
126        let mut rev_cache = match self.dfas.rev_cache.lock() {
127            Ok(cache) => cache,
128            Err(err) => err.into_inner(),
129        };
130
131        let (fwd_input, rev_input) = (&mut self.fwd_input, &mut self.rev_input);
132
133        let h_start = try_search_rev(&self.dfas.rev_dfa, &mut rev_cache, fwd_input)
134            .ok()
135            .flatten()?
136            .offset();
137
138        let mut fwd_cache = match self.dfas.fwd_cache.lock() {
139            Ok(cache) => cache,
140            Err(err) => err.into_inner(),
141        };
142
143        rev_input.set_range(h_start..fwd_input.end());
144        let half = try_search_fwd(&self.dfas.fwd_dfa, &mut fwd_cache, rev_input)
145            .ok()
146            .flatten()?;
147        let h_end = half.offset();
148
149        fwd_input.set_end(h_start);
150
151        // To not repeatedly match the same empty thing over and over.
152        if h_start == h_end {
153            if self.rev_match == Some(self.b_start + h_start..self.b_start + h_end) {
154                return None;
155            } else if h_start > 0 {
156                fwd_input.set_end(h_start - 1);
157            }
158        }
159
160        self.rev_match = Some(self.b_start + h_start..self.b_start + h_end);
161
162        Some(R::get_match(
163            self.b_start + h_start..self.b_start + h_end,
164            half.pattern(),
165        ))
166    }
167}
168
169/// A type searcheable by [`DFA`]s
170///
171/// This type is used to create the [`Matches`] [`Iterator`], a useful
172/// and configurable iterator over the matches in the `Haystack`,
173/// primarily on the [`Strs`] type.
174pub trait RegexHaystack<'h> {
175    /// An [`Iterator`] over the matches for a given [`RegexPattern`]
176    ///
177    /// This `Iterator` will search through the entire range of the
178    /// haystack. If the haystack is [`Strs`], for example, then it
179    /// will search through the [`Strs::byte_range`]. You can also set
180    /// a custom range for search through the [`Matches::range`]
181    /// method, which will reset the search to encompass the part of a
182    /// [`TextRange`] that is clipped by the haystack.
183    ///
184    /// This `Iterator` also implements [`DoubleEndedIterator`], which
185    /// means that you can get the elements in reverse order.
186    ///
187    /// # Panics
188    ///
189    /// This function will panic if the [`RegexPattern`] isn't valid.
190    /// If you want a non panicking variety, check out
191    /// [`RegexHaystack::try_search`]
192    #[track_caller]
193    fn search<R: RegexPattern>(&'h self, pat: R) -> Matches<'h, R> {
194        match self.try_search(pat) {
195            Ok(matches) => matches,
196            Err(err) => panic!("{err}"),
197        }
198    }
199
200    /// An [`Iterator`] over the matches for a given [`RegexPattern`]
201    ///
202    /// This `Iterator` will search through the entire range of the
203    /// haystack. If the haystack is [`Strs`], for example, then it
204    /// will search through the [`Strs::byte_range`]. You can also set
205    /// a custom range for search through the [`Matches::range`]
206    /// method, which will reset the search to encompass the part of a
207    /// [`TextRange`] that is clipped by the haystack.
208    ///
209    /// This `Iterator` also implements [`DoubleEndedIterator`], which
210    /// means that you can get the elements in reverse order.
211    ///
212    /// This function will return [`Err`] if the regex pattern is not
213    /// valid. If you want a panicking variety, check out
214    /// [`RegexHaystack::search`]
215    fn try_search<R: RegexPattern>(
216        &'h self,
217        pat: R,
218    ) -> Result<Matches<'h, R>, Box<regex_syntax::Error>>;
219
220    /// Wether this haystack contains a match for a [`RegexPattern`]
221    ///
222    /// This is equivalent to calling `self.search().map(|iter|
223    /// iter.next().is_some())`.
224    ///
225    /// This function will return [`Err`] if the regex pattern is not
226    /// valid.
227    fn contains_pat(&'h self, pat: impl RegexPattern) -> Result<bool, Box<regex_syntax::Error>> {
228        Ok(self.try_search(pat)?.next().is_some())
229    }
230
231    /// Wether this haystack matches the [`RegexPattern`] exactly
232    ///
233    /// This function will return [`Err`] if the regex pattern is not
234    /// valid.
235    fn matches_pat(&'h self, pat: impl RegexPattern) -> Result<bool, Box<regex_syntax::Error>> {
236        let mut matches = self.try_search(pat)?;
237        Ok(matches
238            .next()
239            .is_some_and(|_| matches.fwd_input.start() == matches.fwd_input.end()))
240    }
241}
242
243impl<'b> RegexHaystack<'b> for Strs {
244    fn try_search<R: RegexPattern>(
245        &'b self,
246        pat: R,
247    ) -> Result<Matches<'b, R>, Box<regex_syntax::Error>> {
248        let dfas = dfas_from_pat(pat)?;
249
250        let haystack = self.to_array().map(str::as_bytes);
251
252        Ok(Matches {
253            haystack,
254            b_start: 0,
255            dfas,
256            fwd_input: Input::new(SearchBytes(haystack, 0)),
257            rev_input: Input::new(SearchBytes(haystack, 0)),
258            rev_match: None,
259            _ghost: PhantomData,
260        })
261    }
262}
263
264impl<'s, S: std::ops::Deref<Target = str>> RegexHaystack<'s> for S {
265    fn try_search<R: RegexPattern>(
266        &'s self,
267        pat: R,
268    ) -> Result<Matches<'s, R>, Box<regex_syntax::Error>> {
269        let dfas = dfas_from_pat(pat)?;
270
271        let haystack = [self.deref().as_bytes(), &[]];
272
273        Ok(Matches {
274            haystack,
275            b_start: 0,
276            dfas,
277            fwd_input: Input::new(SearchBytes(haystack, 0)),
278            rev_input: Input::new(SearchBytes(haystack, 0)),
279            rev_match: None,
280            _ghost: PhantomData,
281        })
282    }
283}
284
285struct DFAs {
286    fwd_dfa: DFA,
287    fwd_cache: Mutex<Cache>,
288    rev_dfa: DFA,
289    rev_cache: Mutex<Cache>,
290}
291
292fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
293    static DFA_LIST: LazyLock<Mutex<HashMap<Patterns<'static>, &'static DFAs>>> =
294        LazyLock::new(Mutex::default);
295
296    let mut list = DFA_LIST.lock().unwrap();
297
298    let mut bytes = [0; 4];
299    let pat = pat.as_patterns(&mut bytes);
300
301    if let Some(dfas) = list.get(&pat) {
302        Ok(*dfas)
303    } else {
304        let pat = pat.leak();
305        let (fwd_dfa, rev_dfa) = pat.dfas()?;
306
307        let (fwd_cache, rev_cache) = (Cache::new(&fwd_dfa), Cache::new(&rev_dfa));
308        let dfas = Box::leak(Box::new(DFAs {
309            fwd_dfa,
310            fwd_cache: Mutex::new(fwd_cache),
311            rev_dfa,
312            rev_cache: Mutex::new(rev_cache),
313        }));
314        let _ = list.insert(pat, dfas);
315        Ok(dfas)
316    }
317}
318
319#[derive(Clone, Copy, PartialEq, Eq, Hash)]
320enum Patterns<'a> {
321    One(&'a str),
322    Many(&'a [&'a str]),
323}
324
325impl Patterns<'_> {
326    fn leak(&self) -> Patterns<'static> {
327        match self {
328            Patterns::One(str) => Patterns::One(String::from(*str).leak()),
329            Patterns::Many(strs) => Patterns::Many(
330                strs.iter()
331                    .map(|s| {
332                        let str: &'static str = s.to_string().leak();
333                        str
334                    })
335                    .collect::<Vec<&'static str>>()
336                    .leak(),
337            ),
338        }
339    }
340
341    fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
342        let mut fwd_builder = DFA::builder();
343        fwd_builder.syntax(syntax::Config::new().multi_line(true));
344        let mut rev_builder = DFA::builder();
345        rev_builder
346            .syntax(syntax::Config::new().multi_line(true))
347            .thompson(thompson::Config::new().reverse(true));
348
349        match self {
350            Patterns::One(pat) => {
351                let pat = pat.replace("\\b", "(?-u:\\b)");
352                syntax::parse(&pat)?;
353                let fwd = fwd_builder.build(&pat).unwrap();
354                let rev = rev_builder.build(&pat).unwrap();
355                Ok((fwd, rev))
356            }
357            Patterns::Many(pats) => {
358                let pats: Vec<String> =
359                    pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
360                for pat in pats.iter() {
361                    regex_syntax::Parser::new().parse(pat)?;
362                }
363                let fwd = fwd_builder.build_many(&pats).unwrap();
364                let rev = rev_builder.build_many(&pats).unwrap();
365                Ok((fwd, rev))
366            }
367        }
368    }
369}
370
371/// A regex pattern to search for
372///
373/// It can either be a single `&str`, or a list of `&str`s, in which
374/// case the matched pattern will be specified.
375pub trait RegexPattern: InnerRegexPattern {
376    /// Eiter a [`Range<usize>`] or `(usize, Range<usize>)`
377    type Match;
378
379    /// transforms a matched pattern into [`RegexPattern::Match`]
380    fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match;
381}
382
383impl RegexPattern for &str {
384    type Match = Range<usize>;
385
386    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
387        points
388    }
389}
390
391impl RegexPattern for String {
392    type Match = Range<usize>;
393
394    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
395        points
396    }
397}
398
399impl RegexPattern for &String {
400    type Match = Range<usize>;
401
402    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
403        points
404    }
405}
406
407impl RegexPattern for char {
408    type Match = Range<usize>;
409
410    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
411        points
412    }
413}
414
415impl<const N: usize> RegexPattern for [&str; N] {
416    type Match = (usize, Range<usize>);
417
418    fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
419        (pattern.as_usize(), points)
420    }
421}
422
423impl RegexPattern for &[&str] {
424    type Match = (usize, Range<usize>);
425
426    fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
427        (pattern.as_usize(), points)
428    }
429}
430
431trait InnerRegexPattern {
432    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
433}
434
435impl InnerRegexPattern for &str {
436    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
437        Patterns::One(self)
438    }
439}
440
441impl InnerRegexPattern for String {
442    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
443        Patterns::One(self)
444    }
445}
446
447impl InnerRegexPattern for &String {
448    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
449        Patterns::One(self)
450    }
451}
452
453impl InnerRegexPattern for char {
454    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
455        Patterns::One(self.encode_utf8(bytes) as &str)
456    }
457}
458
459impl<const N: usize> InnerRegexPattern for [&str; N] {
460    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
461        Patterns::Many(self)
462    }
463}
464
465impl InnerRegexPattern for &[&str] {
466    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
467        Patterns::Many(self)
468    }
469}
470
471#[derive(Clone, Copy)]
472struct SearchBytes<'a>([&'a [u8]; 2], usize);
473
474impl Cursor for SearchBytes<'_> {
475    fn chunk(&self) -> &[u8] {
476        self.0[self.1]
477    }
478
479    fn advance(&mut self) -> bool {
480        if self.1 == 0 {
481            self.1 += 1;
482            true
483        } else {
484            false
485        }
486    }
487
488    fn backtrack(&mut self) -> bool {
489        if self.1 == 1 {
490            self.1 -= 1;
491            true
492        } else {
493            false
494        }
495    }
496
497    fn total_bytes(&self) -> Option<usize> {
498        Some(self.0[0].len() + self.0[1].len())
499    }
500
501    fn offset(&self) -> usize {
502        match self.1 {
503            1 => self.0[0].len(),
504            _ => 0,
505        }
506    }
507}