use std::iter::Peekable;
use crate::lexer::{
Lexer,
token::{Span, Token, TokenKind},
};
pub mod ast;
mod error;
use ast::*;
use error::ParseError;
pub struct Parser<'a> {
lexer: Peekable<Lexer<'a>>,
last_span: Option<Span>,
last_token: Option<TokenKind<'a>>,
source_len: usize,
}
impl<'a> Parser<'a> {
pub fn new(lexer: Lexer<'a>) -> Self {
let source_len = lexer.source_len();
Self {
lexer: lexer.peekable(),
last_span: None,
last_token: None,
source_len,
}
}
pub fn parse(&mut self) -> Result<Stmt, ParseError> {
let stmt = self.parse_stmt()?;
if let Some(tok) = self.peek()? {
return Err(ParseError::unexpected_token("end of input", Some(tok)));
}
Ok(stmt)
}
fn peek(&mut self) -> Result<Option<&Token<'a>>, ParseError> {
match self.lexer.peek() {
Some(Ok(tok)) => Ok(Some(tok)),
Some(Err(e)) => Err(ParseError::from(e.clone())),
None => Ok(None),
}
}
fn bump(&mut self) -> Result<Option<Token<'a>>, ParseError> {
match self.lexer.next() {
Some(Ok(tok)) => {
self.last_span = Some(tok.span);
self.last_token = Some(tok.kind.clone());
Ok(Some(tok))
}
Some(Err(e)) => Err(ParseError::from(e)),
None => Ok(None),
}
}
fn fallback_span(&self) -> Option<Span> {
if let Some(span) = self.last_span {
Some(span)
} else if self.source_len > 0 {
Some(Span::new(
self.source_len.saturating_sub(1),
self.source_len,
))
} else {
None
}
}
fn fallback_found(&self) -> Option<String> {
self.last_token.as_ref().map(|kind| kind.to_string())
}
fn fallback_found_and_span(&self) -> (Option<String>, Option<Span>) {
(self.fallback_found(), self.fallback_span())
}
fn eat(&mut self, expected: TokenKind<'a>) -> Result<bool, ParseError> {
if matches!(self.peek()?, Some(t) if t.kind == expected) {
self.bump()?;
Ok(true)
} else {
Ok(false)
}
}
fn expect_token(&mut self, expected: TokenKind<'a>) -> Result<Token<'a>, ParseError> {
let expected_str = expected.to_string();
match self.bump()? {
Some(tok) if tok.kind == expected => Ok(tok),
Some(tok) => Err(ParseError::unexpected(
expected_str.clone(),
Some(tok.kind.to_string()),
Some(tok.span),
)),
None => {
let (found, span) = self.fallback_found_and_span();
Err(ParseError::unexpected(expected_str, found, span))
}
}
}
fn expect(&mut self, expected: TokenKind<'a>) -> Result<(), ParseError> {
self.expect_token(expected)?;
Ok(())
}
fn parse_stmt(&mut self) -> Result<Stmt, ParseError> {
match self.peek()? {
Some(tok) if matches!(tok.kind, TokenKind::Identifier(_)) => {
self.parse_stmt_starting_with_ident()
}
_ => Ok(Stmt::Expression(self.parse_expr_bp(0)?)),
}
}
fn parse_stmt_starting_with_ident(&mut self) -> Result<Stmt, ParseError> {
let (name, name_span) = match self.bump()? {
Some(sp) => match sp.kind {
TokenKind::Identifier(s) => (s.to_string(), sp.span),
other => {
return Err(ParseError::unexpected(
"identifier",
Some(other.to_string()),
Some(sp.span),
));
}
},
None => {
let (found, span) = self.fallback_found_and_span();
return Err(ParseError::unexpected("identifier", found, span));
}
};
if matches!(self.peek()?, Some(tok) if tok.kind == TokenKind::OpenParen)
&& self.lookahead_func_def_after_params()?
{
self.expect(TokenKind::OpenParen)?;
let params = self.parse_pattern_list()?;
self.expect(TokenKind::Assign)?;
let body = self.parse_expr_bp(0)?;
return Ok(Stmt::FunctionDefinition {
name,
arms: vec![FuncArm { params, body }],
});
}
if self.eat(TokenKind::Assign)? {
let value = self.parse_expr_bp(0)?;
return Ok(Stmt::Assignment { name, value });
}
let lhs = Expr::Identifier(name, name_span);
Ok(Stmt::Expression(self.parse_expr_bp_with_lhs(lhs, 0)?))
}
fn parse_pattern_list(&mut self) -> Result<Vec<Pattern>, ParseError> {
let mut params = Vec::new();
if self.eat(TokenKind::CloseParen)? {
return Ok(params);
}
loop {
params.push(self.parse_pattern()?);
if self.eat(TokenKind::Comma)? {
continue;
}
self.expect(TokenKind::CloseParen)?;
break;
}
Ok(params)
}
fn parse_pattern(&mut self) -> Result<Pattern, ParseError> {
match self.bump()? {
Some(Token { kind, span }) => match kind {
TokenKind::Identifier(s) => Ok(Pattern::Identifier(s.to_string())),
TokenKind::Integer { base: _, val } => Ok(Pattern::Lit(Literal::Int(val))),
TokenKind::Float(x) => Ok(Pattern::Lit(Literal::Float(x))),
TokenKind::Bool(b) => Ok(Pattern::Lit(Literal::Bool(b))),
other => Err(ParseError::unexpected(
"literal",
Some(other.to_string()),
Some(span),
)),
},
None => {
let (found, span) = self.fallback_found_and_span();
Err(ParseError::unexpected("literal", found, span))
}
}
}
fn parse_expr_bp(&mut self, min_bp: u8) -> Result<Expr, ParseError> {
let lhs = if let Some(tok) = self.peek()? {
if let Some(op) = prefix_op(&tok.kind) {
let op_tok = match self.bump()? {
Some(tok) => tok,
None => unreachable!("prefix operator vanished after peek"),
};
let rhs = self.parse_expr_bp(PREFIX_BP)?;
let span = span_cover(op_tok.span, rhs.span());
Expr::Unary {
op,
span,
rhs: Box::new(rhs),
}
} else {
self.parse_primary()?
}
} else {
let (found, span) = self.fallback_found_and_span();
return Err(ParseError::unexpected("expression", found, span));
};
self.parse_expr_bp_with_lhs(lhs, min_bp)
}
fn parse_expr_bp_with_lhs(&mut self, mut lhs: Expr, min_bp: u8) -> Result<Expr, ParseError> {
loop {
if matches!(self.peek()?, Some(tok) if tok.kind == TokenKind::OpenParen) {
let open_tok = match self.bump()? {
Some(tok) => tok,
None => unreachable!("open paren vanished after peek"),
};
let mut args = Vec::new();
let close_tok = if matches!(self.peek()?, Some(tok) if tok.kind == TokenKind::CloseParen)
{
match self.bump()? {
Some(tok) => tok,
None => unreachable!("close paren expected"),
}
} else {
loop {
args.push(self.parse_expr_bp(0)?);
if matches!(self.peek()?, Some(tok) if tok.kind == TokenKind::Comma) {
self.bump()?;
continue;
}
break;
}
self.expect_token(TokenKind::CloseParen)?
};
let callee_expr = lhs;
let call_span = span_cover(callee_expr.span(), open_tok.span);
let span = span_cover(call_span, close_tok.span);
lhs = Expr::Call {
callee: Box::new(callee_expr),
args,
span,
};
continue;
}
let (op, lbp, rbp, op_span) = match self.peek()? {
Some(tok) => match infix_bp(&tok.kind) {
Some((op, lbp, rbp)) => (op, lbp, rbp, tok.span),
None => break,
},
None => break,
};
if lbp < min_bp {
break;
}
self.bump()?;
let rhs = self.parse_expr_bp(rbp)?;
let span = span_cover(lhs.span(), span_cover(op_span, rhs.span()));
let left = lhs;
lhs = Expr::Binary {
lhs: Box::new(left),
op,
span,
rhs: Box::new(rhs),
};
}
Ok(lhs)
}
fn parse_primary(&mut self) -> Result<Expr, ParseError> {
match self.bump()? {
Some(Token {
kind: TokenKind::Integer { base: _, val },
span,
}) => Ok(Expr::Lit(Literal::Int(val), span)),
Some(Token {
kind: TokenKind::Float(x),
span,
}) => Ok(Expr::Lit(Literal::Float(x), span)),
Some(Token {
kind: TokenKind::Bool(b),
span,
}) => Ok(Expr::Lit(Literal::Bool(b), span)),
Some(Token {
kind: TokenKind::Identifier(s),
span,
}) => Ok(Expr::Identifier(s.to_string(), span)),
Some(Token {
kind: TokenKind::OpenParen,
span: open_span,
}) => {
let expr = self.parse_expr_bp(0)?;
let close_tok = self.expect_token(TokenKind::CloseParen)?;
let span = span_cover(open_span, close_tok.span);
Ok(Expr::Group(Box::new(expr), span))
}
Some(Token { kind, span }) => Err(ParseError::unexpected(
"expression",
Some(kind.to_string()),
Some(span),
)),
None => {
let (found, span) = self.fallback_found_and_span();
Err(ParseError::unexpected("expression", found, span))
}
}
}
fn lookahead_func_def_after_params(&mut self) -> Result<bool, ParseError> {
let mut snap = self.lexer.clone();
match snap.next() {
Some(Ok(Token {
kind: TokenKind::OpenParen,
..
})) => {}
_ => return Ok(false),
}
let mut depth = 1usize;
for next in snap.by_ref() {
let t = next.map_err(ParseError::from)?;
match t.kind {
TokenKind::OpenParen => depth += 1,
TokenKind::CloseParen => {
depth -= 1;
if depth == 0 {
break;
}
}
_ => {}
}
}
if depth != 0 {
return Ok(false);
}
match snap.next() {
Some(Ok(Token {
kind: TokenKind::Assign,
..
})) => Ok(true),
_ => Ok(false),
}
}
}
fn span_cover(a: Span, b: Span) -> Span {
Span::new(a.start.min(b.start), a.end.max(b.end))
}
const PREFIX_BP: u8 = 100;
fn prefix_op(tok: &TokenKind) -> Option<UnaryOp> {
match tok {
TokenKind::Minus => Some(UnaryOp::Neg),
TokenKind::Bang => Some(UnaryOp::Not),
_ => None,
}
}
fn infix_bp(tok: &TokenKind) -> Option<(BinOp, u8, u8)> {
use BinOp::*;
match tok {
TokenKind::Or => Some((Or, 1, 2)),
TokenKind::And => Some((And, 2, 3)),
TokenKind::BitOr => Some((BitOr, 3, 4)),
TokenKind::BitAnd => Some((BitAnd, 4, 5)),
TokenKind::Caret => Some((Xor, 5, 6)),
TokenKind::Eq => Some((Eq, 6, 7)),
TokenKind::Ne => Some((Ne, 6, 7)),
TokenKind::Lt => Some((Lt, 7, 8)),
TokenKind::LtEq => Some((LtEq, 7, 8)),
TokenKind::Gt => Some((Gt, 7, 8)),
TokenKind::GtEq => Some((GtEq, 7, 8)),
TokenKind::BitShl => Some((BitShl, 8, 9)),
TokenKind::BitShr => Some((BitShr, 8, 9)),
TokenKind::Plus => Some((Add, 9, 10)),
TokenKind::Minus => Some((Sub, 9, 10)),
TokenKind::Star => Some((Mul, 10, 11)),
TokenKind::Slash => Some((Div, 10, 11)),
TokenKind::Percent => Some((Mod, 10, 11)),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexer::token::Span;
fn parse(input: &str) -> Result<Stmt, ParseError> {
let lexer = Lexer::new(input);
let mut parser = Parser::new(lexer);
parser.parse()
}
#[test]
fn parses_binary_precedence() {
let stmt = parse("1 + 2 * 3").unwrap();
let Expr::Binary { lhs, op, span, rhs } = expect_expr(stmt) else {
panic!("expected binary expression");
};
assert_eq!(op, BinOp::Add);
assert_eq!(span, Span::new(0, 9));
let Expr::Lit(Literal::Int(1), lhs_span) = *lhs else {
panic!("left operand should be literal 1");
};
assert_eq!(lhs_span, Span::new(0, 1));
let Expr::Binary {
lhs: mul_lhs,
op: mul_op,
span: mul_span,
rhs: mul_rhs,
} = *rhs
else {
panic!("expected multiplication on the right");
};
assert_eq!(mul_op, BinOp::Mul);
assert_eq!(mul_span, Span::new(4, 9));
let Expr::Lit(Literal::Int(2), lhs_mul_span) = *mul_lhs else {
panic!("left operand of multiplication should be 2");
};
assert_eq!(lhs_mul_span, Span::new(4, 5));
let Expr::Lit(Literal::Int(3), rhs_mul_span) = *mul_rhs else {
panic!("right operand of multiplication should be 3");
};
assert_eq!(rhs_mul_span, Span::new(8, 9));
}
#[test]
fn parses_assignment_statement() {
let stmt = parse("answer = 42").unwrap();
let Stmt::Assignment { name, value } = stmt else {
panic!("expected assignment statement");
};
assert_eq!(name, "answer");
let Expr::Lit(Literal::Int(42), span) = value else {
panic!("assignment value should be literal 42");
};
assert_eq!(span, Span::new(9, 11));
}
#[test]
fn parses_function_definition_with_patterns() {
let stmt = parse("f(x, 1) = x").unwrap();
let Stmt::FunctionDefinition { name, arms } = stmt else {
panic!("expected function definition");
};
assert_eq!(name, "f");
assert_eq!(arms.len(), 1);
let FuncArm { params, body } = arms.into_iter().next().unwrap();
assert_eq!(
params,
vec![
Pattern::Identifier("x".into()),
Pattern::Lit(Literal::Int(1)),
]
);
let Expr::Identifier(name, span) = body else {
panic!("function body should be identifier `x`");
};
assert_eq!(name, "x");
assert_eq!(span, Span::new(10, 11));
}
#[test]
fn parses_additional_binary_ops() {
let Expr::Binary { op, span, .. } = parse_expr("1 != 2") else {
panic!("expected binary expression");
};
assert_eq!(op, BinOp::Ne);
assert_eq!(span, Span::new(0, 6));
let Expr::Binary { op, span, .. } = parse_expr("3 < 4") else {
panic!("expected comparison expression");
};
assert_eq!(op, BinOp::Lt);
assert_eq!(span, Span::new(0, 5));
let Expr::Binary { op, span, .. } = parse_expr("4 > 3") else {
panic!("expected comparison expression");
};
assert_eq!(op, BinOp::Gt);
assert_eq!(span, Span::new(0, 5));
let Expr::Binary { op, span, .. } = parse_expr("5 - 2") else {
panic!("expected subtraction expression");
};
assert_eq!(op, BinOp::Sub);
assert_eq!(span, Span::new(0, 5));
let Expr::Binary { op, span, .. } = parse_expr("2 * 3") else {
panic!("expected multiplication expression");
};
assert_eq!(op, BinOp::Mul);
assert_eq!(span, Span::new(0, 5));
let Expr::Binary { op, span, .. } = parse_expr("8 / 4") else {
panic!("expected division expression");
};
assert_eq!(op, BinOp::Div);
assert_eq!(span, Span::new(0, 5));
let Expr::Binary { op, span, .. } = parse_expr("9 % 5") else {
panic!("expected modulo expression");
};
assert_eq!(op, BinOp::Mod);
assert_eq!(span, Span::new(0, 5));
}
#[test]
fn binop_variants_constructible() {
use BinOp::*;
let all = [
And, BitAnd, Or, BitOr, Eq, Ne, Lt, LtEq, Gt, GtEq, Xor, BitShl, BitShr, Add, Sub, Mul,
Div, Mod,
];
assert_eq!(all.len(), 18);
}
#[test]
fn prefix_operators_chain_correctly() {
let Expr::Unary { op, rhs, .. } = parse_expr("!-x") else {
panic!("expected unary expression");
};
assert_eq!(op, UnaryOp::Not);
let Expr::Unary { op: inner_op, .. } = *rhs else {
panic!("expected nested unary");
};
assert_eq!(inner_op, UnaryOp::Neg);
}
#[test]
fn nested_calls_parse_properly() {
let Expr::Call { callee, args, .. } = parse_expr("f()(1)") else {
panic!("expected call expression");
};
assert!(matches!(*callee, Expr::Call { .. }));
assert_eq!(args.len(), 1);
}
#[test]
fn reports_unclosed_grouping() {
let err = parse("(1 + 2").unwrap_err();
assert!(
matches!(err, ParseError::UnexpectedToken { .. }),
"expected unexpected token error, got {err:?}"
);
}
#[test]
fn rejects_trailing_tokens() {
let err = parse("1 2").unwrap_err();
if let ParseError::UnexpectedToken {
expected,
found,
span,
} = err
{
assert_eq!(expected, "end of input");
assert_eq!(found, "'2'");
let span = span.expect("span present for trailing literal");
let offset: usize = span.offset();
assert_eq!(offset, 2);
} else {
panic!("expected EOF error for trailing literal, got {err:?}");
}
let err = parse("f(x) y").unwrap_err();
if let ParseError::UnexpectedToken {
expected,
found,
span,
} = err
{
assert_eq!(expected, "end of input");
assert_eq!(found, "'y'");
let span = span.expect("span present for trailing identifier");
let offset: usize = span.offset();
assert_eq!(offset, 5);
} else {
panic!("expected EOF error for trailing identifier, got {err:?}");
}
}
#[test]
fn lookahead_returns_false_for_unbalanced_params() {
let mut parser = Parser::new(Lexer::new("(x + 1"));
assert!(
!parser.lookahead_func_def_after_params().unwrap(),
"should not treat unbalanced parens as function definition"
);
}
#[test]
fn lookahead_requires_assign_after_params() {
let mut parser = Parser::new(Lexer::new("(x) 1"));
assert!(
!parser.lookahead_func_def_after_params().unwrap(),
"missing '=' should not trigger function definition"
);
}
fn expect_expr(stmt: Stmt) -> Expr {
match stmt {
Stmt::Expression(expr) => expr,
other => panic!("expected expression statement, got {other:?}"),
}
}
fn parse_expr(input: &str) -> Expr {
expect_expr(parse(input).unwrap())
}
#[test]
fn parses_empty_function_parameters() {
let stmt = parse("f() = 1").unwrap();
let Stmt::FunctionDefinition { name, arms } = stmt else {
panic!("expected function definition");
};
assert_eq!(name, "f");
let arm = arms.first().expect("one arm");
assert!(arm.params.is_empty(), "expected empty params");
}
#[test]
fn lexer_errors_surface_as_parse_errors() {
let err = parse("$").unwrap_err();
assert!(matches!(err, ParseError::LexerError { .. }));
}
#[test]
fn empty_input_reports_expression_at_eof() {
let err = parse("").unwrap_err();
match err {
ParseError::UnexpectedToken {
expected,
found,
span,
} => {
assert_eq!(expected, "expression");
assert_eq!(found, "end of input");
assert!(span.is_none(), "span should be None for empty input");
}
other @ ParseError::LexerError { .. } => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn function_patterns_accept_bool_and_float_literals() {
let stmt = parse("f(true, 1.5) = 0").unwrap();
let Stmt::FunctionDefinition { arms, .. } = stmt else {
panic!("expected function definition");
};
let arm = arms.first().expect("one arm");
assert_eq!(
arm.params,
vec![
Pattern::Lit(Literal::Bool(true)),
Pattern::Lit(Literal::Float(1.5)),
]
);
}
#[test]
fn unterminated_call_reports_last_span() {
let err = parse("foo(").unwrap_err();
match err {
ParseError::UnexpectedToken {
found,
span: Some(span),
..
} => {
assert_eq!(found, "'('");
let offset: usize = span.offset();
assert_eq!(offset, 3, "span should point to '('");
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn parses_bitwise_operators() {
let Expr::Binary { op, span, .. } = parse_expr("1 ^ 2") else {
panic!("expected xor expression");
};
assert_eq!(op, BinOp::Xor);
assert_eq!(span, Span::new(0, 5));
let Expr::Binary { op, span, .. } = parse_expr("1 | 2") else {
panic!("expected bitwise or expression");
};
assert_eq!(op, BinOp::BitOr);
assert_eq!(span, Span::new(0, 5));
let Expr::Binary { op, span, .. } = parse_expr("1 & 2") else {
panic!("expected bitwise and expression");
};
assert_eq!(op, BinOp::BitAnd);
assert_eq!(span, Span::new(0, 5));
}
#[test]
fn parses_shift_operators() {
let Expr::Binary { op, span, .. } = parse_expr("1 << 2") else {
panic!("expected shift left expression");
};
assert_eq!(op, BinOp::BitShl);
assert_eq!(span, Span::new(0, 6));
let Expr::Binary { op, span, .. } = parse_expr("1 >> 2") else {
panic!("expected shift right expression");
};
assert_eq!(op, BinOp::BitShr);
assert_eq!(span, Span::new(0, 6));
}
#[test]
fn shift_binds_tighter_than_comparison_but_looser_than_addition() {
let Expr::Binary { op, lhs, .. } = parse_expr("1 + 2 << 3") else {
panic!("expected shift expression");
};
assert_eq!(op, BinOp::BitShl);
assert!(matches!(*lhs, Expr::Binary { op: BinOp::Add, .. }));
let Expr::Binary { op, lhs, .. } = parse_expr("1 << 2 < 3") else {
panic!("expected comparison expression");
};
assert_eq!(op, BinOp::Lt);
assert!(matches!(
*lhs,
Expr::Binary {
op: BinOp::BitShl,
..
}
));
}
#[test]
fn parses_call_with_multiple_args() {
let Expr::Call { args, .. } = parse_expr("f(1, 2, 3)") else {
panic!("expected call expression");
};
assert_eq!(args.len(), 3);
}
#[test]
fn identifier_expression_parses_without_assignment() {
let stmt = parse("foo").unwrap();
let Expr::Identifier(name, span) = expect_expr(stmt) else {
panic!("expected identifier expression");
};
assert_eq!(name, "foo");
assert_eq!(span, Span::new(0, 3));
}
}