icydb-core 0.76.5

IcyDB — A type-safe, embedded ORM and schema system for the Internet Computer
Documentation
use crate::db::reduced_sql::{Keyword, SqlParseError, Token, TokenKind};

pub(crate) fn tokenize_sql(sql: &str) -> Result<Vec<Token>, SqlParseError> {
    Lexer::tokenize(sql)
}

struct Lexer<'a> {
    bytes: &'a [u8],
    pos: usize,
}

impl<'a> Lexer<'a> {
    fn tokenize(sql: &'a str) -> Result<Vec<Token>, SqlParseError> {
        let mut lexer = Self {
            bytes: sql.as_bytes(),
            pos: 0,
        };
        let mut tokens = Vec::new();

        while let Some(token) = lexer.next_token()? {
            tokens.push(token);
        }

        Ok(tokens)
    }

    fn next_token(&mut self) -> Result<Option<Token>, SqlParseError> {
        self.skip_whitespace();
        let Some(byte) = self.peek_byte() else {
            return Ok(None);
        };

        let kind = match byte {
            b',' => {
                self.pos += 1;
                TokenKind::Comma
            }
            b'.' => {
                self.pos += 1;
                TokenKind::Dot
            }
            b'(' => {
                self.pos += 1;
                TokenKind::LParen
            }
            b')' => {
                self.pos += 1;
                TokenKind::RParen
            }
            b';' => {
                self.pos += 1;
                TokenKind::Semicolon
            }
            b'*' => {
                self.pos += 1;
                TokenKind::Star
            }
            b'=' => {
                self.pos += 1;
                TokenKind::Eq
            }
            b'!' => {
                self.pos += 1;
                if self.consume_if(b'=') {
                    TokenKind::Ne
                } else {
                    return Err(SqlParseError::invalid_syntax("unexpected '!'"));
                }
            }
            b'<' => {
                self.pos += 1;
                if self.consume_if(b'=') {
                    TokenKind::Lte
                } else if self.consume_if(b'>') {
                    TokenKind::Ne
                } else {
                    TokenKind::Lt
                }
            }
            b'>' => {
                self.pos += 1;
                if self.consume_if(b'=') {
                    TokenKind::Gte
                } else {
                    TokenKind::Gt
                }
            }
            b'\'' => TokenKind::StringLiteral(self.lex_string_literal()?),
            b'"' | b'`' => return Err(SqlParseError::unsupported_feature("quoted identifiers")),
            b'-' => {
                if self
                    .peek_second_byte()
                    .is_some_and(|next| next.is_ascii_digit())
                {
                    self.pos += 1;
                    TokenKind::Number(self.lex_number(true))
                } else {
                    return Err(SqlParseError::invalid_syntax("unexpected '-'"));
                }
            }
            next if next.is_ascii_digit() => TokenKind::Number(self.lex_number(false)),
            next if is_identifier_start(next) => self.lex_identifier_or_keyword(),
            other => {
                return Err(SqlParseError::invalid_syntax(format!(
                    "unexpected character '{}'; reduced SQL supports bare identifiers, strings, numbers, and simple operators",
                    other as char
                )));
            }
        };

        Ok(Some(Token { kind }))
    }

    fn skip_whitespace(&mut self) {
        while self
            .peek_byte()
            .is_some_and(|byte| byte.is_ascii_whitespace())
        {
            self.pos += 1;
        }
    }

    fn peek_byte(&self) -> Option<u8> {
        self.bytes.get(self.pos).copied()
    }

    fn peek_second_byte(&self) -> Option<u8> {
        self.bytes.get(self.pos + 1).copied()
    }

    fn lex_string_literal(&mut self) -> Result<String, SqlParseError> {
        self.expect_byte(b'\'')?;
        let mut out = String::new();
        while let Some(byte) = self.peek_byte() {
            self.pos += 1;
            if byte == b'\'' {
                if self.peek_byte() == Some(b'\'') {
                    self.pos += 1;
                    out.push('\'');
                    continue;
                }

                return Ok(out);
            }
            out.push(byte as char);
        }

        Err(SqlParseError::invalid_syntax("unterminated string literal"))
    }

    fn lex_number(&mut self, negative: bool) -> String {
        let start = if negative { self.pos - 1 } else { self.pos };

        while self.peek_byte().is_some_and(|byte| byte.is_ascii_digit()) {
            self.pos += 1;
        }
        if self.peek_byte() == Some(b'.')
            && self
                .peek_second_byte()
                .is_some_and(|byte| byte.is_ascii_digit())
        {
            self.pos += 1;
            while self.peek_byte().is_some_and(|byte| byte.is_ascii_digit()) {
                self.pos += 1;
            }
        }

        std::str::from_utf8(&self.bytes[start..self.pos])
            .expect("numeric token bytes must remain utf-8")
            .to_owned()
    }

    fn lex_identifier_or_keyword(&mut self) -> TokenKind {
        let start = self.pos;
        self.pos += 1;
        while self.peek_byte().is_some_and(is_identifier_continue) {
            self.pos += 1;
        }
        let out = std::str::from_utf8(&self.bytes[start..self.pos])
            .expect("identifier token bytes must remain utf-8")
            .to_owned();
        match keyword_from_ident(out.as_str()) {
            Some(keyword) => TokenKind::Keyword(keyword),
            None => TokenKind::Identifier(out),
        }
    }

    fn consume_if(&mut self, expected: u8) -> bool {
        if self.peek_byte() != Some(expected) {
            return false;
        }

        self.pos += 1;
        true
    }

    fn expect_byte(&mut self, expected: u8) -> Result<(), SqlParseError> {
        match self.peek_byte() {
            Some(found) if found == expected => {
                self.pos += 1;
                Ok(())
            }
            Some(found) => Err(SqlParseError::invalid_syntax(format!(
                "expected '{}', found '{}'",
                expected as char, found as char
            ))),
            None => Err(SqlParseError::invalid_syntax(format!(
                "expected '{}', found end of input",
                expected as char
            ))),
        }
    }
}

const fn is_identifier_start(byte: u8) -> bool {
    byte.is_ascii_alphabetic() || byte == b'_'
}

const fn is_identifier_continue(byte: u8) -> bool {
    byte.is_ascii_alphanumeric() || byte == b'_'
}

const fn keyword_from_ident(value: &str) -> Option<Keyword> {
    match value.len() {
        2 if value.eq_ignore_ascii_case("AS") => Some(Keyword::As),
        2 if value.eq_ignore_ascii_case("BY") => Some(Keyword::By),
        2 if value.eq_ignore_ascii_case("IN") => Some(Keyword::In),
        2 if value.eq_ignore_ascii_case("IS") => Some(Keyword::Is),
        2 if value.eq_ignore_ascii_case("OR") => Some(Keyword::Or),
        3 if value.eq_ignore_ascii_case("AND") => Some(Keyword::And),
        3 if value.eq_ignore_ascii_case("ASC") => Some(Keyword::Asc),
        3 if value.eq_ignore_ascii_case("AVG") => Some(Keyword::Avg),
        3 if value.eq_ignore_ascii_case("MAX") => Some(Keyword::Max),
        3 if value.eq_ignore_ascii_case("MIN") => Some(Keyword::Min),
        3 if value.eq_ignore_ascii_case("NOT") => Some(Keyword::Not),
        3 if value.eq_ignore_ascii_case("SUM") => Some(Keyword::Sum),
        4 if value.eq_ignore_ascii_case("DESC") => Some(Keyword::Desc),
        4 if value.eq_ignore_ascii_case("FROM") => Some(Keyword::From),
        4 if value.eq_ignore_ascii_case("JOIN") => Some(Keyword::Join),
        4 if value.eq_ignore_ascii_case("JSON") => Some(Keyword::Json),
        4 if value.eq_ignore_ascii_case("NULL") => Some(Keyword::Null),
        4 if value.eq_ignore_ascii_case("SHOW") => Some(Keyword::Show),
        4 if value.eq_ignore_ascii_case("TRUE") => Some(Keyword::True),
        4 if value.eq_ignore_ascii_case("WITH") => Some(Keyword::With),
        5 if value.eq_ignore_ascii_case("COUNT") => Some(Keyword::Count),
        5 if value.eq_ignore_ascii_case("FALSE") => Some(Keyword::False),
        5 if value.eq_ignore_ascii_case("GROUP") => Some(Keyword::Group),
        5 if value.eq_ignore_ascii_case("LIMIT") => Some(Keyword::Limit),
        5 if value.eq_ignore_ascii_case("ORDER") => Some(Keyword::Order),
        5 if value.eq_ignore_ascii_case("UNION") => Some(Keyword::Union),
        5 if value.eq_ignore_ascii_case("WHERE") => Some(Keyword::Where),
        6 if value.eq_ignore_ascii_case("DELETE") => Some(Keyword::Delete),
        6 if value.eq_ignore_ascii_case("EXCEPT") => Some(Keyword::Except),
        6 if value.eq_ignore_ascii_case("HAVING") => Some(Keyword::Having),
        6 if value.eq_ignore_ascii_case("INSERT") => Some(Keyword::Insert),
        6 if value.eq_ignore_ascii_case("OFFSET") => Some(Keyword::Offset),
        6 if value.eq_ignore_ascii_case("SELECT") => Some(Keyword::Select),
        6 if value.eq_ignore_ascii_case("UPDATE") => Some(Keyword::Update),
        7 if value.eq_ignore_ascii_case("BETWEEN") => Some(Keyword::Between),
        7 if value.eq_ignore_ascii_case("COLUMNS") => Some(Keyword::Columns),
        7 if value.eq_ignore_ascii_case("EXPLAIN") => Some(Keyword::Explain),
        7 if value.eq_ignore_ascii_case("INDEXES") => Some(Keyword::Indexes),
        8 if value.eq_ignore_ascii_case("DESCRIBE") => Some(Keyword::Describe),
        8 if value.eq_ignore_ascii_case("DISTINCT") => Some(Keyword::Distinct),
        8 if value.eq_ignore_ascii_case("ENTITIES") => Some(Keyword::Entities),
        9 if value.eq_ignore_ascii_case("EXECUTION") => Some(Keyword::Execution),
        9 if value.eq_ignore_ascii_case("INTERSECT") => Some(Keyword::Intersect),
        _ => None,
    }
}