1use std::{num::IntErrorKind, ops};
4
5use squawk_lexer::tokenize;
6
7use crate::SyntaxKind;
8
9pub struct LexedStr<'a> {
10 text: &'a str,
11 kind: Vec<SyntaxKind>,
12 start: Vec<u32>,
13 error: Vec<LexError>,
14}
15
16struct LexError {
17 msg: String,
18 range: ops::Range<u32>,
19}
20
21impl<'a> LexedStr<'a> {
22 pub fn new(text: &'a str) -> LexedStr<'a> {
25 let mut conv = Converter::new(text);
26
27 for token in tokenize(&text[conv.offset..]) {
28 let token_text = &text[conv.offset..][..token.len as usize];
29
30 conv.extend_token(&token.kind, token_text);
31 }
32
33 conv.finalize_with_eof()
34 }
35
36 pub(crate) fn len(&self) -> usize {
59 self.kind.len() - 1
60 }
61
62 pub(crate) fn kind(&self, i: usize) -> SyntaxKind {
67 assert!(i < self.len());
68 self.kind[i]
69 }
70
71 pub(crate) fn text(&self, i: usize) -> &str {
72 self.range_text(i..i + 1)
73 }
74
75 pub(crate) fn range_text(&self, r: ops::Range<usize>) -> &str {
76 assert!(r.start < r.end && r.end <= self.len());
77 let lo = self.start[r.start] as usize;
78 let hi = self.start[r.end] as usize;
79 &self.text[lo..hi]
80 }
81
82 pub fn text_range(&self, i: usize) -> ops::Range<usize> {
84 assert!(i < self.len());
85 let lo = self.start[i] as usize;
86 let hi = self.start[i + 1] as usize;
87 lo..hi
88 }
89 pub fn text_start(&self, i: usize) -> usize {
90 assert!(i <= self.len());
91 self.start[i] as usize
92 }
93 pub fn errors(&self) -> impl Iterator<Item = (&ops::Range<u32>, &str)> + '_ {
109 self.error.iter().map(|it| (&it.range, it.msg.as_str()))
110 }
111
112 fn push(&mut self, kind: SyntaxKind, offset: usize) {
113 self.kind.push(kind);
114 self.start.push(offset as u32);
115 }
116}
117
118struct Converter<'a> {
119 res: LexedStr<'a>,
120 offset: usize,
121}
122
123fn is_empty_quoted_ident(token_text: &str) -> bool {
124 let inner = if let Some(stripped) = token_text
125 .strip_prefix(['u', 'U'])
126 .and_then(|s| s.strip_prefix('&'))
127 {
128 stripped
129 } else {
130 token_text
131 };
132 inner == "\"\""
133}
134
135impl<'a> Converter<'a> {
136 fn new(text: &'a str) -> Self {
137 Self {
138 res: LexedStr {
139 text,
140 kind: Vec::new(),
141 start: Vec::new(),
142 error: Vec::new(),
143 },
144 offset: 0,
145 }
146 }
147
148 fn finalize_with_eof(mut self) -> LexedStr<'a> {
149 self.res.push(SyntaxKind::EOF, self.offset);
150 self.res
151 }
152
153 fn push(&mut self, kind: SyntaxKind, len: usize, err: Option<(&str, ops::Range<u32>)>) {
154 let token_start = self.offset as u32;
155 self.res.push(kind, self.offset);
156 self.offset += len;
157
158 if let Some((msg, err_range)) = err {
159 self.res.error.push(LexError {
160 msg: msg.to_owned(),
161 range: token_start + err_range.start..token_start + err_range.end,
162 });
163 }
164 }
165
166 fn extend_token(&mut self, kind: &squawk_lexer::TokenKind, token_text: &str) {
167 let mut err = "";
172 let mut err_range: Option<ops::Range<u32>> = None;
173
174 let syntax_kind = {
175 match kind {
176 squawk_lexer::TokenKind::LineComment => SyntaxKind::COMMENT,
177 squawk_lexer::TokenKind::BlockComment { terminated } => {
178 if !terminated {
179 err = "Missing trailing `*/` symbols to terminate the block comment";
180 }
181 SyntaxKind::COMMENT
182 }
183
184 squawk_lexer::TokenKind::Whitespace => SyntaxKind::WHITESPACE,
185 squawk_lexer::TokenKind::Ident => {
186 SyntaxKind::from_keyword(token_text).unwrap_or(SyntaxKind::IDENT)
187 }
188 squawk_lexer::TokenKind::Literal { kind, .. } => {
189 self.extend_literal(token_text, kind);
190 return;
191 }
192 squawk_lexer::TokenKind::Semi => SyntaxKind::SEMICOLON,
193 squawk_lexer::TokenKind::Comma => SyntaxKind::COMMA,
194 squawk_lexer::TokenKind::Dot => SyntaxKind::DOT,
195 squawk_lexer::TokenKind::OpenParen => SyntaxKind::L_PAREN,
196 squawk_lexer::TokenKind::CloseParen => SyntaxKind::R_PAREN,
197 squawk_lexer::TokenKind::OpenBracket => SyntaxKind::L_BRACK,
198 squawk_lexer::TokenKind::CloseBracket => SyntaxKind::R_BRACK,
199 squawk_lexer::TokenKind::OpenCurly => SyntaxKind::L_CURLY,
200 squawk_lexer::TokenKind::CloseCurly => SyntaxKind::R_CURLY,
201 squawk_lexer::TokenKind::At => SyntaxKind::AT,
202 squawk_lexer::TokenKind::Pound => SyntaxKind::POUND,
203 squawk_lexer::TokenKind::Tilde => SyntaxKind::TILDE,
204 squawk_lexer::TokenKind::Question => SyntaxKind::QUESTION,
205 squawk_lexer::TokenKind::Colon => SyntaxKind::COLON,
206 squawk_lexer::TokenKind::Eq => SyntaxKind::EQ,
207 squawk_lexer::TokenKind::Bang => SyntaxKind::BANG,
208 squawk_lexer::TokenKind::Lt => SyntaxKind::L_ANGLE,
209 squawk_lexer::TokenKind::Gt => SyntaxKind::R_ANGLE,
210 squawk_lexer::TokenKind::Minus => SyntaxKind::MINUS,
211 squawk_lexer::TokenKind::And => SyntaxKind::AMP,
212 squawk_lexer::TokenKind::Or => SyntaxKind::PIPE,
213 squawk_lexer::TokenKind::Plus => SyntaxKind::PLUS,
214 squawk_lexer::TokenKind::Star => SyntaxKind::STAR,
215 squawk_lexer::TokenKind::Slash => SyntaxKind::SLASH,
216 squawk_lexer::TokenKind::Caret => SyntaxKind::CARET,
217 squawk_lexer::TokenKind::Percent => SyntaxKind::PERCENT,
218 squawk_lexer::TokenKind::Unknown => SyntaxKind::ERROR,
219 squawk_lexer::TokenKind::UnknownPrefix => {
220 err = "unknown literal prefix";
221 SyntaxKind::IDENT
222 }
223 squawk_lexer::TokenKind::Eof => SyntaxKind::EOF,
224 squawk_lexer::TokenKind::Backtick => SyntaxKind::BACKTICK,
225 squawk_lexer::TokenKind::PositionalParam {
226 trailing_junk_start,
227 } => {
228 let digits = &token_text[1..*trailing_junk_start as usize];
229 if digits.is_empty() {
230 err = "missing parameter number";
231 err_range = Some(0..1);
232 } else if digits
233 .parse::<i32>()
234 .is_err_and(|err| matches!(err.kind(), IntErrorKind::PosOverflow))
235 {
236 err = "parameter number too large";
237 err_range = Some(0..*trailing_junk_start);
238 } else if (*trailing_junk_start as usize) < token_text.len() {
239 err = "trailing junk after positional parameter";
240 err_range = Some(*trailing_junk_start..token_text.len() as u32);
241 }
242 SyntaxKind::POSITIONAL_PARAM
243 }
244 squawk_lexer::TokenKind::QuotedIdent { terminated } => {
245 if !terminated {
246 err = "Missing trailing \" to terminate the quoted identifier"
247 } else if is_empty_quoted_ident(token_text) {
248 err = "empty delimited identifier";
249 }
250 SyntaxKind::IDENT
251 }
252 }
253 };
254
255 let err = if err.is_empty() { None } else { Some(err) };
256 let err = err.map(|msg| (msg, err_range.unwrap_or(0..token_text.len() as u32)));
257 self.push(syntax_kind, token_text.len(), err);
258 }
259
260 fn extend_literal(&mut self, token_text: &str, kind: &squawk_lexer::LiteralKind) {
261 let mut err: Option<String> = None;
262 let mut err_range: Option<ops::Range<u32>> = None;
263
264 let syntax_kind = match *kind {
265 squawk_lexer::LiteralKind::Int {
266 empty_int,
267 base,
268 trailing_junk_start,
269 } => {
270 if empty_int {
271 err = Some("Missing digits after the integer base prefix".into());
272 } else {
273 if matches!(base, squawk_lexer::Base::Binary | squawk_lexer::Base::Octal) {
274 let prefix_len = 2u32;
275 let digits = &token_text[prefix_len as usize..trailing_junk_start as usize];
276 let base = base as u32;
277 let token_start = self.offset as u32;
278 for (i, c) in digits.char_indices() {
279 if c != '_' && c.to_digit(base).is_none() {
280 let start = token_start + prefix_len + i as u32;
281 let end = start + c.len_utf8() as u32;
282 self.res.error.push(LexError {
283 msg: format!("invalid digit for a base {base} literal"),
284 range: start..end,
285 });
286 }
287 }
288 }
289 if (trailing_junk_start as usize) < token_text.len() {
290 err = Some("trailing junk after numeric literal".into());
291 err_range = Some(trailing_junk_start..token_text.len() as u32);
292 }
293 }
294 SyntaxKind::INT_NUMBER
295 }
296 squawk_lexer::LiteralKind::Numeric {
297 empty_exponent_start,
298 base: _,
299 trailing_junk_start,
300 } => {
301 if let Some(exponent_start) = empty_exponent_start {
302 err = Some("Missing digits after the exponent symbol".into());
303 err_range = Some(exponent_start..exponent_start + 1);
304 } else if (trailing_junk_start as usize) < token_text.len() {
305 err = Some("trailing junk after numeric literal".into());
306 err_range = Some(trailing_junk_start..token_text.len() as u32);
307 }
308 SyntaxKind::NUMERIC_NUMBER
309 }
310 squawk_lexer::LiteralKind::Str { terminated } => {
311 if !terminated {
312 err =
313 Some("Missing trailing `'` symbol to terminate the string literal".into());
314 }
315 SyntaxKind::STRING
316 }
317 squawk_lexer::LiteralKind::ByteStr { terminated } => {
318 if !terminated {
319 err = Some(
320 "Missing trailing `'` symbol to terminate the hex bit string literal"
321 .into(),
322 );
323 }
324 SyntaxKind::BYTE_STRING
326 }
327 squawk_lexer::LiteralKind::BitStr { terminated } => {
328 if !terminated {
329 err = Some(
330 "Missing trailing `'` symbol to terminate the bit string literal".into(),
331 );
332 }
333 SyntaxKind::BIT_STRING
335 }
336 squawk_lexer::LiteralKind::DollarQuotedString { terminated } => {
337 if !terminated {
338 err = Some("Unterminated dollar quoted string literal".into());
340 }
341 SyntaxKind::DOLLAR_QUOTED_STRING
342 }
343 squawk_lexer::LiteralKind::UnicodeEscStr { terminated } => {
344 if !terminated {
345 err = Some(
346 "Missing trailing `'` symbol to terminate the unicode escape string literal"
347 .into(),
348 );
349 }
350 SyntaxKind::UNICODE_ESC_STRING
352 }
353 squawk_lexer::LiteralKind::EscStr { terminated } => {
354 if !terminated {
355 err = Some(
356 "Missing trailing `'` symbol to terminate the escape string literal".into(),
357 );
358 }
359 SyntaxKind::ESC_STRING
361 }
362 };
363
364 let err = err
365 .as_deref()
366 .map(|msg| (msg, err_range.unwrap_or(0..token_text.len() as u32)));
367 self.push(syntax_kind, token_text.len(), err);
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
374 use insta::assert_snapshot;
375
376 use super::LexedStr;
377
378 fn lex(text: &str) -> String {
379 let lexed = LexedStr::new(text);
380 let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
381 let mut res = String::new();
382
383 for (range, msg) in lexed.errors() {
384 let span = range.start as usize..range.end as usize;
385 let group = Level::ERROR.primary_title(msg).element(
386 Snippet::source(text)
387 .fold(true)
388 .annotation(AnnotationKind::Primary.span(span)),
389 );
390 res.push_str(&renderer.render(&[group]).to_string());
391 res.push('\n');
392 }
393
394 res
395 }
396
397 #[test]
398 fn empty_int_error() {
399 assert_snapshot!(lex("select 0x;"), @"
400 error: Missing digits after the integer base prefix
401 ╭▸
402 1 │ select 0x;
403 ╰╴ ━━
404 ");
405 }
406
407 #[test]
408 fn empty_int_with_trailing_ident_error() {
409 assert_snapshot!(lex("select 0xg;"), @"
410 error: trailing junk after numeric literal
411 ╭▸
412 1 │ select 0xg;
413 ╰╴ ━
414 ");
415 }
416
417 #[test]
418 fn invalid_octal_digits_error() {
419 assert_snapshot!(lex("select 0o999;"), @"
420 error: invalid digit for a base 8 literal
421 ╭▸
422 1 │ select 0o999;
423 ╰╴ ━
424 error: invalid digit for a base 8 literal
425 ╭▸
426 1 │ select 0o999;
427 ╰╴ ━
428 error: invalid digit for a base 8 literal
429 ╭▸
430 1 │ select 0o999;
431 ╰╴ ━
432 ");
433 }
434
435 #[test]
436 fn invalid_binary_digits_error() {
437 assert_snapshot!(lex("select 0b234;"), @"
438 error: invalid digit for a base 2 literal
439 ╭▸
440 1 │ select 0b234;
441 ╰╴ ━
442 error: invalid digit for a base 2 literal
443 ╭▸
444 1 │ select 0b234;
445 ╰╴ ━
446 error: invalid digit for a base 2 literal
447 ╭▸
448 1 │ select 0b234;
449 ╰╴ ━
450 ");
451 }
452
453 #[test]
454 fn invalid_octal_digits_after_valid_error() {
455 assert_snapshot!(lex("select 0o7889;"), @"
456 error: invalid digit for a base 8 literal
457 ╭▸
458 1 │ select 0o7889;
459 ╰╴ ━
460 error: invalid digit for a base 8 literal
461 ╭▸
462 1 │ select 0o7889;
463 ╰╴ ━
464 error: invalid digit for a base 8 literal
465 ╭▸
466 1 │ select 0o7889;
467 ╰╴ ━
468 ");
469 }
470
471 #[test]
472 fn empty_exponent_error() {
473 assert_snapshot!(lex("select 1e;"), @"
474 error: Missing digits after the exponent symbol
475 ╭▸
476 1 │ select 1e;
477 ╰╴ ━
478 ");
479 }
480
481 #[test]
482 fn unterminated_string_error() {
483 assert_snapshot!(lex("select 'hello;"), @"
484 error: Missing trailing `'` symbol to terminate the string literal
485 ╭▸
486 1 │ select 'hello;
487 ╰╴ ━━━━━━━
488 ");
489 }
490
491 #[test]
492 fn unterminated_hex_bit_string_error() {
493 assert_snapshot!(lex("select X'1F;"), @"
494 error: Missing trailing `'` symbol to terminate the hex bit string literal
495 ╭▸
496 1 │ select X'1F;
497 ╰╴ ━━━━━
498 ");
499 }
500
501 #[test]
502 fn unterminated_bit_string_error() {
503 assert_snapshot!(lex("select B'101;"), @"
504 error: Missing trailing `'` symbol to terminate the bit string literal
505 ╭▸
506 1 │ select B'101;
507 ╰╴ ━━━━━━
508 ");
509 }
510
511 #[test]
512 fn unterminated_dollar_quoted_string_error() {
513 assert_snapshot!(lex("select $tag$hello;"), @"
514 error: Unterminated dollar quoted string literal
515 ╭▸
516 1 │ select $tag$hello;
517 ╰╴ ━━━━━━━━━━━
518 ");
519 }
520
521 #[test]
522 fn unterminated_unicode_escape_string_error() {
523 assert_snapshot!(lex("select U&'hello;"), @"
524 error: Missing trailing `'` symbol to terminate the unicode escape string literal
525 ╭▸
526 1 │ select U&'hello;
527 ╰╴ ━━━━━━━━━
528 ");
529 }
530
531 #[test]
532 fn unterminated_escape_string_error() {
533 assert_snapshot!(lex("select E'hello;"), @"
534 error: Missing trailing `'` symbol to terminate the escape string literal
535 ╭▸
536 1 │ select E'hello;
537 ╰╴ ━━━━━━━━
538 ");
539 }
540}