cstree/syntax/
text.rs

1//! Efficient representation of the source text that is covered by a [`SyntaxNode`].
2
3use std::fmt;
4
5use crate::{
6    Syntax,
7    interning::{Resolver, TokenKey},
8    syntax::{SyntaxNode, SyntaxToken},
9    text::{TextRange, TextSize},
10};
11
12/// An efficient representation of the text that is covered by a [`SyntaxNode`], i.e. the combined
13/// source text of all tokens that are descendants of the node.
14///
15/// Offers methods to work with the text distributed across multiple [`SyntaxToken`]s while avoiding
16/// the construction of intermediate strings where possible.
17/// This includes efficient comparisons with itself and with strings and conversion `to_string()`.
18///
19/// # Example
20/// ```
21/// # use cstree::testing::*;
22/// # use cstree::syntax::ResolvedNode;
23/// #
24/// fn parse_float_literal(s: &str) -> ResolvedNode<MySyntax> {
25///     // parsing...
26/// #     let mut builder: GreenNodeBuilder<MySyntax> = GreenNodeBuilder::new();
27/// #     builder.start_node(Float);
28/// #     builder.token(Float, s);
29/// #     builder.finish_node();
30/// #     let (root, cache) = builder.finish();
31/// #     let resolver = cache.unwrap().into_interner().unwrap();
32/// #     SyntaxNode::new_root_with_resolver(root, resolver)
33/// }
34/// let float_node = parse_float_literal("2.748E2");
35/// let text = float_node.text();
36/// assert_eq!(text.len(), 7.into());
37/// assert!(text.contains_char('E'));
38/// assert_eq!(text.find_char('E'), Some(5.into()));
39/// assert_eq!(text.char_at(1.into()), Some('.'));
40/// let sub = text.slice(2.into()..5.into());
41/// assert_eq!(sub, "748");
42/// ```
43pub struct SyntaxText<'n, 'i, I: ?Sized, S: Syntax, D: 'static = ()> {
44    node:     &'n SyntaxNode<S, D>,
45    range:    TextRange,
46    resolver: &'i I,
47}
48
49impl<I: ?Sized, S: Syntax, D> Clone for SyntaxText<'_, '_, I, S, D> {
50    fn clone(&self) -> Self {
51        *self
52    }
53}
54
55impl<I: ?Sized, S: Syntax, D> Copy for SyntaxText<'_, '_, I, S, D> {}
56
57impl<'n, 'i, I: Resolver<TokenKey> + ?Sized, S: Syntax, D> SyntaxText<'n, 'i, I, S, D> {
58    pub(crate) fn new(node: &'n SyntaxNode<S, D>, resolver: &'i I) -> Self {
59        let range = node.text_range();
60        SyntaxText { node, range, resolver }
61    }
62
63    /// The combined length of this text, in bytes.
64    pub fn len(&self) -> TextSize {
65        self.range.len()
66    }
67
68    /// Returns `true` if [`self.len()`](SyntaxText::len) is zero.
69    pub fn is_empty(&self) -> bool {
70        self.range.is_empty()
71    }
72
73    /// Returns `true` if `c` appears anywhere in this text.
74    pub fn contains_char(&self, c: char) -> bool {
75        self.try_for_each_chunk(|chunk| if chunk.contains(c) { Err(()) } else { Ok(()) })
76            .is_err()
77    }
78
79    /// If `self.contains_char(c)`, returns `Some(pos)`, where `pos` is the byte position of the
80    /// first appearance of `c`. Otherwise, returns `None`.
81    pub fn find_char(&self, c: char) -> Option<TextSize> {
82        let mut acc: TextSize = 0.into();
83        let res = self.try_for_each_chunk(|chunk| {
84            if let Some(pos) = chunk.find(c) {
85                let pos: TextSize = (pos as u32).into();
86                return Err(acc + pos);
87            }
88            acc += TextSize::of(chunk);
89            Ok(())
90        });
91        found(res)
92    }
93
94    /// If `offset < self.len()`, returns `Some(c)`, where `c` is the first `char` at or after
95    /// `offset` (in bytes). Otherwise, returns `None`.
96    pub fn char_at(&self, offset: TextSize) -> Option<char> {
97        let mut start: TextSize = 0.into();
98        let res = self.try_for_each_chunk(|chunk| {
99            let end = start + TextSize::of(chunk);
100            if start <= offset && offset < end {
101                let off: usize = u32::from(offset - start) as usize;
102                return Err(chunk[off..].chars().next().unwrap());
103            }
104            start = end;
105            Ok(())
106        });
107        found(res)
108    }
109
110    /// Indexes this text by the given `range` and returns a `SyntaxText` that represents the
111    /// corresponding slice of this text.
112    ///
113    /// # Panics
114    /// The end of `range` must be equal of higher than its start.
115    /// Further, `range` must be contained within `0..self.len()`.
116    pub fn slice<Ra: private::SyntaxTextRange>(&self, range: Ra) -> Self {
117        let start = range.start().unwrap_or_default();
118        let end = range.end().unwrap_or_else(|| self.len());
119        assert!(start <= end);
120        let len = end - start;
121        let start = self.range.start() + start;
122        let end = start + len;
123        assert!(
124            start <= end,
125            "invalid slice, range: {:?}, slice: {:?}",
126            self.range,
127            (range.start(), range.end()),
128        );
129        let range = TextRange::new(start, end);
130        assert!(
131            self.range.contains_range(range),
132            "invalid slice, range: {:?}, slice: {:?}",
133            self.range,
134            range,
135        );
136        SyntaxText {
137            node: self.node,
138            range,
139            resolver: self.resolver,
140        }
141    }
142
143    /// Applies the given function to text chunks (from [`SyntaxToken`]s) that are part of this text
144    /// as long as it returns `Ok`, starting from the initial value `init`.
145    ///
146    /// If `f` returns `Err`, the error is propagated immediately.
147    /// Otherwise, the result of the current call to `f` will be passed to the invocation of `f` on
148    /// the next token, producing a final value if `f` succeeds on all chunks.
149    ///
150    /// See also [`fold_chunks`](SyntaxText::fold_chunks) for folds that always succeed.
151    pub fn try_fold_chunks<T, F, E>(&self, init: T, mut f: F) -> Result<T, E>
152    where
153        F: FnMut(T, &'i str) -> Result<T, E>,
154    {
155        self.tokens_with_ranges().try_fold(init, move |acc, (token, range)| {
156            f(acc, &token.resolve_text(self.resolver)[range])
157        })
158    }
159
160    /// Applies the given function to all text chunks (from [`SyntaxToken`]s) that are part of this
161    /// text, starting from the initial value `init`.
162    ///
163    /// The result of the current call to `f` will be passed to the invocation of `f` on the next
164    /// token, producing a final value after `f` was called on all chunks.
165    ///
166    /// See also [`try_fold_chunks`](SyntaxText::try_fold_chunks), which performs the same operation
167    /// for fallible functions `f`.
168    pub fn fold_chunks<T, F>(&self, init: T, mut f: F) -> T
169    where
170        F: FnMut(T, &str) -> T,
171    {
172        enum Void {}
173        match self.try_fold_chunks(init, |acc, chunk| Ok::<T, Void>(f(acc, chunk))) {
174            Ok(t) => t,
175            Err(void) => match void {},
176        }
177    }
178
179    /// Applies the given function to all text chunks that this text is comprised of, in order,
180    /// as long as `f` completes successfully.
181    ///
182    /// If `f` returns `Err`, this method returns immediately and will not apply `f` to any further
183    /// chunks.
184    ///
185    /// See also [`try_fold_chunks`](SyntaxText::try_fold_chunks).
186    pub fn try_for_each_chunk<F: FnMut(&str) -> Result<(), E>, E>(&self, mut f: F) -> Result<(), E> {
187        self.try_fold_chunks((), move |(), chunk| f(chunk))
188    }
189
190    /// Applies the given function to all text chunks that this text is comprised of, in order.
191    ///
192    /// See also [`fold_chunks`](SyntaxText::fold_chunks),
193    /// [`try_for_each_chunk`](SyntaxText::try_for_each_chunk).
194    pub fn for_each_chunk<F: FnMut(&str)>(&self, mut f: F) {
195        self.fold_chunks((), |(), chunk| f(chunk))
196    }
197
198    fn tokens_with_ranges(&self) -> impl Iterator<Item = (&'n SyntaxToken<S, D>, TextRange)> + use<'i, 'n, I, S, D> {
199        let text_range = self.range;
200        self.node
201            .descendants_with_tokens()
202            .filter_map(|element| element.into_token())
203            .filter_map(move |token| {
204                let token_range = token.text_range();
205                let range = text_range.intersect(token_range)?;
206                Some((token, range - token_range.start()))
207            })
208    }
209}
210
211#[inline]
212fn found<T>(res: Result<(), T>) -> Option<T> {
213    res.err()
214}
215
216impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> fmt::Debug for SyntaxText<'_, '_, I, S, D> {
217    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
218        fmt::Debug::fmt(&self.to_string(), f)
219    }
220}
221
222impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> fmt::Display for SyntaxText<'_, '_, I, S, D> {
223    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
224        self.try_for_each_chunk(|chunk| fmt::Display::fmt(chunk, f))
225    }
226}
227
228impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> From<SyntaxText<'_, '_, I, S, D>> for String {
229    fn from(text: SyntaxText<'_, '_, I, S, D>) -> String {
230        text.to_string()
231    }
232}
233
234impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> PartialEq<str> for SyntaxText<'_, '_, I, S, D> {
235    fn eq(&self, mut rhs: &str) -> bool {
236        self.try_for_each_chunk(|chunk| {
237            if !rhs.starts_with(chunk) {
238                return Err(());
239            }
240            rhs = &rhs[chunk.len()..];
241            Ok(())
242        })
243        .is_ok()
244            && rhs.is_empty()
245    }
246}
247
248impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> PartialEq<SyntaxText<'_, '_, I, S, D>> for str {
249    fn eq(&self, rhs: &SyntaxText<'_, '_, I, S, D>) -> bool {
250        rhs == self
251    }
252}
253
254impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> PartialEq<&'_ str> for SyntaxText<'_, '_, I, S, D> {
255    fn eq(&self, rhs: &&str) -> bool {
256        self == *rhs
257    }
258}
259
260impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> PartialEq<SyntaxText<'_, '_, I, S, D>> for &'_ str {
261    fn eq(&self, rhs: &SyntaxText<'_, '_, I, S, D>) -> bool {
262        rhs == self
263    }
264}
265
266impl<I1, I2, S1, S2, D1, D2> PartialEq<SyntaxText<'_, '_, I2, S2, D2>> for SyntaxText<'_, '_, I1, S1, D1>
267where
268    S1: Syntax,
269    S2: Syntax,
270    I1: Resolver<TokenKey> + ?Sized,
271    I2: Resolver<TokenKey> + ?Sized,
272{
273    fn eq(&self, other: &SyntaxText<'_, '_, I2, S2, D2>) -> bool {
274        if self.range.len() != other.range.len() {
275            return false;
276        }
277        let mut lhs = self.tokens_with_ranges();
278        let mut rhs = other.tokens_with_ranges();
279        zip_texts(&mut lhs, &mut rhs, self.resolver, other.resolver).is_none()
280            && lhs.all(|it| it.1.is_empty())
281            && rhs.all(|it| it.1.is_empty())
282    }
283}
284
285fn zip_texts<'it1, 'it2, It1, It2, I1, I2, S1, S2, D1, D2>(
286    xs: &mut It1,
287    ys: &mut It2,
288    resolver_x: &I1,
289    resolver_y: &I2,
290) -> Option<()>
291where
292    It1: Iterator<Item = (&'it1 SyntaxToken<S1, D1>, TextRange)>,
293    It2: Iterator<Item = (&'it2 SyntaxToken<S2, D2>, TextRange)>,
294    I1: Resolver<TokenKey> + ?Sized,
295    I2: Resolver<TokenKey> + ?Sized,
296    D1: 'static,
297    D2: 'static,
298    S1: Syntax + 'it1,
299    S2: Syntax + 'it2,
300{
301    let mut x = xs.next()?;
302    let mut y = ys.next()?;
303    loop {
304        while x.1.is_empty() {
305            x = xs.next()?;
306        }
307        while y.1.is_empty() {
308            y = ys.next()?;
309        }
310        let x_text = &x.0.resolve_text(resolver_x)[x.1];
311        let y_text = &y.0.resolve_text(resolver_y)[y.1];
312        if !(x_text.starts_with(y_text) || y_text.starts_with(x_text)) {
313            return Some(());
314        }
315        let advance = std::cmp::min(x.1.len(), y.1.len());
316        x.1 = TextRange::new(x.1.start() + advance, x.1.end());
317        y.1 = TextRange::new(y.1.start() + advance, y.1.end());
318    }
319}
320
321impl<I: Resolver<TokenKey> + ?Sized, S: Syntax, D> Eq for SyntaxText<'_, '_, I, S, D> {}
322
323mod private {
324    use std::ops;
325
326    use crate::text::{TextRange, TextSize};
327
328    pub trait SyntaxTextRange {
329        fn start(&self) -> Option<TextSize>;
330        fn end(&self) -> Option<TextSize>;
331    }
332
333    impl SyntaxTextRange for TextRange {
334        fn start(&self) -> Option<TextSize> {
335            Some(TextRange::start(*self))
336        }
337
338        fn end(&self) -> Option<TextSize> {
339            Some(TextRange::end(*self))
340        }
341    }
342
343    impl SyntaxTextRange for ops::Range<TextSize> {
344        fn start(&self) -> Option<TextSize> {
345            Some(self.start)
346        }
347
348        fn end(&self) -> Option<TextSize> {
349            Some(self.end)
350        }
351    }
352
353    impl SyntaxTextRange for ops::RangeFrom<TextSize> {
354        fn start(&self) -> Option<TextSize> {
355            Some(self.start)
356        }
357
358        fn end(&self) -> Option<TextSize> {
359            None
360        }
361    }
362
363    impl SyntaxTextRange for ops::RangeTo<TextSize> {
364        fn start(&self) -> Option<TextSize> {
365            None
366        }
367
368        fn end(&self) -> Option<TextSize> {
369            Some(self.end)
370        }
371    }
372
373    impl SyntaxTextRange for ops::RangeFull {
374        fn start(&self) -> Option<TextSize> {
375            None
376        }
377
378        fn end(&self) -> Option<TextSize> {
379            None
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use crate::{RawSyntaxKind, build::GreenNodeBuilder, interning::TokenInterner};
387
388    use super::*;
389
390    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
391    #[repr(transparent)]
392    pub struct SyntaxKind(u32);
393
394    impl Syntax for SyntaxKind {
395        fn from_raw(raw: RawSyntaxKind) -> Self {
396            Self(raw.0)
397        }
398
399        fn into_raw(self) -> RawSyntaxKind {
400            RawSyntaxKind(self.0)
401        }
402
403        fn static_text(self) -> Option<&'static str> {
404            match self.0 {
405                1 => Some("{"),
406                2 => Some("}"),
407                _ => None,
408            }
409        }
410    }
411
412    fn build_tree(chunks: &[&str]) -> (SyntaxNode<SyntaxKind, ()>, impl Resolver<TokenKey> + use<>) {
413        let mut builder: GreenNodeBuilder<SyntaxKind> = GreenNodeBuilder::new();
414        builder.start_node(SyntaxKind(62));
415        for &chunk in chunks.iter() {
416            let kind = match chunk {
417                "{" => 1,
418                "}" => 2,
419                _ => 3,
420            };
421            builder.token(SyntaxKind(kind), chunk);
422        }
423        builder.finish_node();
424        let (node, cache) = builder.finish();
425        (SyntaxNode::new_root(node), cache.unwrap().into_interner().unwrap())
426    }
427
428    #[test]
429    fn test_text_equality() {
430        fn do_check(t1: &[&str], t2: &[&str]) {
431            let (t1, resolver) = build_tree(t1);
432            let t1 = t1.resolve_text(&resolver);
433            let (t2, resolver) = build_tree(t2);
434            let t2 = t2.resolve_text(&resolver);
435            let expected = t1.to_string() == t2.to_string();
436            let actual = t1 == t2;
437            assert_eq!(expected, actual, "`{t1}` (SyntaxText) `{t2}` (SyntaxText)");
438            let actual = t1 == t2.to_string().as_str();
439            assert_eq!(expected, actual, "`{t1}` (SyntaxText) `{t2}` (&str)");
440        }
441        fn check(t1: &[&str], t2: &[&str]) {
442            do_check(t1, t2);
443            do_check(t2, t1)
444        }
445
446        check(&[""], &[""]);
447        check(&["a"], &[""]);
448        check(&["a"], &["a"]);
449        check(&["abc"], &["def"]);
450        check(&["hello", "world"], &["hello", "world"]);
451        check(&["hellowo", "rld"], &["hell", "oworld"]);
452        check(&["hel", "lowo", "rld"], &["helloworld"]);
453        check(&["{", "abc", "}"], &["{", "123", "}"]);
454        check(&["{", "abc", "}", "{"], &["{", "123", "}"]);
455        check(&["{", "abc", "}"], &["{", "123", "}", "{"]);
456        check(&["{", "abc", "}ab"], &["{", "abc", "}", "ab"]);
457    }
458
459    #[allow(dead_code)]
460    mod impl_asserts {
461        use super::*;
462
463        struct NotClone;
464
465        fn assert_copy<C: Copy>() {}
466
467        fn test_impls_copy() {
468            assert_copy::<SyntaxText<TokenInterner, SyntaxKind, NotClone>>();
469        }
470    }
471}