duat_core/text/
search.rs

1//! Utilities for searching the [`Text`]
2//!
3//! This includes some methods for the [`Text`] itself, meant for
4//! general use when editing it. It also has the [`Searcher`] struct,
5//! which is used when doing [incremental search] in the [`CmdLine`].
6//! This iterator is then used in a [`IncSearcher`] that can decide
7//! what to do with the results.
8//!
9//! [incremental search]: crate::widget::IncSearch
10//! [`CmdLine`]: crate::widget::CmdLine
11//! [`IncSearcher`]: crate::mode::IncSearcher
12use std::{
13    collections::HashMap,
14    ops::RangeBounds,
15    sync::{LazyLock, RwLock, RwLockWriteGuard},
16};
17
18use regex_automata::{
19    Anchored, Input, PatternID,
20    hybrid::dfa::{Cache, DFA},
21    nfa::thompson,
22    util::syntax,
23};
24
25use super::{Point, Text, TextRange};
26
27impl Text {
28    /// Searches forward for a [`RegexPattern`] in a [range]
29    ///
30    /// A [`RegexPattern`] can either be a single regex string, an
31    /// array of strings, or a slice of strings. When there are more
32    /// than one pattern, The return value will include which pattern
33    /// matched.
34    ///
35    /// The patterns will also automatically be cached, so you don't
36    /// need to do that.
37    ///
38    /// [range]: TextRange
39    pub fn search_fwd<R: RegexPattern>(
40        &mut self,
41        pat: R,
42        range: impl TextRange,
43    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
44        let bytes = self.bytes_mut();
45        let range = range.to_range(bytes.len().byte());
46        let dfas = dfas_from_pat(pat)?;
47        let haystack = {
48            bytes.make_contiguous(range.clone());
49            bytes.get_contiguous(range.clone()).unwrap()
50        };
51
52        let mut fwd_input = Input::new(haystack);
53        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
54        let mut fwd_cache = dfas.fwd.1.write().unwrap();
55        let mut rev_cache = dfas.rev.1.write().unwrap();
56
57        let bytes = bytes as &super::Bytes;
58        Ok(std::iter::from_fn(move || {
59            let init = fwd_input.start();
60            let h_end = loop {
61                if let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
62                    // Ignore empty matches at the start of the input.
63                    if half.offset() == init {
64                        fwd_input.set_start(init + 1);
65                    } else {
66                        break half.offset();
67                    }
68                } else {
69                    return None;
70                }
71            };
72
73            fwd_input.set_start(h_end);
74            rev_input.set_end(h_end);
75
76            let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) else {
77                return None;
78            };
79            let h_start = half.offset();
80
81            let p0 = bytes.point_at(h_start + range.start);
82            let p1 = bytes.point_at(h_end + range.start);
83
84            Some(R::get_match([p0, p1], half.pattern()))
85        }))
86    }
87
88    /// Searches in reverse for a [`RegexPattern`] in a [range]
89    ///
90    /// A [`RegexPattern`] can either be a single regex string, an
91    /// array of strings, or a slice of strings. When there are more
92    /// than one pattern, The return value will include which pattern
93    /// matched.
94    ///
95    /// The patterns will also automatically be cached, so you don't
96    /// need to do that.
97    ///
98    /// [range]: TextRange
99    pub fn search_rev<R: RegexPattern>(
100        &mut self,
101        pat: R,
102        range: impl TextRange,
103    ) -> Result<impl Iterator<Item = R::Match> + '_, Box<regex_syntax::Error>> {
104        let bytes = self.bytes_mut();
105        let range = range.to_range(bytes.len().byte());
106        let dfas = dfas_from_pat(pat)?;
107        let haystack = {
108            bytes.make_contiguous(range.clone());
109            bytes.get_contiguous(range.clone()).unwrap()
110        };
111
112        let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
113        let mut rev_input = Input::new(haystack);
114        let mut fwd_cache = dfas.fwd.1.write().unwrap();
115        let mut rev_cache = dfas.rev.1.write().unwrap();
116
117        let bytes = bytes as &super::Bytes;
118        let gap = range.start;
119        Ok(std::iter::from_fn(move || {
120            let init = rev_input.end();
121            let start = loop {
122                if let Ok(Some(half)) = dfas.rev.0.try_search_rev(&mut rev_cache, &rev_input) {
123                    // Ignore empty matches at the end of the input.
124                    if half.offset() == init {
125                        rev_input.set_end(init.checked_sub(1)?);
126                    } else {
127                        break half.offset();
128                    }
129                } else {
130                    return None;
131                }
132            };
133
134            rev_input.set_end(start);
135            fwd_input.set_start(start);
136
137            let Ok(Some(half)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) else {
138                return None;
139            };
140            let end = half.offset();
141
142            let p0 = bytes.point_at(start + gap);
143            let p1 = bytes.point_at(end + gap);
144
145            Some(R::get_match([p0, p1], half.pattern()))
146        }))
147    }
148
149    /// Returns true if the pattern is found in the given range
150    ///
151    /// This is unanchored by default, if you want an anchored search,
152    /// use the `"^$"` characters.
153    pub fn matches(
154        &mut self,
155        pat: impl RegexPattern,
156        range: impl TextRange,
157    ) -> Result<bool, Box<regex_syntax::Error>> {
158        let range = range.to_range(self.len().byte());
159        let dfas = dfas_from_pat(pat)?;
160
161        let haystack = self.contiguous(range);
162        let fwd_input = Input::new(haystack);
163
164        let mut fwd_cache = dfas.fwd.1.write().unwrap();
165        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
166            Ok(true)
167        } else {
168            Ok(false)
169        }
170    }
171}
172
173/// A trait to match regexes on `&str`s
174pub trait Matcheable: Sized {
175    /// Checks if a type matches a [`RegexPattern`]
176    fn matches(
177        &self,
178        pat: impl RegexPattern,
179        range: impl RangeBounds<usize> + Clone,
180    ) -> Result<bool, Box<regex_syntax::Error>>;
181}
182
183impl<const N: usize> Matcheable for std::array::IntoIter<&str, N> {
184    fn matches(
185        &self,
186        pat: impl RegexPattern,
187        range: impl RangeBounds<usize> + Clone,
188    ) -> Result<bool, Box<regex_syntax::Error>> {
189        let str: String = self.as_slice().iter().copied().collect();
190        let (start, end) = crate::get_ends(range, str.len());
191        let dfas = dfas_from_pat(pat)?;
192        let fwd_input =
193            Input::new(unsafe { std::str::from_utf8_unchecked(&str.as_bytes()[start..end]) });
194
195        let mut fwd_cache = dfas.fwd.1.write().unwrap();
196        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
197            Ok(true)
198        } else {
199            Ok(false)
200        }
201    }
202}
203
204impl Matcheable for &'_ str {
205    fn matches(
206        &self,
207        pat: impl RegexPattern,
208        range: impl RangeBounds<usize> + Clone,
209    ) -> Result<bool, Box<regex_syntax::Error>> {
210        let (start, end) = crate::get_ends(range, self.len());
211        let dfas = dfas_from_pat(pat)?;
212        let fwd_input =
213            Input::new(unsafe { std::str::from_utf8_unchecked(&self.as_bytes()[start..end]) });
214
215        let mut fwd_cache = dfas.fwd.1.write().unwrap();
216        if let Ok(Some(_)) = dfas.fwd.0.try_search_fwd(&mut fwd_cache, &fwd_input) {
217            Ok(true)
218        } else {
219            Ok(false)
220        }
221    }
222}
223
224/// A struct for incremental searching in [`IncSearch`]
225///
226/// [`IncSearch`]: docs.rs/duat-utils/latest/duat_utils/modes/struct.IncSearch.html
227pub struct Searcher {
228    pat: String,
229    fwd_dfa: &'static DFA,
230    rev_dfa: &'static DFA,
231    fwd_cache: RwLockWriteGuard<'static, Cache>,
232    rev_cache: RwLockWriteGuard<'static, Cache>,
233}
234
235impl Searcher {
236    /// Returns a new [`Searcher`]
237    pub fn new(pat: String) -> Result<Self, Box<regex_syntax::Error>> {
238        let dfas = dfas_from_pat(&pat)?;
239        Ok(Self {
240            pat,
241            fwd_dfa: &dfas.fwd.0,
242            rev_dfa: &dfas.rev.0,
243            fwd_cache: dfas.fwd.1.write().unwrap(),
244            rev_cache: dfas.rev.1.write().unwrap(),
245        })
246    }
247
248    /// Searches forward for the required regex in a [range]
249    ///
250    /// [range]: TextRange
251    pub fn search_fwd<'b>(
252        &'b mut self,
253        text: &'b mut Text,
254        range: impl TextRange,
255    ) -> impl Iterator<Item = [Point; 2]> + 'b {
256        let range = range.to_range(text.len().byte());
257        let mut last_point = text.point_at(range.start);
258
259        let haystack = text.contiguous(range.clone());
260        let mut fwd_input = Input::new(haystack).anchored(Anchored::No);
261        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
262
263        let fwd_dfa = &self.fwd_dfa;
264        let rev_dfa = &self.rev_dfa;
265        let fwd_cache = &mut self.fwd_cache;
266        let rev_cache = &mut self.rev_cache;
267        let gap = range.start;
268        std::iter::from_fn(move || {
269            let init = fwd_input.start();
270            let end = loop {
271                if let Ok(Some(half)) = fwd_dfa.try_search_fwd(fwd_cache, &fwd_input) {
272                    // Ignore empty matches at the start of the input.
273                    if half.offset() == init {
274                        fwd_input.set_start(init + 1);
275                    } else {
276                        break half.offset();
277                    }
278                } else {
279                    return None;
280                }
281            };
282
283            fwd_input.set_start(end);
284            rev_input.set_end(end);
285
286            let half = unsafe {
287                rev_dfa
288                    .try_search_rev(rev_cache, &rev_input)
289                    .unwrap()
290                    .unwrap_unchecked()
291            };
292            let start = half.offset();
293
294            let start = unsafe {
295                std::str::from_utf8_unchecked(&haystack.as_bytes()[last_point.byte() - gap..start])
296            }
297            .chars()
298            .fold(last_point, |p, b| p.fwd(b));
299            let end = unsafe {
300                std::str::from_utf8_unchecked(&haystack.as_bytes()[start.byte() - gap..end])
301            }
302            .chars()
303            .fold(start, |p, b| p.fwd(b));
304
305            last_point = end;
306
307            Some([start, end])
308        })
309    }
310
311    /// Searches in reverse for the required regex in a range[range]
312    ///
313    /// [range]: TextRange
314    pub fn search_rev<'b>(
315        &'b mut self,
316        text: &'b mut Text,
317        range: impl TextRange,
318    ) -> impl Iterator<Item = [Point; 2]> + 'b {
319        let range = range.to_range(text.len().byte());
320        let mut last_point = text.point_at(range.end);
321
322        let haystack = text.contiguous(range.clone());
323        let mut fwd_input = Input::new(haystack).anchored(Anchored::Yes);
324        let mut rev_input = Input::new(haystack).anchored(Anchored::Yes);
325
326        let fwd_dfa = &self.fwd_dfa;
327        let rev_dfa = &self.rev_dfa;
328        let fwd_cache = &mut self.fwd_cache;
329        let rev_cache = &mut self.rev_cache;
330        let gap = range.start;
331        std::iter::from_fn(move || {
332            let init = rev_input.end();
333            let start = loop {
334                if let Ok(Some(half)) = rev_dfa.try_search_rev(rev_cache, &rev_input) {
335                    // Ignore empty matches at the end of the input.
336                    if half.offset() == init {
337                        rev_input.set_end(init - 1);
338                    } else {
339                        break half.offset();
340                    }
341                } else {
342                    return None;
343                }
344            };
345
346            fwd_input.set_start(start);
347            rev_input.set_end(start);
348
349            let half = fwd_dfa
350                .try_search_fwd(fwd_cache, &fwd_input)
351                .unwrap()
352                .unwrap();
353
354            let end = unsafe {
355                std::str::from_utf8_unchecked(
356                    &haystack.as_bytes()[half.offset()..(last_point.byte() - gap)],
357                )
358            }
359            .chars()
360            .fold(last_point, |p, b| p.rev(b));
361            let start = unsafe {
362                std::str::from_utf8_unchecked(&haystack.as_bytes()[start..(end.byte() - gap)])
363            }
364            .chars()
365            .fold(end, |p, b| p.rev(b));
366
367            last_point = start;
368
369            Some([start, end])
370        })
371    }
372
373    /// Whether or not the regex matches a specific pattern
374    pub fn matches(&mut self, query: impl AsRef<[u8]>) -> bool {
375        let input = Input::new(&query).anchored(Anchored::Yes);
376
377        let Ok(Some(half)) = self.fwd_dfa.try_search_fwd(&mut self.fwd_cache, &input) else {
378            return false;
379        };
380
381        half.offset() == query.as_ref().len()
382    }
383
384    /// Whether or not there even is a pattern to search for
385    pub fn is_empty(&self) -> bool {
386        self.pat.is_empty()
387    }
388}
389
390struct DFAs {
391    fwd: (DFA, RwLock<Cache>),
392    rev: (DFA, RwLock<Cache>),
393}
394
395fn dfas_from_pat(pat: impl RegexPattern) -> Result<&'static DFAs, Box<regex_syntax::Error>> {
396    static DFA_LIST: LazyLock<RwLock<HashMap<Patterns<'static>, &'static DFAs>>> =
397        LazyLock::new(RwLock::default);
398
399    let mut list = DFA_LIST.write().unwrap();
400
401    let mut bytes = [0; 4];
402    let pat = pat.as_patterns(&mut bytes);
403
404    if let Some(dfas) = list.get(&pat) {
405        Ok(*dfas)
406    } else {
407        let pat = pat.leak();
408        let (fwd, rev) = pat.dfas()?;
409
410        let (fwd_cache, rev_cache) = (Cache::new(&fwd), Cache::new(&rev));
411        let dfas = Box::leak(Box::new(DFAs {
412            fwd: (fwd, RwLock::new(fwd_cache)),
413            rev: (rev, RwLock::new(rev_cache)),
414        }));
415        let _ = list.insert(pat, dfas);
416        Ok(dfas)
417    }
418}
419
420#[derive(Clone, Copy, PartialEq, Eq, Hash)]
421enum Patterns<'a> {
422    One(&'a str),
423    Many(&'a [&'a str]),
424}
425
426impl Patterns<'_> {
427    fn leak(&self) -> Patterns<'static> {
428        match self {
429            Patterns::One(str) => Patterns::One(String::from(*str).leak()),
430            Patterns::Many(strs) => Patterns::Many(
431                strs.iter()
432                    .map(|s| {
433                        let str: &'static str = s.to_string().leak();
434                        str
435                    })
436                    .collect::<Vec<&'static str>>()
437                    .leak(),
438            ),
439        }
440    }
441
442    fn dfas(&self) -> Result<(DFA, DFA), Box<regex_syntax::Error>> {
443        let mut fwd_builder = DFA::builder();
444        fwd_builder.syntax(syntax::Config::new().multi_line(true));
445        let mut rev_builder = DFA::builder();
446        rev_builder
447            .syntax(syntax::Config::new().multi_line(true))
448            .thompson(thompson::Config::new().reverse(true));
449
450        match self {
451            Patterns::One(pat) => {
452                let pat = pat.replace("\\b", "(?-u:\\b)");
453                regex_syntax::Parser::new().parse(&pat)?;
454                let fwd = fwd_builder.build(&pat).unwrap();
455                let rev = rev_builder.build(&pat).unwrap();
456                Ok((fwd, rev))
457            }
458            Patterns::Many(pats) => {
459                let pats: Vec<String> =
460                    pats.iter().map(|p| p.replace("\\b", "(?-u:\\b)")).collect();
461                for pat in pats.iter() {
462                    regex_syntax::Parser::new().parse(pat)?;
463                }
464                let fwd = fwd_builder.build_many(&pats).unwrap();
465                let rev = rev_builder.build_many(&pats).unwrap();
466                Ok((fwd, rev))
467            }
468        }
469    }
470}
471
472/// A regex pattern to search for
473///
474/// It can either be a single `&str`, or a list of `&str`s, in which
475/// case the matched pattern will be specified.
476pub trait RegexPattern: InnerRegexPattern {
477    /// Either two [`Point`]s, or two [`Point`]s and a match index
478    type Match: 'static;
479
480    /// transforms a matched pattern into [`RegexPattern::Match`]
481    fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match;
482}
483
484impl RegexPattern for &str {
485    type Match = [Point; 2];
486
487    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
488        points
489    }
490}
491
492impl RegexPattern for String {
493    type Match = [Point; 2];
494
495    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
496        points
497    }
498}
499
500impl RegexPattern for &String {
501    type Match = [Point; 2];
502
503    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
504        points
505    }
506}
507
508impl RegexPattern for char {
509    type Match = [Point; 2];
510
511    fn get_match(points: [Point; 2], _pattern: PatternID) -> Self::Match {
512        points
513    }
514}
515
516impl<const N: usize> RegexPattern for [&str; N] {
517    type Match = (usize, [Point; 2]);
518
519    fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
520        (pattern.as_usize(), points)
521    }
522}
523
524impl RegexPattern for &[&str] {
525    type Match = (usize, [Point; 2]);
526
527    fn get_match(points: [Point; 2], pattern: PatternID) -> Self::Match {
528        (pattern.as_usize(), points)
529    }
530}
531
532trait InnerRegexPattern {
533    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b>;
534}
535
536impl InnerRegexPattern for &str {
537    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
538        Patterns::One(self)
539    }
540}
541
542impl InnerRegexPattern for String {
543    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
544        Patterns::One(self)
545    }
546}
547
548impl InnerRegexPattern for &String {
549    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
550        Patterns::One(self)
551    }
552}
553
554impl InnerRegexPattern for char {
555    fn as_patterns<'b>(&'b self, bytes: &'b mut [u8; 4]) -> Patterns<'b> {
556        Patterns::One(self.encode_utf8(bytes) as &str)
557    }
558}
559
560impl<const N: usize> InnerRegexPattern for [&str; N] {
561    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
562        Patterns::Many(self)
563    }
564}
565
566impl InnerRegexPattern for &[&str] {
567    fn as_patterns<'b>(&'b self, _bytes: &'b mut [u8; 4]) -> Patterns<'b> {
568        Patterns::Many(self)
569    }
570}