ra_ap_syntax/ast/
token_ext.rs

1//! There are many AstNodes, but only a few tokens, so we hand-write them here.
2
3use std::{borrow::Cow, num::ParseIntError};
4
5use rustc_literal_escaper::{
6    EscapeError, MixedUnit, Mode, unescape_byte, unescape_char, unescape_mixed, unescape_unicode,
7};
8use stdx::always;
9
10use crate::{
11    TextRange, TextSize,
12    ast::{self, AstToken},
13};
14
15impl ast::Comment {
16    pub fn kind(&self) -> CommentKind {
17        CommentKind::from_text(self.text())
18    }
19
20    pub fn is_doc(&self) -> bool {
21        self.kind().doc.is_some()
22    }
23
24    pub fn is_inner(&self) -> bool {
25        self.kind().doc == Some(CommentPlacement::Inner)
26    }
27
28    pub fn is_outer(&self) -> bool {
29        self.kind().doc == Some(CommentPlacement::Outer)
30    }
31
32    pub fn prefix(&self) -> &'static str {
33        let &(prefix, _kind) = CommentKind::BY_PREFIX
34            .iter()
35            .find(|&(prefix, kind)| self.kind() == *kind && self.text().starts_with(prefix))
36            .unwrap();
37        prefix
38    }
39
40    /// Returns the textual content of a doc comment node as a single string with prefix and suffix
41    /// removed.
42    pub fn doc_comment(&self) -> Option<&str> {
43        let kind = self.kind();
44        match kind {
45            CommentKind { shape, doc: Some(_) } => {
46                let prefix = kind.prefix();
47                let text = &self.text()[prefix.len()..];
48                let text = if shape == CommentShape::Block {
49                    text.strip_suffix("*/").unwrap_or(text)
50                } else {
51                    text
52                };
53                Some(text)
54            }
55            _ => None,
56        }
57    }
58}
59
60#[derive(Debug, PartialEq, Eq, Clone, Copy)]
61pub struct CommentKind {
62    pub shape: CommentShape,
63    pub doc: Option<CommentPlacement>,
64}
65
66#[derive(Debug, PartialEq, Eq, Clone, Copy)]
67pub enum CommentShape {
68    Line,
69    Block,
70}
71
72impl CommentShape {
73    pub fn is_line(self) -> bool {
74        self == CommentShape::Line
75    }
76
77    pub fn is_block(self) -> bool {
78        self == CommentShape::Block
79    }
80}
81
82#[derive(Debug, PartialEq, Eq, Clone, Copy)]
83pub enum CommentPlacement {
84    Inner,
85    Outer,
86}
87
88impl CommentKind {
89    const BY_PREFIX: [(&'static str, CommentKind); 9] = [
90        ("/**/", CommentKind { shape: CommentShape::Block, doc: None }),
91        ("/***", CommentKind { shape: CommentShape::Block, doc: None }),
92        ("////", CommentKind { shape: CommentShape::Line, doc: None }),
93        ("///", CommentKind { shape: CommentShape::Line, doc: Some(CommentPlacement::Outer) }),
94        ("//!", CommentKind { shape: CommentShape::Line, doc: Some(CommentPlacement::Inner) }),
95        ("/**", CommentKind { shape: CommentShape::Block, doc: Some(CommentPlacement::Outer) }),
96        ("/*!", CommentKind { shape: CommentShape::Block, doc: Some(CommentPlacement::Inner) }),
97        ("//", CommentKind { shape: CommentShape::Line, doc: None }),
98        ("/*", CommentKind { shape: CommentShape::Block, doc: None }),
99    ];
100
101    pub(crate) fn from_text(text: &str) -> CommentKind {
102        let &(_prefix, kind) = CommentKind::BY_PREFIX
103            .iter()
104            .find(|&(prefix, _kind)| text.starts_with(prefix))
105            .unwrap();
106        kind
107    }
108
109    pub fn prefix(&self) -> &'static str {
110        let &(prefix, _) =
111            CommentKind::BY_PREFIX.iter().rev().find(|(_, kind)| kind == self).unwrap();
112        prefix
113    }
114}
115
116impl ast::Whitespace {
117    pub fn spans_multiple_lines(&self) -> bool {
118        let text = self.text();
119        text.find('\n').is_some_and(|idx| text[idx + 1..].contains('\n'))
120    }
121}
122
123#[derive(Debug)]
124pub struct QuoteOffsets {
125    pub quotes: (TextRange, TextRange),
126    pub contents: TextRange,
127}
128
129impl QuoteOffsets {
130    fn new(literal: &str) -> Option<QuoteOffsets> {
131        let left_quote = literal.find('"')?;
132        let right_quote = literal.rfind('"')?;
133        if left_quote == right_quote {
134            // `literal` only contains one quote
135            return None;
136        }
137
138        let start = TextSize::from(0);
139        let left_quote = TextSize::try_from(left_quote).unwrap() + TextSize::of('"');
140        let right_quote = TextSize::try_from(right_quote).unwrap();
141        let end = TextSize::of(literal);
142
143        let res = QuoteOffsets {
144            quotes: (TextRange::new(start, left_quote), TextRange::new(right_quote, end)),
145            contents: TextRange::new(left_quote, right_quote),
146        };
147        Some(res)
148    }
149}
150
151pub trait IsString: AstToken {
152    const RAW_PREFIX: &'static str;
153    const MODE: Mode;
154    fn is_raw(&self) -> bool {
155        self.text().starts_with(Self::RAW_PREFIX)
156    }
157    fn quote_offsets(&self) -> Option<QuoteOffsets> {
158        let text = self.text();
159        let offsets = QuoteOffsets::new(text)?;
160        let o = self.syntax().text_range().start();
161        let offsets = QuoteOffsets {
162            quotes: (offsets.quotes.0 + o, offsets.quotes.1 + o),
163            contents: offsets.contents + o,
164        };
165        Some(offsets)
166    }
167    fn text_range_between_quotes(&self) -> Option<TextRange> {
168        self.quote_offsets().map(|it| it.contents)
169    }
170    fn text_without_quotes(&self) -> &str {
171        let text = self.text();
172        let Some(offsets) = self.text_range_between_quotes() else { return text };
173        &text[offsets - self.syntax().text_range().start()]
174    }
175    fn open_quote_text_range(&self) -> Option<TextRange> {
176        self.quote_offsets().map(|it| it.quotes.0)
177    }
178    fn close_quote_text_range(&self) -> Option<TextRange> {
179        self.quote_offsets().map(|it| it.quotes.1)
180    }
181    fn escaped_char_ranges(&self, cb: &mut dyn FnMut(TextRange, Result<char, EscapeError>)) {
182        let Some(text_range_no_quotes) = self.text_range_between_quotes() else { return };
183
184        let start = self.syntax().text_range().start();
185        let text = &self.text()[text_range_no_quotes - start];
186        let offset = text_range_no_quotes.start() - start;
187
188        unescape_unicode(text, Self::MODE, &mut |range, unescaped_char| {
189            if let Some((s, e)) = range.start.try_into().ok().zip(range.end.try_into().ok()) {
190                cb(TextRange::new(s, e) + offset, unescaped_char);
191            }
192        });
193    }
194    fn map_range_up(&self, range: TextRange) -> Option<TextRange> {
195        let contents_range = self.text_range_between_quotes()?;
196        if always!(TextRange::up_to(contents_range.len()).contains_range(range)) {
197            Some(range + contents_range.start())
198        } else {
199            None
200        }
201    }
202}
203
204impl IsString for ast::String {
205    const RAW_PREFIX: &'static str = "r";
206    const MODE: Mode = Mode::Str;
207}
208
209impl ast::String {
210    pub fn value(&self) -> Result<Cow<'_, str>, EscapeError> {
211        let text = self.text();
212        let text_range = self.text_range_between_quotes().ok_or(EscapeError::LoneSlash)?;
213        let text = &text[text_range - self.syntax().text_range().start()];
214        if self.is_raw() {
215            return Ok(Cow::Borrowed(text));
216        }
217
218        let mut buf = String::new();
219        let mut prev_end = 0;
220        let mut has_error = None;
221        unescape_unicode(text, Self::MODE, &mut |char_range, unescaped_char| match (
222            unescaped_char,
223            buf.capacity() == 0,
224        ) {
225            (Ok(c), false) => buf.push(c),
226            (Ok(_), true) if char_range.len() == 1 && char_range.start == prev_end => {
227                prev_end = char_range.end
228            }
229            (Ok(c), true) => {
230                buf.reserve_exact(text.len());
231                buf.push_str(&text[..prev_end]);
232                buf.push(c);
233            }
234            (Err(e), _) => has_error = Some(e),
235        });
236
237        match (has_error, buf.capacity() == 0) {
238            (Some(e), _) => Err(e),
239            (None, true) => Ok(Cow::Borrowed(text)),
240            (None, false) => Ok(Cow::Owned(buf)),
241        }
242    }
243}
244
245impl IsString for ast::ByteString {
246    const RAW_PREFIX: &'static str = "br";
247    const MODE: Mode = Mode::ByteStr;
248}
249
250impl ast::ByteString {
251    pub fn value(&self) -> Result<Cow<'_, [u8]>, EscapeError> {
252        let text = self.text();
253        let text_range = self.text_range_between_quotes().ok_or(EscapeError::LoneSlash)?;
254        let text = &text[text_range - self.syntax().text_range().start()];
255        if self.is_raw() {
256            return Ok(Cow::Borrowed(text.as_bytes()));
257        }
258
259        let mut buf: Vec<u8> = Vec::new();
260        let mut prev_end = 0;
261        let mut has_error = None;
262        unescape_unicode(text, Self::MODE, &mut |char_range, unescaped_char| match (
263            unescaped_char,
264            buf.capacity() == 0,
265        ) {
266            (Ok(c), false) => buf.push(c as u8),
267            (Ok(_), true) if char_range.len() == 1 && char_range.start == prev_end => {
268                prev_end = char_range.end
269            }
270            (Ok(c), true) => {
271                buf.reserve_exact(text.len());
272                buf.extend_from_slice(&text.as_bytes()[..prev_end]);
273                buf.push(c as u8);
274            }
275            (Err(e), _) => has_error = Some(e),
276        });
277
278        match (has_error, buf.capacity() == 0) {
279            (Some(e), _) => Err(e),
280            (None, true) => Ok(Cow::Borrowed(text.as_bytes())),
281            (None, false) => Ok(Cow::Owned(buf)),
282        }
283    }
284}
285
286impl IsString for ast::CString {
287    const RAW_PREFIX: &'static str = "cr";
288    const MODE: Mode = Mode::CStr;
289
290    fn escaped_char_ranges(&self, cb: &mut dyn FnMut(TextRange, Result<char, EscapeError>)) {
291        let text_range_no_quotes = match self.text_range_between_quotes() {
292            Some(it) => it,
293            None => return,
294        };
295
296        let start = self.syntax().text_range().start();
297        let text = &self.text()[text_range_no_quotes - start];
298        let offset = text_range_no_quotes.start() - start;
299
300        unescape_mixed(text, Self::MODE, &mut |range, unescaped_char| {
301            let text_range =
302                TextRange::new(range.start.try_into().unwrap(), range.end.try_into().unwrap());
303            // XXX: This method should only be used for highlighting ranges. The unescaped
304            // char/byte is not used. For simplicity, we return an arbitrary placeholder char.
305            cb(text_range + offset, unescaped_char.map(|_| ' '));
306        });
307    }
308}
309
310impl ast::CString {
311    pub fn value(&self) -> Result<Cow<'_, [u8]>, EscapeError> {
312        let text = self.text();
313        let text_range = self.text_range_between_quotes().ok_or(EscapeError::LoneSlash)?;
314        let text = &text[text_range - self.syntax().text_range().start()];
315        if self.is_raw() {
316            return Ok(Cow::Borrowed(text.as_bytes()));
317        }
318
319        let mut buf = Vec::new();
320        let mut prev_end = 0;
321        let mut has_error = None;
322        let extend_unit = |buf: &mut Vec<u8>, unit: MixedUnit| match unit {
323            MixedUnit::Char(c) => buf.extend(c.encode_utf8(&mut [0; 4]).as_bytes()),
324            MixedUnit::HighByte(b) => buf.push(b),
325        };
326        unescape_mixed(text, Self::MODE, &mut |char_range, unescaped| match (
327            unescaped,
328            buf.capacity() == 0,
329        ) {
330            (Ok(u), false) => extend_unit(&mut buf, u),
331            (Ok(_), true) if char_range.len() == 1 && char_range.start == prev_end => {
332                prev_end = char_range.end
333            }
334            (Ok(u), true) => {
335                buf.reserve_exact(text.len());
336                buf.extend(&text.as_bytes()[..prev_end]);
337                extend_unit(&mut buf, u);
338            }
339            (Err(e), _) => has_error = Some(e),
340        });
341
342        match (has_error, buf.capacity() == 0) {
343            (Some(e), _) => Err(e),
344            (None, true) => Ok(Cow::Borrowed(text.as_bytes())),
345            (None, false) => Ok(Cow::Owned(buf)),
346        }
347    }
348}
349
350impl ast::IntNumber {
351    pub fn radix(&self) -> Radix {
352        match self.text().get(..2).unwrap_or_default() {
353            "0b" => Radix::Binary,
354            "0o" => Radix::Octal,
355            "0x" => Radix::Hexadecimal,
356            _ => Radix::Decimal,
357        }
358    }
359
360    pub fn split_into_parts(&self) -> (&str, &str, &str) {
361        let radix = self.radix();
362        let (prefix, mut text) = self.text().split_at(radix.prefix_len());
363
364        let is_suffix_start: fn(&(usize, char)) -> bool = match radix {
365            Radix::Hexadecimal => |(_, c)| matches!(c, 'g'..='z' | 'G'..='Z'),
366            _ => |(_, c)| c.is_ascii_alphabetic(),
367        };
368
369        let mut suffix = "";
370        if let Some((suffix_start, _)) = text.char_indices().find(is_suffix_start) {
371            let (text2, suffix2) = text.split_at(suffix_start);
372            text = text2;
373            suffix = suffix2;
374        };
375
376        (prefix, text, suffix)
377    }
378
379    pub fn value(&self) -> Result<u128, ParseIntError> {
380        let (_, text, _) = self.split_into_parts();
381        u128::from_str_radix(&text.replace('_', ""), self.radix() as u32)
382    }
383
384    pub fn suffix(&self) -> Option<&str> {
385        let (_, _, suffix) = self.split_into_parts();
386        if suffix.is_empty() { None } else { Some(suffix) }
387    }
388
389    pub fn value_string(&self) -> String {
390        let (_, text, _) = self.split_into_parts();
391        text.replace('_', "")
392    }
393}
394
395impl ast::FloatNumber {
396    pub fn split_into_parts(&self) -> (&str, &str) {
397        let text = self.text();
398        let mut float_text = self.text();
399        let mut suffix = "";
400        let mut indices = text.char_indices();
401        if let Some((mut suffix_start, c)) = indices.by_ref().find(|(_, c)| c.is_ascii_alphabetic())
402        {
403            if c == 'e' || c == 'E' {
404                if let Some(suffix_start_tuple) = indices.find(|(_, c)| c.is_ascii_alphabetic()) {
405                    suffix_start = suffix_start_tuple.0;
406
407                    float_text = &text[..suffix_start];
408                    suffix = &text[suffix_start..];
409                }
410            } else {
411                float_text = &text[..suffix_start];
412                suffix = &text[suffix_start..];
413            }
414        }
415
416        (float_text, suffix)
417    }
418
419    pub fn suffix(&self) -> Option<&str> {
420        let (_, suffix) = self.split_into_parts();
421        if suffix.is_empty() { None } else { Some(suffix) }
422    }
423
424    pub fn value_string(&self) -> String {
425        let (text, _) = self.split_into_parts();
426        text.replace('_', "")
427    }
428}
429
430#[derive(Debug, PartialEq, Eq, Copy, Clone)]
431pub enum Radix {
432    Binary = 2,
433    Octal = 8,
434    Decimal = 10,
435    Hexadecimal = 16,
436}
437
438impl Radix {
439    pub const ALL: &'static [Radix] =
440        &[Radix::Binary, Radix::Octal, Radix::Decimal, Radix::Hexadecimal];
441
442    const fn prefix_len(self) -> usize {
443        match self {
444            Self::Decimal => 0,
445            _ => 2,
446        }
447    }
448}
449
450impl ast::Char {
451    pub fn value(&self) -> Result<char, EscapeError> {
452        let mut text = self.text();
453        if text.starts_with('\'') {
454            text = &text[1..];
455        } else {
456            return Err(EscapeError::ZeroChars);
457        }
458        if text.ends_with('\'') {
459            text = &text[0..text.len() - 1];
460        }
461
462        unescape_char(text)
463    }
464}
465
466impl ast::Byte {
467    pub fn value(&self) -> Result<u8, EscapeError> {
468        let mut text = self.text();
469        if text.starts_with("b\'") {
470            text = &text[2..];
471        } else {
472            return Err(EscapeError::ZeroChars);
473        }
474        if text.ends_with('\'') {
475            text = &text[0..text.len() - 1];
476        }
477
478        unescape_byte(text)
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use rustc_apfloat::ieee::Quad as f128;
485
486    use crate::ast::{self, FloatNumber, IntNumber, make};
487
488    fn check_float_suffix<'a>(lit: &str, expected: impl Into<Option<&'a str>>) {
489        assert_eq!(FloatNumber { syntax: make::tokens::literal(lit) }.suffix(), expected.into());
490    }
491
492    fn check_int_suffix<'a>(lit: &str, expected: impl Into<Option<&'a str>>) {
493        assert_eq!(IntNumber { syntax: make::tokens::literal(lit) }.suffix(), expected.into());
494    }
495
496    // FIXME(#17451) Use `expected: f128` once `f128` is stabilised.
497    fn check_float_value(lit: &str, expected: &str) {
498        let expected = Some(expected.parse::<f128>().unwrap());
499        assert_eq!(
500            FloatNumber { syntax: make::tokens::literal(lit) }.value_string().parse::<f128>().ok(),
501            expected
502        );
503        assert_eq!(
504            IntNumber { syntax: make::tokens::literal(lit) }.value_string().parse::<f128>().ok(),
505            expected
506        );
507    }
508
509    fn check_int_value(lit: &str, expected: impl Into<Option<u128>>) {
510        assert_eq!(IntNumber { syntax: make::tokens::literal(lit) }.value().ok(), expected.into());
511    }
512
513    #[test]
514    fn test_float_number_suffix() {
515        check_float_suffix("123.0", None);
516        check_float_suffix("123f32", "f32");
517        check_float_suffix("123.0e", None);
518        check_float_suffix("123.0e4", None);
519        check_float_suffix("123.0ef16", "f16");
520        check_float_suffix("123.0E4f32", "f32");
521        check_float_suffix("1_2_3.0_f128", "f128");
522    }
523
524    #[test]
525    fn test_int_number_suffix() {
526        check_int_suffix("123", None);
527        check_int_suffix("123i32", "i32");
528        check_int_suffix("1_0_1_l_o_l", "l_o_l");
529        check_int_suffix("0b11", None);
530        check_int_suffix("0o11", None);
531        check_int_suffix("0xff", None);
532        check_int_suffix("0b11u32", "u32");
533        check_int_suffix("0o11u32", "u32");
534        check_int_suffix("0xffu32", "u32");
535    }
536
537    fn check_string_value<'a>(lit: &str, expected: impl Into<Option<&'a str>>) {
538        assert_eq!(
539            ast::String { syntax: make::tokens::literal(&format!("\"{lit}\"")) }
540                .value()
541                .as_deref()
542                .ok(),
543            expected.into()
544        );
545    }
546
547    #[test]
548    fn test_string_escape() {
549        check_string_value(r"foobar", "foobar");
550        check_string_value(r"\foobar", None);
551        check_string_value(r"\nfoobar", "\nfoobar");
552        check_string_value(r"C:\\Windows\\System32\\", "C:\\Windows\\System32\\");
553        check_string_value(r"\x61bcde", "abcde");
554        check_string_value(
555            r"a\
556bcde", "abcde",
557        );
558    }
559
560    fn check_byte_string_value<'a, const N: usize>(
561        lit: &str,
562        expected: impl Into<Option<&'a [u8; N]>>,
563    ) {
564        assert_eq!(
565            ast::ByteString { syntax: make::tokens::literal(&format!("b\"{lit}\"")) }
566                .value()
567                .as_deref()
568                .ok(),
569            expected.into().map(|value| &value[..])
570        );
571    }
572
573    #[test]
574    fn test_byte_string_escape() {
575        check_byte_string_value(r"foobar", b"foobar");
576        check_byte_string_value(r"\foobar", None::<&[u8; 0]>);
577        check_byte_string_value(r"\nfoobar", b"\nfoobar");
578        check_byte_string_value(r"C:\\Windows\\System32\\", b"C:\\Windows\\System32\\");
579        check_byte_string_value(r"\x61bcde", b"abcde");
580        check_byte_string_value(
581            r"a\
582bcde", b"abcde",
583        );
584    }
585
586    #[test]
587    fn test_value_underscores() {
588        check_float_value("1.3_4665449586950493453___6_f128", "1.346654495869504934536");
589        check_float_value("1.234567891011121_f64", "1.234567891011121");
590        check_float_value("1__0.__0__f32", "10.0");
591        check_float_value("3._0_f16", "3.0");
592        check_int_value("0b__1_0_", 2);
593        check_int_value("1_1_1_1_1_1", 111111);
594    }
595}