use std::borrow::Cow;
use std::fmt;
use crate::cursor::InputCursor;
use crate::span::{HasSpan, Position, Span};
use crate::symbol::Symbol;
use crate::token::{CommentKind, NumberLiteral, StringLiteralKind, Token, TokenValue};
type ScanResult<'a> = Result<TokenValue<'a>, LexerError>;
fn is_whitespace(c: char) -> bool {
c == ' ' || c == '\r' || c == '\n' || c == '\t'
}
fn is_digit(c: char) -> bool {
char::is_ascii_digit(&c)
}
fn is_hex_digit(c: char) -> bool {
matches!(c, '0'..='9' | 'a'..='f' | 'A'..='F')
}
fn is_ident_start(c: char) -> bool {
matches!(c, 'a'..='z' | 'A'..='Z' | '_')
}
fn is_ident(c: char) -> bool {
is_ident_start(c) || is_digit(c)
}
fn is_newline(c: char) -> bool {
c == '\n' || c == '\r'
}
#[allow(clippy::unreadable_literal)]
fn lua_utf8ish_encode(c: u32, buf: &mut [u8]) -> Result<usize, LexerErrorKind> {
match c {
0x00..=0x7f => {
buf[0] = c as u8;
Ok(1)
}
0x0080..=0x7ff => {
buf[0] = ((0b110 << 5) | (c >> 6)) as u8;
buf[1] = ((0b10 << 6) | (c & 0x3f)) as u8;
Ok(2)
}
0x800..=0xffff => {
buf[0] = ((0b1110 << 4) | (c >> 12)) as u8;
buf[1] = ((0b10 << 6) | ((c >> 6) & 0x3f)) as u8;
buf[2] = ((0b10 << 6) | (c & 0x3f)) as u8;
Ok(3)
}
0x10000..=0x10ffff => {
buf[0] = ((0b11110 << 3) | (c >> 18)) as u8;
buf[1] = ((0b10 << 6) | ((c >> 12) & 0x3f)) as u8;
buf[2] = ((0b10 << 6) | ((c >> 6) & 0x3f)) as u8;
buf[3] = ((0b10 << 6) | (c & 0x3f)) as u8;
Ok(4)
}
_ => Err(LexerErrorKind::InvalidUtf8Codepoint),
}
}
fn parse_number_literal<'a>(
hex: bool,
int_part: Option<&'a str>,
frac_part: Option<&'a str>,
exp: Option<f64>,
) -> NumberLiteral {
let radix = if hex { 16 } else { 10 };
if let (Some(i), None, None) = (int_part, frac_part, exp) {
let int = if hex {
u64::from_str_radix(&i[..16.min(i.len())], radix).map(|i| i as i64)
} else {
i64::from_str_radix(i, radix)
};
if let Ok(int) = int {
return NumberLiteral::Integer(int);
}
}
let mut value: f64 = if let Some(i) = int_part {
i.chars().fold(0f64, |acc, c| {
acc * (radix as f64) + (c.to_digit(radix).unwrap() as f64)
})
} else {
0f64
};
if let Some(f) = frac_part {
value += f.chars().rfold(0f64, |acc, c| {
(acc + (c.to_digit(radix).unwrap() as f64)) / (radix as f64)
});
}
if let Some(exponent) = exp {
if hex {
value *= 2f64.powf(exponent);
} else {
value *= 10f64.powf(exponent);
}
}
NumberLiteral::Float(value)
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum LexerErrorKind {
ExpectedExponent,
ExponentTooLarge,
ExpectedDigit,
UnescapedNewline,
UnclosedString,
UnknownEscape,
UnexpectedEof,
ExpectedHexDigit,
Expected2HexDigits,
DoesntFitIntoByte,
ExpectedUpTo3Digits,
InvalidUtf8Codepoint,
Utf8CodepointTooLarge,
UnclosedLongString,
UnrecognizedCharacter,
ExpectedChar(char),
}
impl fmt::Display for LexerErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Self::ExpectedChar(c) = self {
write!(f, "expected a `{}`", c)
} else {
write!(f, "{}", match self {
Self::ExpectedExponent => "expected an exponent",
Self::ExponentTooLarge => "the exponent is too large",
Self::ExpectedDigit => "expected a digit",
Self::UnescapedNewline => "an unescaped newline in the string literal",
Self::UnclosedString => "an unclosed string literal",
Self::UnknownEscape => "an unknown escape sequence",
Self::UnexpectedEof => "an unexpected end of file",
Self::ExpectedHexDigit => "expected a hex digit",
Self::Expected2HexDigits => "expected 2 hex digits",
Self::DoesntFitIntoByte => "expected the number be less than 256",
Self::ExpectedUpTo3Digits => "expected up to 3 decimal digits",
Self::InvalidUtf8Codepoint => "the UTF-8 codepoint is invalid",
Self::Utf8CodepointTooLarge => "the UTF-8 codepoint is too large",
Self::UnclosedLongString => "an unclosed long string",
Self::UnrecognizedCharacter => "an unrecognized character",
Self::ExpectedChar(_) => unreachable!(),
})
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct LexerError {
start: Option<Position>,
end: Position,
kind: LexerErrorKind,
}
impl LexerError {
pub fn kind(&self) -> LexerErrorKind {
self.kind
}
}
impl fmt::Display for LexerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.kind)
}
}
impl std::error::Error for LexerError {}
impl HasSpan for LexerError {
fn span(&self) -> Span {
Span {
start: self.start.unwrap(),
end: self.end,
}
}
}
#[derive(Clone, Debug)]
pub struct Lexer<'a> {
cursor: InputCursor<'a>,
finished: bool,
eof: bool,
}
impl<'a> Lexer<'a> {
pub fn new(cursor: InputCursor<'a>) -> Self {
Lexer {
cursor,
finished: false,
eof: false,
}
}
pub fn cursor(&self) -> &InputCursor<'a> {
&self.cursor
}
fn scan_number(&mut self) -> ScanResult<'a> {
let hex = self
.cursor
.consume_n_if(2, |s| s == "0x" || s == "0X")
.is_some();
let predicate = if hex { is_hex_digit } else { is_digit };
let int_part = self.cursor.consume_while(predicate);
let frac_part = if self.cursor.consume_n_if(1, |c| c == ".").is_some() {
self.cursor.consume_while(predicate).or(Some(""))
} else {
None
};
let is_exp_char = |c| hex && c == 'p' || !hex && c == 'e';
let exp = if self.cursor.consume_if(is_exp_char).is_some() {
let sign = self.cursor.consume_if(|c| c == '+' || c == '-');
let exponent = self
.cursor
.consume_while(is_digit)
.ok_or_else(|| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExpectedExponent,
})?;
let mut exponent = exponent.parse::<i64>().map_err(|_| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExponentTooLarge,
})?;
if let Some('-') = sign {
exponent = -exponent;
}
Some(exponent as f64)
} else {
None
};
match (int_part, frac_part) {
(None, None) => Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExpectedDigit,
}),
_ => Ok(TokenValue::Number(parse_number_literal(
hex, int_part, frac_part, exp,
))),
}
}
fn scan_ident_or_kword(&mut self) -> ScanResult<'a> {
let word = self.cursor.consume_while(is_ident).unwrap();
if let Some(&v) = Symbol::SYMBOLS.get(word) {
Ok(TokenValue::Symbol(v))
} else {
Ok(TokenValue::Ident(word))
}
}
fn scan_long_string(&mut self) -> ScanResult<'a> {
let (level, s) = self.scan_long_bracket(true)?;
let s = match s {
Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()),
Cow::Owned(v) => Cow::Owned(v.as_bytes().to_owned()),
};
Ok(TokenValue::String {
value: s,
kind: StringLiteralKind::Bracketed { level },
})
}
fn scan_short_string(&mut self) -> ScanResult<'a> {
let quote = self.cursor.next().unwrap();
let mut s = Cow::from(&[] as &[u8]);
let mut terminated = false;
let mut buf = [0; 4];
let start = self.cursor.pos();
while let Some(c) = self.cursor.next() {
if c == quote {
terminated = true;
break;
}
if is_newline(c) {
return Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::UnescapedNewline,
});
}
if c == '\\' {
self.parse_string_escape(s.to_mut())?;
} else if let Cow::Borrowed(ref mut string) = s {
*string = self.cursor.substr(start..self.cursor.pos()).as_bytes();
} else {
s.to_mut()
.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
}
}
if !terminated {
return Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::UnclosedString,
});
}
let kind = match quote {
'\'' => StringLiteralKind::SingleQuoted,
'"' => StringLiteralKind::DoubleQuoted,
_ => panic!("the string doesn't start with a quote character"),
};
Ok(TokenValue::String { value: s, kind })
}
fn parse_string_escape(&mut self, s: &mut Vec<u8>) -> Result<(), LexerError> {
if let Some(c) = self.cursor.peek() {
match c {
'z' => {
self.cursor.next();
self.cursor.consume_while(is_whitespace);
}
'x' => {
self.cursor.next();
self.parse_hex_escape(s)?
}
'u' => {
self.cursor.next();
self.parse_utf8_escape(s)?
}
c if c.is_ascii_digit() => self.parse_dec_escape(s)?,
_ => {
self.cursor.next();
s.push(match c {
'a' => b'\x07',
'b' => b'\x08',
'f' => b'\x0c',
'n' => b'\n',
'r' => b'\r',
't' => b'\t',
'v' => b'\x0b',
'\\' => b'\\',
'"' => b'"',
'\'' => b'\'',
'\n' => b'\n',
_ => {
return Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::UnknownEscape,
})
}
});
}
}
Ok(())
} else {
Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::UnexpectedEof,
})
}
}
fn parse_hex_escape(&mut self, s: &mut Vec<u8>) -> Result<(), LexerError> {
if let Some(hex) = self.cursor.consume_n_if(2, |s| s.chars().all(is_hex_digit)) {
let byte = u8::from_str_radix(hex, 16).unwrap();
s.push(byte);
Ok(())
} else {
Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::Expected2HexDigits,
})
}
}
fn parse_dec_escape(&mut self, s: &mut Vec<u8>) -> Result<(), LexerError> {
if let Some(dec) = self.cursor.consume_first_n_while(3, is_digit) {
let byte = dec.parse::<u8>().map_err(|_| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::DoesntFitIntoByte,
})?;
s.push(byte);
Ok(())
} else {
Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExpectedUpTo3Digits,
})
}
}
fn parse_utf8_escape(&mut self, s: &mut Vec<u8>) -> Result<(), LexerError> {
self.cursor
.consume_if(|c| c == '{')
.ok_or_else(|| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExpectedChar('{'),
})?;
let hex = self
.cursor
.consume_while(is_hex_digit)
.ok_or_else(|| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExpectedHexDigit,
})?;
#[allow(clippy::unreadable_literal)]
let value = u32::from_str_radix(hex, 16)
.map_err(|_| ())
.and_then(|v| if v > 0x10ffff { Err(()) } else { Ok(v) })
.map_err(|_| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::Utf8CodepointTooLarge,
})?;
self.cursor
.consume_if(|c| c == '}')
.ok_or_else(|| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExpectedChar('}'),
})?;
let mut buf = [0; 4];
let len = lua_utf8ish_encode(value, &mut buf).map_err(|kind| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind,
})?;
s.extend_from_slice(&buf[..len]);
Ok(())
}
fn scan_long_bracket(
&mut self,
coerce_newlines: bool,
) -> Result<(usize, Cow<'a, str>), LexerError> {
self.cursor.next();
let level = self.cursor.consume_while(|c| c == '=').unwrap_or("").len();
self.cursor
.consume_if(|c| c == '[')
.ok_or_else(|| LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::ExpectedChar('['),
})?;
let mut s = Cow::from("");
let mut handled_newline = false;
if coerce_newlines {
if self.cursor.starts_with("\r\n") {
self.cursor.skip_n(2);
} else if self.cursor.starts_with("\r") || self.cursor.starts_with("\n") {
self.cursor.next();
}
}
let start = self.cursor.pos();
while let Some(c) = self.cursor.next() {
match c {
']' if self
.cursor
.remaining()
.chars()
.take(level)
.all(|c| c == '=')
&& self.cursor.remaining().chars().nth(level) == Some(']') =>
{
self.cursor.skip_n(level + 1);
return Ok((level, s));
}
'\r' | '\n' if handled_newline => handled_newline = false,
'\r' | '\n'
if coerce_newlines
&& (self.cursor.starts_with("\r\n") || self.cursor.starts_with("\n\r")) =>
{
handled_newline = true;
s = Cow::from(s.into_owned());
s.to_mut().push('\n');
}
_ => {
if let Cow::Borrowed(ref mut s) = s {
*s = &self.cursor.substr(start..self.cursor.pos());
} else {
s.to_mut().push(c);
}
}
}
}
Err(LexerError {
start: None,
end: self.cursor.prev_pos(),
kind: LexerErrorKind::UnclosedLongString,
})
}
fn scan_comment(&mut self) -> ScanResult<'a> {
self.cursor.consume_n(2);
let mut long_comment = false;
if let Some('[') = self.cursor.peek() {
let mut i = 1;
while let Some('=') = self.cursor.peek_nth(i) {
i += 1;
}
if let Some('[') = self.cursor.peek_nth(i) {
long_comment = true;
}
}
if long_comment {
let (level, value) = self.scan_long_bracket(false)?;
Ok(TokenValue::Comment {
value: match value {
Cow::Borrowed(v) => v,
_ => panic!("long comments are supposed to stay immutable"),
},
kind: CommentKind::Bracketed { level },
})
} else {
let value = self.cursor.consume_while(|c| !is_newline(c)).unwrap_or("");
Ok(TokenValue::Comment {
value,
kind: CommentKind::Unbracketed,
})
}
}
}
impl<'a> Iterator for Lexer<'a> {
type Item = Result<Token<'a>, LexerError>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished {
return None;
}
let start = self.cursor.pos();
if self.cursor.peek().is_none() {
return if !self.eof {
self.eof = true;
let pos = self.cursor.pos_no_newline();
Some(Ok(Token {
value: TokenValue::Eof,
span: Span { start: pos, end: pos },
}))
} else {
None
};
}
let value = match self.cursor.peek()? {
c if is_whitespace(c) => Ok(TokenValue::Whitespace(
self.cursor.consume_while(is_whitespace).unwrap(),
)),
'0'..='9' => self.scan_number(),
'.' if self
.cursor
.peek_nth(1)
.filter(char::is_ascii_digit)
.is_some() =>
{
self.scan_number()
}
'\'' | '"' => self.scan_short_string(),
'-' if self.cursor.starts_with("--") => self.scan_comment(),
'[' if self.cursor.starts_with("[[") || self.cursor.starts_with("[=") => {
self.scan_long_string()
}
c => {
if is_ident_start(c) {
self.scan_ident_or_kword()
} else if let Some((symbol, len)) = Symbol::parse(self.cursor.remaining()) {
self.cursor.skip_n(len);
Ok(TokenValue::Symbol(symbol))
} else {
Err(LexerError {
start: None,
end: self.cursor.pos(),
kind: LexerErrorKind::UnrecognizedCharacter,
})
}
}
};
let token = value
.map(|token| Token {
value: token,
span: Span {
start,
end: self.cursor.prev_pos(),
},
})
.map_err(|mut e| {
e.start = Some(start);
e
});
if token.is_err() {
self.finished = true;
}
Some(token)
}
}