use std::{num::IntErrorKind, ops};
use squawk_lexer::tokenize;
use crate::SyntaxKind;
pub struct LexedStr<'a> {
text: &'a str,
kind: Vec<SyntaxKind>,
start: Vec<u32>,
error: Vec<LexError>,
}
struct LexError {
msg: String,
range: ops::Range<u32>,
}
impl<'a> LexedStr<'a> {
pub fn new(text: &'a str) -> LexedStr<'a> {
let mut conv = Converter::new(text);
for token in tokenize(&text[conv.offset..]) {
let token_text = &text[conv.offset..][..token.len as usize];
conv.extend_token(&token.kind, token_text);
}
conv.finalize_with_eof()
}
pub(crate) fn len(&self) -> usize {
self.kind.len() - 1
}
pub(crate) fn kind(&self, i: usize) -> SyntaxKind {
assert!(i < self.len());
self.kind[i]
}
pub(crate) fn text(&self, i: usize) -> &str {
self.range_text(i..i + 1)
}
pub(crate) fn range_text(&self, r: ops::Range<usize>) -> &str {
assert!(r.start < r.end && r.end <= self.len());
let lo = self.start[r.start] as usize;
let hi = self.start[r.end] as usize;
&self.text[lo..hi]
}
pub fn text_range(&self, i: usize) -> ops::Range<usize> {
assert!(i < self.len());
let lo = self.start[i] as usize;
let hi = self.start[i + 1] as usize;
lo..hi
}
pub fn text_start(&self, i: usize) -> usize {
assert!(i <= self.len());
self.start[i] as usize
}
pub fn errors(&self) -> impl Iterator<Item = (&ops::Range<u32>, &str)> + '_ {
self.error.iter().map(|it| (&it.range, it.msg.as_str()))
}
fn push(&mut self, kind: SyntaxKind, offset: usize) {
self.kind.push(kind);
self.start.push(offset as u32);
}
}
struct Converter<'a> {
res: LexedStr<'a>,
offset: usize,
}
fn is_empty_quoted_ident(token_text: &str) -> bool {
let inner = if let Some(stripped) = token_text
.strip_prefix(['u', 'U'])
.and_then(|s| s.strip_prefix('&'))
{
stripped
} else {
token_text
};
inner == "\"\""
}
impl<'a> Converter<'a> {
fn new(text: &'a str) -> Self {
Self {
res: LexedStr {
text,
kind: Vec::new(),
start: Vec::new(),
error: Vec::new(),
},
offset: 0,
}
}
fn finalize_with_eof(mut self) -> LexedStr<'a> {
self.res.push(SyntaxKind::EOF, self.offset);
self.res
}
fn push(&mut self, kind: SyntaxKind, len: usize, err: Option<(&str, ops::Range<u32>)>) {
let token_start = self.offset as u32;
self.res.push(kind, self.offset);
self.offset += len;
if let Some((msg, err_range)) = err {
self.res.error.push(LexError {
msg: msg.to_owned(),
range: token_start + err_range.start..token_start + err_range.end,
});
}
}
fn extend_token(&mut self, kind: &squawk_lexer::TokenKind, token_text: &str) {
let mut err = "";
let mut err_range: Option<ops::Range<u32>> = None;
let syntax_kind = {
match kind {
squawk_lexer::TokenKind::LineComment => SyntaxKind::COMMENT,
squawk_lexer::TokenKind::BlockComment { terminated } => {
if !terminated {
err = "Missing trailing `*/` symbols to terminate the block comment";
}
SyntaxKind::COMMENT
}
squawk_lexer::TokenKind::Whitespace => SyntaxKind::WHITESPACE,
squawk_lexer::TokenKind::Ident => {
SyntaxKind::from_keyword(token_text).unwrap_or(SyntaxKind::IDENT)
}
squawk_lexer::TokenKind::Literal { kind, .. } => {
self.extend_literal(token_text, kind);
return;
}
squawk_lexer::TokenKind::Semi => SyntaxKind::SEMICOLON,
squawk_lexer::TokenKind::Comma => SyntaxKind::COMMA,
squawk_lexer::TokenKind::Dot => SyntaxKind::DOT,
squawk_lexer::TokenKind::OpenParen => SyntaxKind::L_PAREN,
squawk_lexer::TokenKind::CloseParen => SyntaxKind::R_PAREN,
squawk_lexer::TokenKind::OpenBracket => SyntaxKind::L_BRACK,
squawk_lexer::TokenKind::CloseBracket => SyntaxKind::R_BRACK,
squawk_lexer::TokenKind::OpenCurly => SyntaxKind::L_CURLY,
squawk_lexer::TokenKind::CloseCurly => SyntaxKind::R_CURLY,
squawk_lexer::TokenKind::At => SyntaxKind::AT,
squawk_lexer::TokenKind::Pound => SyntaxKind::POUND,
squawk_lexer::TokenKind::Tilde => SyntaxKind::TILDE,
squawk_lexer::TokenKind::Question => SyntaxKind::QUESTION,
squawk_lexer::TokenKind::Colon => SyntaxKind::COLON,
squawk_lexer::TokenKind::Eq => SyntaxKind::EQ,
squawk_lexer::TokenKind::Bang => SyntaxKind::BANG,
squawk_lexer::TokenKind::Lt => SyntaxKind::L_ANGLE,
squawk_lexer::TokenKind::Gt => SyntaxKind::R_ANGLE,
squawk_lexer::TokenKind::Minus => SyntaxKind::MINUS,
squawk_lexer::TokenKind::And => SyntaxKind::AMP,
squawk_lexer::TokenKind::Or => SyntaxKind::PIPE,
squawk_lexer::TokenKind::Plus => SyntaxKind::PLUS,
squawk_lexer::TokenKind::Star => SyntaxKind::STAR,
squawk_lexer::TokenKind::Slash => SyntaxKind::SLASH,
squawk_lexer::TokenKind::Caret => SyntaxKind::CARET,
squawk_lexer::TokenKind::Percent => SyntaxKind::PERCENT,
squawk_lexer::TokenKind::Unknown => SyntaxKind::ERROR,
squawk_lexer::TokenKind::UnknownPrefix => {
err = "unknown literal prefix";
SyntaxKind::IDENT
}
squawk_lexer::TokenKind::Eof => SyntaxKind::EOF,
squawk_lexer::TokenKind::Backtick => SyntaxKind::BACKTICK,
squawk_lexer::TokenKind::PositionalParam {
trailing_junk_start,
} => {
let digits = &token_text[1..*trailing_junk_start as usize];
if digits.is_empty() {
err = "missing parameter number";
err_range = Some(0..1);
} else if digits
.parse::<i32>()
.is_err_and(|err| matches!(err.kind(), IntErrorKind::PosOverflow))
{
err = "parameter number too large";
err_range = Some(0..*trailing_junk_start);
} else if (*trailing_junk_start as usize) < token_text.len() {
err = "trailing junk after positional parameter";
err_range = Some(*trailing_junk_start..token_text.len() as u32);
}
SyntaxKind::POSITIONAL_PARAM
}
squawk_lexer::TokenKind::QuotedIdent { terminated } => {
if !terminated {
err = "Missing trailing \" to terminate the quoted identifier"
} else if is_empty_quoted_ident(token_text) {
err = "empty delimited identifier";
}
SyntaxKind::IDENT
}
}
};
let err = if err.is_empty() { None } else { Some(err) };
let err = err.map(|msg| (msg, err_range.unwrap_or(0..token_text.len() as u32)));
self.push(syntax_kind, token_text.len(), err);
}
fn extend_literal(&mut self, token_text: &str, kind: &squawk_lexer::LiteralKind) {
let mut err: Option<String> = None;
let mut err_range: Option<ops::Range<u32>> = None;
let syntax_kind = match *kind {
squawk_lexer::LiteralKind::Int {
empty_int,
base,
trailing_junk_start,
} => {
if empty_int {
err = Some("Missing digits after the integer base prefix".into());
} else {
if matches!(base, squawk_lexer::Base::Binary | squawk_lexer::Base::Octal) {
let prefix_len = 2u32;
let digits = &token_text[prefix_len as usize..trailing_junk_start as usize];
let base = base as u32;
let token_start = self.offset as u32;
for (i, c) in digits.char_indices() {
if c != '_' && c.to_digit(base).is_none() {
let start = token_start + prefix_len + i as u32;
let end = start + c.len_utf8() as u32;
self.res.error.push(LexError {
msg: format!("invalid digit for a base {base} literal"),
range: start..end,
});
}
}
}
if (trailing_junk_start as usize) < token_text.len() {
err = Some("trailing junk after numeric literal".into());
err_range = Some(trailing_junk_start..token_text.len() as u32);
}
}
SyntaxKind::INT_NUMBER
}
squawk_lexer::LiteralKind::Numeric {
empty_exponent_start,
base: _,
trailing_junk_start,
} => {
if let Some(exponent_start) = empty_exponent_start {
err = Some("Missing digits after the exponent symbol".into());
err_range = Some(exponent_start..exponent_start + 1);
} else if (trailing_junk_start as usize) < token_text.len() {
err = Some("trailing junk after numeric literal".into());
err_range = Some(trailing_junk_start..token_text.len() as u32);
}
SyntaxKind::NUMERIC_NUMBER
}
squawk_lexer::LiteralKind::Str { terminated } => {
if !terminated {
err =
Some("Missing trailing `'` symbol to terminate the string literal".into());
}
SyntaxKind::STRING
}
squawk_lexer::LiteralKind::ByteStr { terminated } => {
if !terminated {
err = Some(
"Missing trailing `'` symbol to terminate the hex bit string literal"
.into(),
);
}
SyntaxKind::BYTE_STRING
}
squawk_lexer::LiteralKind::BitStr { terminated } => {
if !terminated {
err = Some(
"Missing trailing `'` symbol to terminate the bit string literal".into(),
);
}
SyntaxKind::BIT_STRING
}
squawk_lexer::LiteralKind::DollarQuotedString { terminated } => {
if !terminated {
err = Some("Unterminated dollar quoted string literal".into());
}
SyntaxKind::DOLLAR_QUOTED_STRING
}
squawk_lexer::LiteralKind::UnicodeEscStr { terminated } => {
if !terminated {
err = Some(
"Missing trailing `'` symbol to terminate the unicode escape string literal"
.into(),
);
}
SyntaxKind::UNICODE_ESC_STRING
}
squawk_lexer::LiteralKind::EscStr { terminated } => {
if !terminated {
err = Some(
"Missing trailing `'` symbol to terminate the escape string literal".into(),
);
}
SyntaxKind::ESC_STRING
}
};
let err = err
.as_deref()
.map(|msg| (msg, err_range.unwrap_or(0..token_text.len() as u32)));
self.push(syntax_kind, token_text.len(), err);
}
}
#[cfg(test)]
mod tests {
use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
use insta::assert_snapshot;
use super::LexedStr;
fn lex(text: &str) -> String {
let lexed = LexedStr::new(text);
let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
let mut res = String::new();
for (range, msg) in lexed.errors() {
let span = range.start as usize..range.end as usize;
let group = Level::ERROR.primary_title(msg).element(
Snippet::source(text)
.fold(true)
.annotation(AnnotationKind::Primary.span(span)),
);
res.push_str(&renderer.render(&[group]).to_string());
res.push('\n');
}
res
}
#[test]
fn empty_int_error() {
assert_snapshot!(lex("select 0x;"), @"
error: Missing digits after the integer base prefix
╭▸
1 │ select 0x;
╰╴ ━━
");
}
#[test]
fn empty_int_with_trailing_ident_error() {
assert_snapshot!(lex("select 0xg;"), @"
error: trailing junk after numeric literal
╭▸
1 │ select 0xg;
╰╴ ━
");
}
#[test]
fn invalid_octal_digits_error() {
assert_snapshot!(lex("select 0o999;"), @"
error: invalid digit for a base 8 literal
╭▸
1 │ select 0o999;
╰╴ ━
error: invalid digit for a base 8 literal
╭▸
1 │ select 0o999;
╰╴ ━
error: invalid digit for a base 8 literal
╭▸
1 │ select 0o999;
╰╴ ━
");
}
#[test]
fn invalid_binary_digits_error() {
assert_snapshot!(lex("select 0b234;"), @"
error: invalid digit for a base 2 literal
╭▸
1 │ select 0b234;
╰╴ ━
error: invalid digit for a base 2 literal
╭▸
1 │ select 0b234;
╰╴ ━
error: invalid digit for a base 2 literal
╭▸
1 │ select 0b234;
╰╴ ━
");
}
#[test]
fn invalid_octal_digits_after_valid_error() {
assert_snapshot!(lex("select 0o7889;"), @"
error: invalid digit for a base 8 literal
╭▸
1 │ select 0o7889;
╰╴ ━
error: invalid digit for a base 8 literal
╭▸
1 │ select 0o7889;
╰╴ ━
error: invalid digit for a base 8 literal
╭▸
1 │ select 0o7889;
╰╴ ━
");
}
#[test]
fn empty_exponent_error() {
assert_snapshot!(lex("select 1e;"), @"
error: Missing digits after the exponent symbol
╭▸
1 │ select 1e;
╰╴ ━
");
}
#[test]
fn unterminated_string_error() {
assert_snapshot!(lex("select 'hello;"), @"
error: Missing trailing `'` symbol to terminate the string literal
╭▸
1 │ select 'hello;
╰╴ ━━━━━━━
");
}
#[test]
fn unterminated_hex_bit_string_error() {
assert_snapshot!(lex("select X'1F;"), @"
error: Missing trailing `'` symbol to terminate the hex bit string literal
╭▸
1 │ select X'1F;
╰╴ ━━━━━
");
}
#[test]
fn unterminated_bit_string_error() {
assert_snapshot!(lex("select B'101;"), @"
error: Missing trailing `'` symbol to terminate the bit string literal
╭▸
1 │ select B'101;
╰╴ ━━━━━━
");
}
#[test]
fn unterminated_dollar_quoted_string_error() {
assert_snapshot!(lex("select $tag$hello;"), @"
error: Unterminated dollar quoted string literal
╭▸
1 │ select $tag$hello;
╰╴ ━━━━━━━━━━━
");
}
#[test]
fn unterminated_unicode_escape_string_error() {
assert_snapshot!(lex("select U&'hello;"), @"
error: Missing trailing `'` symbol to terminate the unicode escape string literal
╭▸
1 │ select U&'hello;
╰╴ ━━━━━━━━━
");
}
#[test]
fn unterminated_escape_string_error() {
assert_snapshot!(lex("select E'hello;"), @"
error: Missing trailing `'` symbol to terminate the escape string literal
╭▸
1 │ select E'hello;
╰╴ ━━━━━━━━
");
}
}