duat_core/text/
search.rs

1//! Utilities for searching the [`Text`]
2//!
3//! This includes some methods for the [`Text`] itself, meant for
4//! general use when editing it. It also has the [`Searcher`] struct,
5//! which is used when doing [incremental search] in the
6//! [`PromptLine`]. This iterator is then used in a [`IncSearcher`]
7//! that can decide what to do with the results.
8//!
9//! [incremental search]: https://docs.rs/duat/latest/duat/modes/struct.IncSearcher.html
10//! [`PromptLine`]: https://docs.rs/duat/latest/duat/widgets/struct.PromptLine.html
11//! [`IncSearcher`]: https://docs.rs/duat/latest/duat/modes/trait.IncSearcher.html
12//! [`Text`]: super::Text
13use std::{
14    collections::HashMap,
15    ops::{Range, RangeBounds},
16    sync::{LazyLock, RwLock, RwLockWriteGuard},
17};
18
19use regex_cursor::{
20    Cursor, Input,
21    engines::hybrid::{try_search_fwd, try_search_rev},
22    regex_automata::{
23        Anchored, PatternID,
24        hybrid::dfa::{Cache, DFA},
25        nfa::thompson,
26        util::syntax,
27    },
28};
29
30use super::{Bytes, TextRange};
31
32impl Bytes {
33    /// Searches forward for a [`RegexPattern`] in a [range]
34    ///
35    /// A [`RegexPattern`] can either be a single regex string, an
36    /// array of strings, or a slice of strings. When there are more
37    /// than one pattern, The return value will include which pattern
38    /// matched.
39    ///
40    /// The patterns will also automatically be cached, so you don't
41    /// need to do that.
42    ///
43    /// [range]: TextRange
44    pub fn search_fwd<R: RegexPattern>(
45        &self,
46        pat: R,
47        range: impl TextRange,
48    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
49        let range = range.to_range(self.len().byte());
50        let dfas = dfas_from_pat(pat)?;
51
52        let b_start = self.point_at_byte(range.start).byte();
53
54        let (mut fwd_input, mut rev_input) = get_inputs(self, range.clone());
55        rev_input.anchored(Anchored::Yes);
56
57        let mut fwd_cache = dfas.fwd.1.write().unwrap();
58        let mut rev_cache = dfas.rev.1.write().unwrap();
59
60        Ok(std::iter::from_fn(move || {
61            let init = fwd_input.start();
62            let h_end = loop {
63                if let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input)
64                {
65                    // Ignore empty matches at the start of the input.
66                    if half.offset() == init {
67                        fwd_input.set_start(init + 1);
68                    } else {
69                        break half.offset();
70                    }
71                } else {
72                    return None;
73                }
74            };
75
76            fwd_input.set_start(h_end);
77            rev_input.set_end(h_end);
78
79            let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input) else {
80                return None;
81            };
82            let h_start = half.offset();
83
84            Some(R::get_match(
85                b_start + h_start..b_start + h_end,
86                half.pattern(),
87            ))
88        }))
89    }
90
91    /// Searches in reverse for a [`RegexPattern`] in a [range]
92    ///
93    /// A [`RegexPattern`] can either be a single regex string, an
94    /// array of strings, or a slice of strings. When there are more
95    /// than one pattern, The return value will include which pattern
96    /// matched.
97    ///
98    /// The patterns will also automatically be cached, so you don't
99    /// need to do that.
100    ///
101    /// [range]: TextRange
102    pub fn search_rev<R: RegexPattern>(
103        &self,
104        pat: R,
105        range: impl TextRange,
106    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
107        let range = range.to_range(self.len().byte());
108        let dfas = dfas_from_pat(pat)?;
109
110        let (mut fwd_input, mut rev_input) = get_inputs(self, range.clone());
111        fwd_input.anchored(Anchored::Yes);
112
113        let mut fwd_cache = dfas.fwd.1.write().unwrap();
114        let mut rev_cache = dfas.rev.1.write().unwrap();
115
116        Ok(std::iter::from_fn(move || {
117            let init = rev_input.end();
118            let start = loop {
119                if let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input)
120                {
121                    // Ignore empty matches at the end of the input.
122                    if half.offset() == init {
123                        rev_input.set_end(init.checked_sub(1)?);
124                    } else {
125                        break half.offset();
126                    }
127                } else {
128                    return None;
129                }
130            };
131
132            rev_input.set_end(start);
133            fwd_input.set_start(start);
134
135            let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) else {
136                return None;
137            };
138            let end = half.offset();
139
140            Some(R::get_match(
141                range.start + start..range.start + end,
142                half.pattern(),
143            ))
144        }))
145    }
146
147    /// Returns true if the pattern is found in the given range
148    ///
149    /// This is unanchored by default, if you want an anchored search,
150    /// use the `"^$"` characters.
151    pub fn matches(
152        &self,
153        pat: impl RegexPattern,
154        range: impl TextRange,
155    ) -> Result<bool, Box<regex_syntax::Error>> {
156        let range = range.to_range(self.len().byte());
157        let dfas = dfas_from_pat(pat)?;
158
159        let (mut fwd_input, _) = get_inputs(self, range.clone());
160        fwd_input.anchored(Anchored::Yes);
161
162        let mut fwd_cache = dfas.fwd.1.write().unwrap();
163        if let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) {
164            Ok(hm.offset() == range.end)
165        } else {
166            Ok(false)
167        }
168    }
169}
170
171/// A trait to match regexes on `&str`s
172pub trait Matcheable: Sized {
173    /// Returns a forward [`Iterator`] over matches of a given regex
174    fn search_fwd(
175        &self,
176        pat: impl RegexPattern,
177        range: impl RangeBounds<usize> + Clone,
178    ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>>;
179
180    /// Returns a backwards [`Iterator`] over matches of a given regex
181    fn search_rev(
182        &self,
183        pat: impl RegexPattern,
184        range: impl RangeBounds<usize> + Clone,
185    ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>>;
186
187    /// Checks if a type matches a [`RegexPattern`]
188    fn reg_matches(
189        &self,
190        pat: impl RegexPattern,
191        range: impl RangeBounds<usize> + Clone,
192    ) -> Result<bool, Box<regex_syntax::Error>>;
193}
194
195impl<S: AsRef<str>> Matcheable for S {
196    fn search_fwd(
197        &self,
198        pat: impl RegexPattern,
199        range: impl RangeBounds<usize> + Clone,
200    ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>> {
201        let (start, end) = crate::utils::get_ends(range, self.as_ref().len());
202        let str = &self.as_ref()[start..end];
203        let dfas = dfas_from_pat(pat)?;
204
205        let mut fwd_input = Input::new(str);
206        let mut rev_input = Input::new(str);
207        rev_input.anchored(Anchored::Yes);
208
209        let mut fwd_cache = dfas.fwd.1.write().unwrap();
210        let mut rev_cache = dfas.rev.1.write().unwrap();
211
212        Ok(std::iter::from_fn(move || {
213            let init = fwd_input.start();
214            let h_end = loop {
215                if let Ok(Some(half)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input)
216                {
217                    // Ignore empty matches at the start of the input.
218                    if half.offset() == init {
219                        fwd_input.set_start(init + 1);
220                    } else {
221                        break half.offset();
222                    }
223                } else {
224                    return None;
225                }
226            };
227
228            fwd_input.set_start(h_end);
229            rev_input.set_end(h_end);
230
231            let Ok(Some(hm)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input) else {
232                return None;
233            };
234            let h_start = hm.offset();
235
236            Some((start + h_start..start + h_end, &str[h_start..h_end]))
237        }))
238    }
239
240    fn search_rev(
241        &self,
242        pat: impl RegexPattern,
243        range: impl RangeBounds<usize> + Clone,
244    ) -> Result<impl Iterator<Item = (Range<usize>, &str)>, Box<regex_syntax::Error>> {
245        let (start, end) = crate::utils::get_ends(range, self.as_ref().len());
246        let str = &self.as_ref()[start..end];
247        let dfas = dfas_from_pat(pat)?;
248
249        let mut fwd_input = Input::new(str);
250        fwd_input.anchored(Anchored::Yes);
251        let mut rev_input = Input::new(str);
252
253        let mut fwd_cache = dfas.fwd.1.write().unwrap();
254        let mut rev_cache = dfas.rev.1.write().unwrap();
255
256        Ok(std::iter::from_fn(move || {
257            let init = rev_input.end();
258            let h_start = loop {
259                if let Ok(Some(half)) = try_search_rev(&dfas.rev.0, &mut rev_cache, &mut rev_input)
260                {
261                    // Ignore empty matches at the end of the input.
262                    if half.offset() == init {
263                        rev_input.set_end(init.checked_sub(1)?);
264                    } else {
265                        break half.offset();
266                    }
267                } else {
268                    return None;
269                }
270            };
271
272            rev_input.set_end(h_start);
273            fwd_input.set_start(h_start);
274
275            let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) else {
276                return None;
277            };
278            let h_end = hm.offset();
279
280            Some((start + h_start..start + h_end, &str[h_start..h_end]))
281        }))
282    }
283
284    fn reg_matches(
285        &self,
286        pat: impl RegexPattern,
287        range: impl RangeBounds<usize> + Clone,
288    ) -> Result<bool, Box<regex_syntax::Error>> {
289        let (start, end) = crate::utils::get_ends(range, self.as_ref().len());
290        let str = &self.as_ref()[start..end];
291        let dfas = dfas_from_pat(pat)?;
292
293        let mut fwd_input = Input::new(str);
294        fwd_input.anchored(Anchored::Yes);
295
296        let mut fwd_cache = dfas.fwd.1.write().unwrap();
297        if let Ok(Some(hm)) = try_search_fwd(&dfas.fwd.0, &mut fwd_cache, &mut fwd_input) {
298            Ok(start + hm.offset() == end)
299        } else {
300            Ok(false)
301        }
302    }
303}
304
305/// A struct for incremental searching in [`IncSearch`]
306///
307/// [`IncSearch`]: docs.rs/duat/latest/duat/modes/struct.IncSearch.html
308pub struct Searcher {
309    pat: String,
310    fwd_dfa: &'static DFA,
311    rev_dfa: &'static DFA,
312    fwd_cache: RwLockWriteGuard<'static, Cache>,
313    rev_cache: RwLockWriteGuard<'static, Cache>,
314}
315
316impl Searcher {
317    /// Returns a new [`Searcher`]
318    pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
319        let dfas = dfas_from_pat(&pat)?;
320        Ok(Self {
321            pat,
322            fwd_dfa: &dfas.fwd.0,
323            rev_dfa: &dfas.rev.0,
324            fwd_cache: dfas.fwd.1.write().unwrap(),
325            rev_cache: dfas.rev.1.write().unwrap(),
326        })
327    }
328
329    /// Searches forward for the required regex in a [range]
330    ///
331    /// [range]: TextRange
332    pub fn search_fwd<'b>(
333        &'b mut self,
334        ref_bytes: &'b impl AsRef<Bytes>,
335        range: impl TextRange,
336    ) -> impl Iterator<Item = Range<usize>> + 'b {
337        let bytes = ref_bytes.as_ref();
338        let range = range.to_range(bytes.len().byte());
339
340        let (mut fwd_input, mut rev_input) = get_inputs(bytes, range.clone());
341        rev_input.set_anchored(Anchored::Yes);
342
343        let fwd_dfa = &self.fwd_dfa;
344        let rev_dfa = &self.rev_dfa;
345        let fwd_cache = &mut self.fwd_cache;
346        let rev_cache = &mut self.rev_cache;
347
348        std::iter::from_fn(move || {
349            let init = fwd_input.start();
350            let h_end = loop {
351                if let Ok(Some(half)) = try_search_fwd(fwd_dfa, fwd_cache, &mut fwd_input) {
352                    // Ignore empty matches at the start of the input.
353                    if half.offset() == init {
354                        fwd_input.set_start(init + 1);
355                    } else {
356                        break half.offset();
357                    }
358                } else {
359                    return None;
360                }
361            };
362
363            fwd_input.set_start(h_end);
364            rev_input.set_end(h_end);
365
366            let h_start = unsafe {
367                try_search_rev(rev_dfa, rev_cache, &mut rev_input)
368                    .unwrap()
369                    .unwrap_unchecked()
370                    .offset()
371            };
372
373            Some(range.start + h_start..range.start + h_end)
374        })
375    }
376
377    /// Searches in reverse for the required regex in a range[range]
378    ///
379    /// [range]: TextRange
380    pub fn search_rev<'b>(
381        &'b mut self,
382        ref_bytes: &'b impl AsRef<Bytes>,
383        range: impl TextRange,
384    ) -> impl Iterator<Item = Range<usize>> + 'b {
385        let bytes = ref_bytes.as_ref();
386        let range = range.to_range(bytes.len().byte());
387
388        let (mut fwd_input, mut rev_input) = get_inputs(bytes, range.clone());
389        fwd_input.anchored(Anchored::Yes);
390
391        let fwd_dfa = &self.fwd_dfa;
392        let rev_dfa = &self.rev_dfa;
393        let fwd_cache = &mut self.fwd_cache;
394        let rev_cache = &mut self.rev_cache;
395        std::iter::from_fn(move || {
396            let init = rev_input.end();
397            let h_start = loop {
398                if let Ok(Some(half)) = try_search_rev(rev_dfa, rev_cache, &mut rev_input) {
399                    // Ignore empty matches at the end of the input.
400                    if half.offset() == init {
401                        rev_input.set_end(init - 1);
402                    } else {
403                        break half.offset();
404                    }
405                } else {
406                    return None;
407                }
408            };
409
410            fwd_input.set_start(h_start);
411            rev_input.set_end(h_start);
412
413            let h_end = unsafe {
414                try_search_fwd(fwd_dfa, fwd_cache, &mut fwd_input)
415                    .unwrap()
416                    .unwrap_unchecked()
417                    .offset()
418            };
419
420            Some(range.start + h_start..range.start + h_end)
421        })
422    }
423
424    /// Whether or not the regex matches a specific pattern
425    pub fn matches(&mut self, cursor: impl Cursor) -> bool {
426        let total_bytes = cursor.total_bytes();
427
428        let mut input = Input::new(cursor);
429        input.anchored(Anchored::Yes);
430
431        let Ok(Some(half)) = try_search_fwd(self.fwd_dfa, &mut self.fwd_cache, &mut input) else {
432            return false;
433        };
434
435        total_bytes.is_some_and(|len| len == half.offset())
436    }
437
438    /// Whether or not there even is a pattern to search for
439    pub fn is_empty(&self) -> bool {
440        self.pat.is_empty()
441    }
442}
443
444struct DFAs {
445    fwd: (DFA, RwLock<Cache>),
446    rev: (DFA, RwLock<Cache>),
447}
448
449fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
450    static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
451        LazyLock::new(RwLock::default);
452
453    let mut list = DFA_LIST.write().unwrap();
454
455    let mut bytes = [0; 4];
456    let pat = pat.as_patterns(&mut bytes);
457
458    if let Some(dfas) = list.get(&pat) {
459        Ok(*dfas)
460    } else {
461        let pat = pat.leak();
462        let (fwd, rev) = pat.dfas()?;
463
464        let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
465        let dfas = Box::leak(Box::new(DFAs {
466            fwd: (fwd, RwLock::new(fwd_cache)),
467            rev: (rev, RwLock::new(rev_cache)),
468        }));
469        let _ = list.insert(pat, dfas);
470        Ok(dfas)
471    }
472}
473
474#[derive(Clone, Copy, PartialEq, Eq, Hash)]
475enum Patterns<'a> {
476    One(&'a str),
477    Many(&'a [&'a str]),
478}
479
480impl Patterns<'_> {
481    fn leak(&self) -> Patterns<'static> {
482        match self {
483            Patterns::One(str) => Patterns::One(String::from(*str).leak()),
484            Patterns::Many(strs) => Patterns::Many(
485                strs.iter()
486                    .map(|s| {
487                        let str: &'static str = s.to_string().leak();
488                        str
489                    })
490                    .collect::<Vec<&'static str>>()
491                    .leak(),
492            ),
493        }
494    }
495
496    fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
497        let mut fwd_builder = DFA::builder();
498        fwd_builder.syntax(syntax::Config::new().multi_line(true));
499        let mut rev_builder = DFA::builder();
500        rev_builder
501            .syntax(syntax::Config::new().multi_line(true))
502            .thompson(thompson::Config::new().reverse(true));
503
504        match self {
505            Patterns::One(pat) => {
506                let pat = pat.replace("\\b", "(?-u:\\b)");
507                syntax::parse(&pat)?;
508                let fwd = fwd_builder.build(&pat).unwrap();
509                let rev = rev_builder.build(&pat).unwrap();
510                Ok((fwd, rev))
511            }
512            Patterns::Many(pats) => {
513                let pats: Vec<String> =
514                    pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
515                for pat in pats.iter() {
516                    regex_syntax::Parser::new().parse(pat)?;
517                }
518                let fwd = fwd_builder.build_many(&pats).unwrap();
519                let rev = rev_builder.build_many(&pats).unwrap();
520                Ok((fwd, rev))
521            }
522        }
523    }
524}
525
526/// A regex pattern to search for
527///
528/// It can either be a single `&str`, or a list of `&str`s, in which
529/// case the matched pattern will be specified.
530pub trait RegexPattern: InnerRegexPattern {
531    /// Eiter a [`Range<usize>`] or `(usize, Range<usize>)`
532    type Match: 'static;
533
534    /// transforms a matched pattern into [`RegexPattern::Match`]
535    fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match;
536}
537
538impl RegexPattern for &str {
539    type Match = Range<usize>;
540
541    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
542        points
543    }
544}
545
546impl RegexPattern for String {
547    type Match = Range<usize>;
548
549    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
550        points
551    }
552}
553
554impl RegexPattern for &String {
555    type Match = Range<usize>;
556
557    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
558        points
559    }
560}
561
562impl RegexPattern for char {
563    type Match = Range<usize>;
564
565    fn get_match(points: Range<usize>, _pattern: PatternID) -> Self::Match {
566        points
567    }
568}
569
570impl<const N: usize> RegexPattern for [&str; N] {
571    type Match = (usize, Range<usize>);
572
573    fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
574        (pattern.as_usize(), points)
575    }
576}
577
578impl RegexPattern for &[&str] {
579    type Match = (usize, Range<usize>);
580
581    fn get_match(points: Range<usize>, pattern: PatternID) -> Self::Match {
582        (pattern.as_usize(), points)
583    }
584}
585
586trait InnerRegexPattern {
587    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
588}
589
590impl InnerRegexPattern for &str {
591    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
592        Patterns::One(self)
593    }
594}
595
596impl InnerRegexPattern for String {
597    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
598        Patterns::One(self)
599    }
600}
601
602impl InnerRegexPattern for &String {
603    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
604        Patterns::One(self)
605    }
606}
607
608impl InnerRegexPattern for char {
609    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
610        Patterns::One(self.encode_utf8(bytes) as &str)
611    }
612}
613
614impl<const N: usize> InnerRegexPattern for [&str; N] {
615    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
616        Patterns::Many(self)
617    }
618}
619
620impl InnerRegexPattern for &[&str] {
621    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
622        Patterns::Many(self)
623    }
624}
625
626#[derive(Clone, Copy)]
627struct SearchBytes<'a>([&'a [u8]; 2], usize);
628
629impl Cursor for SearchBytes<'_> {
630    fn chunk(&self) -> &[u8] {
631        self.0[self.1]
632    }
633
634    fn advance(&mut self) -> bool {
635        if self.1 == 0 {
636            self.1 += 1;
637            true
638        } else {
639            false
640        }
641    }
642
643    fn backtrack(&mut self) -> bool {
644        if self.1 == 1 {
645            self.1 -= 1;
646            true
647        } else {
648            false
649        }
650    }
651
652    fn total_bytes(&self) -> Option<usize> {
653        Some(self.0[0].len() + self.0[1].len())
654    }
655
656    fn offset(&self) -> usize {
657        match self.1 {
658            1 => self.0[0].len(),
659            _ => 0,
660        }
661    }
662}
663
664fn get_inputs(
665    bytes: &Bytes,
666    range: std::ops::Range<usize>,
667) -> (Input<SearchBytes<'_>>, Input<SearchBytes<'_>>) {
668    let haystack = SearchBytes(bytes.slices(range).to_array(), 0);
669    let fwd_input = Input::new(haystack);
670    let rev_input = Input::new(haystack);
671    (fwd_input, rev_input)
672}