use crate::error::{ParseError, ParseResult, Position, Span};
#[derive(Debug, Clone, PartialEq)]
pub enum Token {
Integer(i64),
Float(f64),
Identifier(String),
Plus,
Minus,
Star,
DoubleStar,
Slash,
Caret,
Percent,
Bang,
LParen,
RParen,
LBracket,
RBracket,
LBrace,
RBrace,
Comma,
Semicolon,
Equals,
NotEquals,
Less,
LessEq,
Greater,
GreaterEq,
Underscore,
Eof,
Pi,
Infinity,
Sqrt,
Dot,
Cross,
Grad,
Div,
Curl,
Laplacian,
ForAll,
Exists,
Union,
Intersect,
In,
NotIn,
And,
Or,
Not,
Implies,
Iff,
}
impl std::fmt::Display for Token {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Token::Integer(n) => write!(f, "{}", n),
Token::Float(n) => write!(f, "{}", n),
Token::Identifier(s) => write!(f, "{}", s),
Token::Plus => write!(f, "+"),
Token::Minus => write!(f, "-"),
Token::Star => write!(f, "*"),
Token::DoubleStar => write!(f, "**"),
Token::Slash => write!(f, "/"),
Token::Caret => write!(f, "^"),
Token::Percent => write!(f, "%"),
Token::Bang => write!(f, "!"),
Token::LParen => write!(f, "("),
Token::RParen => write!(f, ")"),
Token::LBracket => write!(f, "["),
Token::RBracket => write!(f, "]"),
Token::LBrace => write!(f, "{{"),
Token::RBrace => write!(f, "}}"),
Token::Comma => write!(f, ","),
Token::Semicolon => write!(f, ";"),
Token::Equals => write!(f, "="),
Token::NotEquals => write!(f, "!="),
Token::Less => write!(f, "<"),
Token::LessEq => write!(f, "<="),
Token::Greater => write!(f, ">"),
Token::GreaterEq => write!(f, ">="),
Token::Underscore => write!(f, "_"),
Token::Eof => write!(f, "<EOF>"),
Token::Pi => write!(f, "π"),
Token::Infinity => write!(f, "∞"),
Token::Sqrt => write!(f, "√"),
Token::Dot => write!(f, "dot"),
Token::Cross => write!(f, "cross"),
Token::Grad => write!(f, "grad"),
Token::Div => write!(f, "div"),
Token::Curl => write!(f, "curl"),
Token::Laplacian => write!(f, "laplacian"),
Token::ForAll => write!(f, "forall"),
Token::Exists => write!(f, "exists"),
Token::Union => write!(f, "union"),
Token::Intersect => write!(f, "intersect"),
Token::In => write!(f, "in"),
Token::NotIn => write!(f, "notin"),
Token::And => write!(f, "and"),
Token::Or => write!(f, "or"),
Token::Not => write!(f, "not"),
Token::Implies => write!(f, "implies"),
Token::Iff => write!(f, "iff"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Spanned<T> {
pub value: T,
pub span: Span,
}
impl<T> Spanned<T> {
pub fn new(value: T, span: Span) -> Self {
Self { value, span }
}
}
pub type SpannedToken = Spanned<Token>;
struct Tokenizer<'a> {
input: &'a str,
offset: usize,
line: usize,
column: usize,
}
impl<'a> Tokenizer<'a> {
fn new(input: &'a str) -> Self {
Self {
input,
offset: 0,
line: 1,
column: 1,
}
}
fn position(&self) -> Position {
Position::new(self.line, self.column, self.offset)
}
fn peek(&self) -> Option<char> {
self.input[self.offset..].chars().next()
}
fn peek_ahead(&self, n: usize) -> Option<char> {
self.input[self.offset..].chars().nth(n)
}
fn consume(&mut self) -> Option<char> {
let ch = self.peek()?;
self.offset += ch.len_utf8();
if ch == '\n' {
self.line += 1;
self.column = 1;
} else {
self.column += 1;
}
Some(ch)
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.peek() {
if ch.is_whitespace() {
self.consume();
} else {
break;
}
}
}
fn scan_number(&mut self) -> ParseResult<(Token, Span)> {
let start = self.position();
let mut number_str = String::new();
let mut has_dot = false;
let mut has_exp = false;
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number_str.push(ch);
self.consume();
} else {
break;
}
}
if self.peek() == Some('.') {
if let Some(next) = self.peek_ahead(1) {
if next.is_ascii_digit() {
has_dot = true;
number_str.push('.');
self.consume();
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number_str.push(ch);
self.consume();
} else {
break;
}
}
}
}
}
if let Some(ch) = self.peek() {
if ch == 'e' || ch == 'E' {
has_exp = true;
number_str.push(ch);
self.consume();
if let Some(sign) = self.peek() {
if sign == '+' || sign == '-' {
number_str.push(sign);
self.consume();
}
}
while let Some(ch) = self.peek() {
if ch.is_ascii_digit() {
number_str.push(ch);
self.consume();
} else {
break;
}
}
}
}
let end = self.position();
let span = Span::new(start, end);
if has_dot || has_exp {
match number_str.parse::<f64>() {
Ok(n) => Ok((Token::Float(n), span)),
Err(_) => Err(ParseError::invalid_number(
&number_str,
"invalid float",
Some(span),
)),
}
} else {
match number_str.parse::<i64>() {
Ok(n) => Ok((Token::Integer(n), span)),
Err(_) => Err(ParseError::invalid_number(
&number_str,
"invalid integer",
Some(span),
)),
}
}
}
fn scan_identifier(&mut self) -> (Token, Span) {
let start = self.position();
let mut ident = String::new();
if let Some(ch) = self.peek() {
if ch.is_ascii_alphabetic() {
ident.push(ch);
self.consume();
}
}
while let Some(ch) = self.peek() {
if ch.is_ascii_alphanumeric() {
ident.push(ch);
self.consume();
} else {
break;
}
}
let end = self.position();
let span = Span::new(start, end);
let token = match ident.as_str() {
"dot" => Token::Dot,
"cross" => Token::Cross,
"grad" => Token::Grad,
"div" => Token::Div,
"curl" => Token::Curl,
"laplacian" => Token::Laplacian,
"forall" => Token::ForAll,
"exists" => Token::Exists,
"union" => Token::Union,
"intersect" => Token::Intersect,
"in" => Token::In,
"notin" => Token::NotIn,
"and" => Token::And,
"or" => Token::Or,
"not" => Token::Not,
"implies" => Token::Implies,
"iff" => Token::Iff,
_ => Token::Identifier(ident),
};
(token, span)
}
fn scan_token(&mut self) -> ParseResult<Option<SpannedToken>> {
self.skip_whitespace();
let Some(ch) = self.peek() else {
return Ok(None);
};
let start = self.position();
if ch.is_ascii_digit() {
let (token, span) = self.scan_number()?;
return Ok(Some(SpannedToken::new(token, span)));
}
if ch.is_ascii_alphabetic() {
let (token, span) = self.scan_identifier();
return Ok(Some(SpannedToken::new(token, span)));
}
match ch {
'!' => {
self.consume();
if self.peek() == Some('=') {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::NotEquals,
Span::new(start, end),
)));
}
let end = self.position();
return Ok(Some(SpannedToken::new(Token::Bang, Span::new(start, end))));
}
'<' => {
self.consume();
if self.peek() == Some('=') {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::LessEq,
Span::new(start, end),
)));
}
let end = self.position();
return Ok(Some(SpannedToken::new(Token::Less, Span::new(start, end))));
}
'>' => {
self.consume();
if self.peek() == Some('=') {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::GreaterEq,
Span::new(start, end),
)));
}
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::Greater,
Span::new(start, end),
)));
}
'≤' => {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::LessEq,
Span::new(start, end),
)));
}
'≥' => {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::GreaterEq,
Span::new(start, end),
)));
}
'≠' => {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::NotEquals,
Span::new(start, end),
)));
}
'π' => {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(Token::Pi, Span::new(start, end))));
}
'∞' => {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::Infinity,
Span::new(start, end),
)));
}
'√' => {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(Token::Sqrt, Span::new(start, end))));
}
'*' => {
self.consume();
if self.peek() == Some('*') {
self.consume();
let end = self.position();
return Ok(Some(SpannedToken::new(
Token::DoubleStar,
Span::new(start, end),
)));
}
let end = self.position();
return Ok(Some(SpannedToken::new(Token::Star, Span::new(start, end))));
}
_ => {}
}
self.consume();
let end = self.position();
let span = Span::new(start, end);
let token = match ch {
'+' => Token::Plus,
'-' => Token::Minus,
'/' => Token::Slash,
'^' => Token::Caret,
'%' => Token::Percent,
'(' => Token::LParen,
')' => Token::RParen,
'[' => Token::LBracket,
']' => Token::RBracket,
'{' => Token::LBrace,
'}' => Token::RBrace,
',' => Token::Comma,
';' => Token::Semicolon,
'=' => Token::Equals,
'_' => Token::Underscore,
_ => {
return Err(ParseError::unexpected_token(
vec!["valid token".to_string()],
ch.to_string(),
Some(span),
));
}
};
Ok(Some(SpannedToken::new(token, span)))
}
fn tokenize_all(&mut self) -> ParseResult<Vec<SpannedToken>> {
let mut tokens = Vec::new();
while let Some(token) = self.scan_token()? {
tokens.push(token);
}
Ok(tokens)
}
}
pub fn tokenize(input: &str) -> ParseResult<Vec<SpannedToken>> {
let mut tokenizer = Tokenizer::new(input);
tokenizer.tokenize_all()
}
#[cfg(test)]
#[allow(clippy::approx_constant)]
mod tests {
use super::*;
#[test]
fn test_tokenize_integer() {
let tokens = tokenize("42").unwrap();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].value, Token::Integer(42));
}
#[test]
fn test_tokenize_float() {
let tokens = tokenize("3.14").unwrap();
assert_eq!(tokens.len(), 1);
assert!(matches!(tokens[0].value, Token::Float(f) if (f - 3.14).abs() < 0.001));
}
#[test]
fn test_tokenize_scientific_notation() {
let tokens = tokenize("1.5e-3").unwrap();
assert_eq!(tokens.len(), 1);
assert!(matches!(tokens[0].value, Token::Float(f) if (f - 0.0015).abs() < 0.0001));
}
#[test]
fn test_tokenize_identifier() {
let tokens = tokenize("x").unwrap();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].value, Token::Identifier("x".to_string()));
}
#[test]
fn test_tokenize_multi_char_identifier() {
let tokens = tokenize("theta").unwrap();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].value, Token::Identifier("theta".to_string()));
}
#[test]
fn test_tokenize_operators() {
let tokens = tokenize("+ - * / ^ %").unwrap();
assert_eq!(tokens.len(), 6);
assert_eq!(tokens[0].value, Token::Plus);
assert_eq!(tokens[1].value, Token::Minus);
assert_eq!(tokens[2].value, Token::Star);
assert_eq!(tokens[3].value, Token::Slash);
assert_eq!(tokens[4].value, Token::Caret);
assert_eq!(tokens[5].value, Token::Percent);
}
#[test]
fn test_tokenize_delimiters() {
let tokens = tokenize("( ) [ ] { }").unwrap();
assert_eq!(tokens.len(), 6);
assert_eq!(tokens[0].value, Token::LParen);
assert_eq!(tokens[1].value, Token::RParen);
assert_eq!(tokens[2].value, Token::LBracket);
assert_eq!(tokens[3].value, Token::RBracket);
assert_eq!(tokens[4].value, Token::LBrace);
assert_eq!(tokens[5].value, Token::RBrace);
}
#[test]
fn test_tokenize_relations() {
let tokens = tokenize("= != < <= > >=").unwrap();
assert_eq!(tokens.len(), 6);
assert_eq!(tokens[0].value, Token::Equals);
assert_eq!(tokens[1].value, Token::NotEquals);
assert_eq!(tokens[2].value, Token::Less);
assert_eq!(tokens[3].value, Token::LessEq);
assert_eq!(tokens[4].value, Token::Greater);
assert_eq!(tokens[5].value, Token::GreaterEq);
}
#[test]
fn test_tokenize_unicode_relations() {
let tokens = tokenize("≤ ≥ ≠").unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].value, Token::LessEq);
assert_eq!(tokens[1].value, Token::GreaterEq);
assert_eq!(tokens[2].value, Token::NotEquals);
}
#[test]
fn test_tokenize_expression() {
let tokens = tokenize("2 + x * 3.14").unwrap();
assert_eq!(tokens.len(), 5);
assert_eq!(tokens[0].value, Token::Integer(2));
assert_eq!(tokens[1].value, Token::Plus);
assert_eq!(tokens[2].value, Token::Identifier("x".to_string()));
assert_eq!(tokens[3].value, Token::Star);
assert!(matches!(tokens[4].value, Token::Float(_)));
}
#[test]
fn test_tokenize_function_call() {
let tokens = tokenize("sin(x)").unwrap();
assert_eq!(tokens.len(), 4);
assert_eq!(tokens[0].value, Token::Identifier("sin".to_string()));
assert_eq!(tokens[1].value, Token::LParen);
assert_eq!(tokens[2].value, Token::Identifier("x".to_string()));
assert_eq!(tokens[3].value, Token::RParen);
}
#[test]
fn test_tokenize_factorial() {
let tokens = tokenize("5!").unwrap();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].value, Token::Integer(5));
assert_eq!(tokens[1].value, Token::Bang);
}
#[test]
fn test_tokenize_underscore() {
let tokens = tokenize("x_1").unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].value, Token::Identifier("x".to_string()));
assert_eq!(tokens[1].value, Token::Underscore);
assert_eq!(tokens[2].value, Token::Integer(1));
}
#[test]
fn test_tokenize_empty() {
let tokens = tokenize("").unwrap();
assert_eq!(tokens.len(), 0);
}
#[test]
fn test_tokenize_whitespace_only() {
let tokens = tokenize(" ").unwrap();
assert_eq!(tokens.len(), 0);
}
#[test]
fn test_invalid_character() {
let result = tokenize("@");
assert!(result.is_err());
}
#[test]
fn test_span_tracking() {
let tokens = tokenize("x + y").unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].span.start.column, 1);
assert_eq!(tokens[0].span.end.column, 2);
assert_eq!(tokens[2].span.start.column, 5);
assert_eq!(tokens[2].span.end.column, 6);
}
#[test]
fn test_tokenize_unicode_pi() {
let tokens = tokenize("2*π").unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].value, Token::Integer(2));
assert_eq!(tokens[1].value, Token::Star);
assert_eq!(tokens[2].value, Token::Pi);
}
#[test]
fn test_tokenize_unicode_infinity() {
let tokens = tokenize("∞").unwrap();
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].value, Token::Infinity);
}
#[test]
fn test_tokenize_unicode_sqrt() {
let tokens = tokenize("√4").unwrap();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].value, Token::Sqrt);
assert_eq!(tokens[1].value, Token::Integer(4));
}
#[test]
fn test_tokenize_unicode_sqrt_with_parens() {
let tokens = tokenize("√(x+1)").unwrap();
assert_eq!(tokens.len(), 6);
assert_eq!(tokens[0].value, Token::Sqrt);
assert_eq!(tokens[1].value, Token::LParen);
assert_eq!(tokens[2].value, Token::Identifier("x".to_string()));
assert_eq!(tokens[3].value, Token::Plus);
assert_eq!(tokens[4].value, Token::Integer(1));
assert_eq!(tokens[5].value, Token::RParen);
}
#[test]
fn test_tokenize_double_star() {
let tokens = tokenize("2**3").unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].value, Token::Integer(2));
assert_eq!(tokens[1].value, Token::DoubleStar);
assert_eq!(tokens[2].value, Token::Integer(3));
}
#[test]
fn test_tokenize_star_vs_double_star() {
let tokens = tokenize("2*3**4").unwrap();
assert_eq!(tokens.len(), 5);
assert_eq!(tokens[0].value, Token::Integer(2));
assert_eq!(tokens[1].value, Token::Star);
assert_eq!(tokens[2].value, Token::Integer(3));
assert_eq!(tokens[3].value, Token::DoubleStar);
assert_eq!(tokens[4].value, Token::Integer(4));
}
#[test]
fn test_tokenize_vector_keywords() {
let tokens = tokenize("dot cross").unwrap();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].value, Token::Dot);
assert_eq!(tokens[1].value, Token::Cross);
}
#[test]
fn test_tokenize_vector_calculus_keywords() {
let tokens = tokenize("grad div curl laplacian").unwrap();
assert_eq!(tokens.len(), 4);
assert_eq!(tokens[0].value, Token::Grad);
assert_eq!(tokens[1].value, Token::Div);
assert_eq!(tokens[2].value, Token::Curl);
assert_eq!(tokens[3].value, Token::Laplacian);
}
#[test]
fn test_tokenize_quantifier_keywords() {
let tokens = tokenize("forall exists").unwrap();
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].value, Token::ForAll);
assert_eq!(tokens[1].value, Token::Exists);
}
#[test]
fn test_tokenize_set_keywords() {
let tokens = tokenize("union intersect in notin").unwrap();
assert_eq!(tokens.len(), 4);
assert_eq!(tokens[0].value, Token::Union);
assert_eq!(tokens[1].value, Token::Intersect);
assert_eq!(tokens[2].value, Token::In);
assert_eq!(tokens[3].value, Token::NotIn);
}
#[test]
fn test_tokenize_logical_keywords() {
let tokens = tokenize("and or not implies iff").unwrap();
assert_eq!(tokens.len(), 5);
assert_eq!(tokens[0].value, Token::And);
assert_eq!(tokens[1].value, Token::Or);
assert_eq!(tokens[2].value, Token::Not);
assert_eq!(tokens[3].value, Token::Implies);
assert_eq!(tokens[4].value, Token::Iff);
}
}