Skip to main content

squawk_parser/
lexed_str.rs

1// based on https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/parser/src/lexed_str.rs
2
3use std::ops;
4
5use squawk_lexer::tokenize;
6
7use crate::SyntaxKind;
8
9pub struct LexedStr<'a> {
10    text: &'a str,
11    kind: Vec<SyntaxKind>,
12    start: Vec<u32>,
13    error: Vec<LexError>,
14}
15
16struct LexError {
17    msg: String,
18    token: u32,
19}
20
21impl<'a> LexedStr<'a> {
22    // TODO: rust-analyzer has an edition thing to specify things that are only
23    // available in certain version, we can do that later
24    pub fn new(text: &'a str) -> LexedStr<'a> {
25        let mut conv = Converter::new(text);
26
27        for token in tokenize(&text[conv.offset..]) {
28            let token_text = &text[conv.offset..][..token.len as usize];
29
30            conv.extend_token(&token.kind, token_text);
31        }
32
33        conv.finalize_with_eof()
34    }
35
36    // pub(crate) fn single_token(text: &'a str) -> Option<(SyntaxKind, Option<String>)> {
37    //     if text.is_empty() {
38    //         return None;
39    //     }
40
41    //     let token = tokenize(text).next()?;
42    //     if token.len as usize != text.len() {
43    //         return None;
44    //     }
45
46    //     let mut conv = Converter::new(text);
47    //     conv.extend_token(&token.kind, text);
48    //     match &*conv.res.kind {
49    //         [kind] => Some((*kind, conv.res.error.pop().map(|it| it.msg))),
50    //         _ => None,
51    //     }
52    // }
53
54    // pub(crate) fn as_str(&self) -> &str {
55    //     self.text
56    // }
57
58    pub(crate) fn len(&self) -> usize {
59        self.kind.len() - 1
60    }
61
62    // pub(crate) fn is_empty(&self) -> bool {
63    //     self.len() == 0
64    // }
65
66    pub(crate) fn kind(&self, i: usize) -> SyntaxKind {
67        assert!(i < self.len());
68        self.kind[i]
69    }
70
71    pub(crate) fn text(&self, i: usize) -> &str {
72        self.range_text(i..i + 1)
73    }
74
75    pub(crate) fn range_text(&self, r: ops::Range<usize>) -> &str {
76        assert!(r.start < r.end && r.end <= self.len());
77        let lo = self.start[r.start] as usize;
78        let hi = self.start[r.end] as usize;
79        &self.text[lo..hi]
80    }
81
82    // Naming is hard.
83    pub fn text_range(&self, i: usize) -> ops::Range<usize> {
84        assert!(i < self.len());
85        let lo = self.start[i] as usize;
86        let hi = self.start[i + 1] as usize;
87        lo..hi
88    }
89    pub fn text_start(&self, i: usize) -> usize {
90        assert!(i <= self.len());
91        self.start[i] as usize
92    }
93    // pub(crate) fn text_len(&self, i: usize) -> usize {
94    //     assert!(i < self.len());
95    //     let r = self.text_range(i);
96    //     r.end - r.start
97    // }
98
99    // pub(crate) fn error(&self, i: usize) -> Option<&str> {
100    //     assert!(i < self.len());
101    //     let err = self
102    //         .error
103    //         .binary_search_by_key(&(i as u32), |i| i.token)
104    //         .ok()?;
105    //     Some(self.error[err].msg.as_str())
106    // }
107
108    pub fn errors(&self) -> impl Iterator<Item = (usize, &str)> + '_ {
109        self.error
110            .iter()
111            .map(|it| (it.token as usize, it.msg.as_str()))
112    }
113
114    fn push(&mut self, kind: SyntaxKind, offset: usize) {
115        self.kind.push(kind);
116        self.start.push(offset as u32);
117    }
118}
119
120struct Converter<'a> {
121    res: LexedStr<'a>,
122    offset: usize,
123}
124
125impl<'a> Converter<'a> {
126    fn new(text: &'a str) -> Self {
127        Self {
128            res: LexedStr {
129                text,
130                kind: Vec::new(),
131                start: Vec::new(),
132                error: Vec::new(),
133            },
134            offset: 0,
135        }
136    }
137
138    fn finalize_with_eof(mut self) -> LexedStr<'a> {
139        self.res.push(SyntaxKind::EOF, self.offset);
140        self.res
141    }
142
143    fn push(&mut self, kind: SyntaxKind, len: usize, err: Option<&str>) {
144        self.res.push(kind, self.offset);
145        self.offset += len;
146
147        if let Some(err) = err {
148            let token = self.res.len() as u32;
149            let msg = err.to_owned();
150            self.res.error.push(LexError { msg, token });
151        }
152    }
153
154    fn extend_token(&mut self, kind: &squawk_lexer::TokenKind, token_text: &str) {
155        // A note on an intended tradeoff:
156        // We drop some useful information here (see patterns with double dots `..`)
157        // Storing that info in `SyntaxKind` is not possible due to its layout requirements of
158        // being `u16` that come from `rowan::SyntaxKind`.
159        let mut err = "";
160
161        let syntax_kind = {
162            match kind {
163                squawk_lexer::TokenKind::LineComment => SyntaxKind::COMMENT,
164                squawk_lexer::TokenKind::BlockComment { terminated } => {
165                    if !terminated {
166                        err = "Missing trailing `*/` symbols to terminate the block comment";
167                    }
168                    SyntaxKind::COMMENT
169                }
170
171                squawk_lexer::TokenKind::Whitespace => SyntaxKind::WHITESPACE,
172                squawk_lexer::TokenKind::Ident => {
173                    SyntaxKind::from_keyword(token_text).unwrap_or(SyntaxKind::IDENT)
174                }
175                squawk_lexer::TokenKind::Literal { kind, .. } => {
176                    self.extend_literal(token_text, kind);
177                    return;
178                }
179                squawk_lexer::TokenKind::Semi => SyntaxKind::SEMICOLON,
180                squawk_lexer::TokenKind::Comma => SyntaxKind::COMMA,
181                squawk_lexer::TokenKind::Dot => SyntaxKind::DOT,
182                squawk_lexer::TokenKind::OpenParen => SyntaxKind::L_PAREN,
183                squawk_lexer::TokenKind::CloseParen => SyntaxKind::R_PAREN,
184                squawk_lexer::TokenKind::OpenBracket => SyntaxKind::L_BRACK,
185                squawk_lexer::TokenKind::CloseBracket => SyntaxKind::R_BRACK,
186                squawk_lexer::TokenKind::OpenCurly => SyntaxKind::L_CURLY,
187                squawk_lexer::TokenKind::CloseCurly => SyntaxKind::R_CURLY,
188                squawk_lexer::TokenKind::At => SyntaxKind::AT,
189                squawk_lexer::TokenKind::Pound => SyntaxKind::POUND,
190                squawk_lexer::TokenKind::Tilde => SyntaxKind::TILDE,
191                squawk_lexer::TokenKind::Question => SyntaxKind::QUESTION,
192                squawk_lexer::TokenKind::Colon => SyntaxKind::COLON,
193                squawk_lexer::TokenKind::Eq => SyntaxKind::EQ,
194                squawk_lexer::TokenKind::Bang => SyntaxKind::BANG,
195                squawk_lexer::TokenKind::Lt => SyntaxKind::L_ANGLE,
196                squawk_lexer::TokenKind::Gt => SyntaxKind::R_ANGLE,
197                squawk_lexer::TokenKind::Minus => SyntaxKind::MINUS,
198                squawk_lexer::TokenKind::And => SyntaxKind::AMP,
199                squawk_lexer::TokenKind::Or => SyntaxKind::PIPE,
200                squawk_lexer::TokenKind::Plus => SyntaxKind::PLUS,
201                squawk_lexer::TokenKind::Star => SyntaxKind::STAR,
202                squawk_lexer::TokenKind::Slash => SyntaxKind::SLASH,
203                squawk_lexer::TokenKind::Caret => SyntaxKind::CARET,
204                squawk_lexer::TokenKind::Percent => SyntaxKind::PERCENT,
205                squawk_lexer::TokenKind::Unknown => SyntaxKind::ERROR,
206                squawk_lexer::TokenKind::UnknownPrefix => {
207                    err = "unknown literal prefix";
208                    SyntaxKind::IDENT
209                }
210                squawk_lexer::TokenKind::Eof => SyntaxKind::EOF,
211                squawk_lexer::TokenKind::Backtick => SyntaxKind::BACKTICK,
212                squawk_lexer::TokenKind::PositionalParam => SyntaxKind::POSITIONAL_PARAM,
213                squawk_lexer::TokenKind::QuotedIdent { terminated } => {
214                    if !terminated {
215                        err = "Missing trailing \" to terminate the quoted identifier"
216                    }
217                    SyntaxKind::IDENT
218                }
219            }
220        };
221
222        let err = if err.is_empty() { None } else { Some(err) };
223        self.push(syntax_kind, token_text.len(), err);
224    }
225
226    fn extend_literal(&mut self, token_text: &str, kind: &squawk_lexer::LiteralKind) {
227        let mut err: Option<String> = None;
228
229        let syntax_kind = match *kind {
230            squawk_lexer::LiteralKind::Int { empty_int, base: _ } => {
231                if empty_int {
232                    err = Some("Missing digits after the integer base prefix".into());
233                }
234                SyntaxKind::INT_NUMBER
235            }
236            squawk_lexer::LiteralKind::Float {
237                empty_exponent,
238                base: _,
239            } => {
240                if empty_exponent {
241                    err = Some("Missing digits after the exponent symbol".into());
242                }
243                SyntaxKind::FLOAT_NUMBER
244            }
245            squawk_lexer::LiteralKind::Str { terminated } => {
246                if !terminated {
247                    err =
248                        Some("Missing trailing `'` symbol to terminate the string literal".into());
249                }
250                SyntaxKind::STRING
251            }
252            squawk_lexer::LiteralKind::ByteStr { terminated } => {
253                if !terminated {
254                    err = Some(
255                        "Missing trailing `'` symbol to terminate the hex bit string literal"
256                            .into(),
257                    );
258                } else {
259                    let inside = &token_text[2..token_text.len() - 1];
260                    if let Some(c) = inside.chars().find(|c| !c.is_ascii_hexdigit()) {
261                        err = Some(format!("\"{c}\" is not a valid hexadecimal digit"));
262                    }
263                }
264                SyntaxKind::BYTE_STRING
265            }
266            squawk_lexer::LiteralKind::BitStr { terminated } => {
267                if !terminated {
268                    err = Some(
269                        "Missing trailing `'` symbol to terminate the bit string literal".into(),
270                    );
271                } else {
272                    let inside = &token_text[2..token_text.len() - 1];
273                    if let Some(c) = inside.chars().find(|&c| c != '0' && c != '1') {
274                        err = Some(format!("\"{c}\" is not a valid binary digit"));
275                    }
276                }
277                SyntaxKind::BIT_STRING
278            }
279            squawk_lexer::LiteralKind::DollarQuotedString { terminated } => {
280                if !terminated {
281                    // TODO: we could be fancier and say the ending string we're looking for
282                    err = Some("Unterminated dollar quoted string literal".into());
283                }
284                SyntaxKind::DOLLAR_QUOTED_STRING
285            }
286            squawk_lexer::LiteralKind::UnicodeEscStr { terminated } => {
287                if !terminated {
288                    err = Some(
289                        "Missing trailing `'` symbol to terminate the unicode escape string literal"
290                            .into(),
291                    );
292                }
293                // validated in squawk_syntax
294                SyntaxKind::UNICODE_ESC_STRING
295            }
296            squawk_lexer::LiteralKind::EscStr { terminated } => {
297                if !terminated {
298                    err = Some(
299                        "Missing trailing `'` symbol to terminate the escape string literal".into(),
300                    );
301                } else {
302                    err = validate_escape_string_unicode_escapes(token_text);
303                }
304                SyntaxKind::ESC_STRING
305            }
306        };
307
308        self.push(syntax_kind, token_text.len(), err.as_deref());
309    }
310}
311
312fn validate_escape_string_unicode_escapes(token_text: &str) -> Option<String> {
313    let mut chars = token_text[2..token_text.len() - 1].chars();
314
315    while let Some(c) = chars.next() {
316        if c != '\\' {
317            continue;
318        }
319
320        let (required, example) = match chars.next() {
321            Some('u') => (4, r"\uXXXX"),
322            Some('U') => (8, r"\UXXXXXXXX"),
323            _ => continue,
324        };
325
326        for _ in 0..required {
327            if !chars.next().is_some_and(|c| c.is_ascii_hexdigit()) {
328                return Some(format!(
329                    "Unicode escape requires {required} hex digits: {example}"
330                ));
331            }
332        }
333    }
334
335    None
336}
337
338#[cfg(test)]
339mod tests {
340    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
341    use insta::assert_snapshot;
342
343    use super::LexedStr;
344
345    fn lex(text: &str) -> String {
346        let lexed = LexedStr::new(text);
347        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
348        let mut res = String::new();
349
350        for (token, msg) in lexed.errors() {
351            let group = Level::ERROR.primary_title(msg).element(
352                Snippet::source(text)
353                    .fold(true)
354                    .annotation(AnnotationKind::Primary.span(lexed.text_range(token))),
355            );
356            res.push_str(&renderer.render(&[group]).to_string());
357            res.push('\n');
358        }
359
360        res
361    }
362
363    #[test]
364    fn empty_int_error() {
365        assert_snapshot!(lex("select 0x;"), @"
366        error: Missing digits after the integer base prefix
367          ╭▸ 
368        1 │ select 0x;
369          ╰╴       ━━
370        ");
371    }
372
373    #[test]
374    fn empty_exponent_error() {
375        assert_snapshot!(lex("select 1e;"), @"
376        error: Missing digits after the exponent symbol
377          ╭▸ 
378        1 │ select 1e;
379          ╰╴       ━━
380        ");
381    }
382
383    #[test]
384    fn unterminated_string_error() {
385        assert_snapshot!(lex("select 'hello;"), @"
386        error: Missing trailing `'` symbol to terminate the string literal
387          ╭▸ 
388        1 │ select 'hello;
389          ╰╴       ━━━━━━━
390        ");
391    }
392
393    #[test]
394    fn hex_invalid_digit() {
395        assert_snapshot!(lex("select X'1FZ';"), @r#"
396        error: "Z" is not a valid hexadecimal digit
397          ╭▸ 
398        1 │ select X'1FZ';
399          ╰╴       ━━━━━━
400        "#);
401    }
402
403    #[test]
404    fn unterminated_hex_bit_string_error() {
405        assert_snapshot!(lex("select X'1F;"), @"
406        error: Missing trailing `'` symbol to terminate the hex bit string literal
407          ╭▸ 
408        1 │ select X'1F;
409          ╰╴       ━━━━━
410        ");
411    }
412
413    #[test]
414    fn unterminated_bit_string_error() {
415        assert_snapshot!(lex("select B'101;"), @"
416        error: Missing trailing `'` symbol to terminate the bit string literal
417          ╭▸ 
418        1 │ select B'101;
419          ╰╴       ━━━━━━
420        ");
421    }
422
423    #[test]
424    fn invalid_binary_digit_error() {
425        assert_snapshot!(lex("select b'0 ';"), @r#"
426        error: " " is not a valid binary digit
427          ╭▸ 
428        1 │ select b'0 ';
429          ╰╴       ━━━━━
430        "#);
431    }
432
433    #[test]
434    fn unterminated_dollar_quoted_string_error() {
435        assert_snapshot!(lex("select $tag$hello;"), @"
436        error: Unterminated dollar quoted string literal
437          ╭▸ 
438        1 │ select $tag$hello;
439          ╰╴       ━━━━━━━━━━━
440        ");
441    }
442
443    #[test]
444    fn unterminated_unicode_escape_string_error() {
445        assert_snapshot!(lex("select U&'hello;"), @"
446        error: Missing trailing `'` symbol to terminate the unicode escape string literal
447          ╭▸ 
448        1 │ select U&'hello;
449          ╰╴       ━━━━━━━━━
450        ");
451    }
452
453    #[test]
454    fn unterminated_escape_string_error() {
455        assert_snapshot!(lex("select E'hello;"), @"
456        error: Missing trailing `'` symbol to terminate the escape string literal
457          ╭▸ 
458        1 │ select E'hello;
459          ╰╴       ━━━━━━━━
460        ");
461    }
462
463    #[test]
464    fn invalid_unicode_escape_4_digits_error() {
465        assert_snapshot!(lex(r"select E'\u00';"), @r"
466        error: Unicode escape requires 4 hex digits: \uXXXX
467          ╭▸ 
468        1 │ select E'\u00';
469          ╰╴       ━━━━━━━
470        ");
471    }
472
473    #[test]
474    fn invalid_unicode_escape_8_digits_error() {
475        assert_snapshot!(lex(r"select E'\UFFFF';"), @r"
476        error: Unicode escape requires 8 hex digits: \UXXXXXXXX
477          ╭▸ 
478        1 │ select E'\UFFFF';
479          ╰╴       ━━━━━━━━━
480        ");
481    }
482}