use super::ast::{BinOp, Expr};
pub struct Parser<'a> {
src: &'a str,
pos: usize,
}
impl<'a> Parser<'a> {
#[must_use]
pub const fn new(src: &'a str) -> Self {
Self { src, pos: 0 }
}
pub fn parse(mut self) -> Result<Expr, String> {
let e = self.expr()?;
self.skip_ws();
if self.pos != self.src.len() {
return Err(format!(
"trailing input at byte {}: {:?}",
self.pos,
&self.src[self.pos..]
));
}
Ok(e)
}
fn expr(&mut self) -> Result<Expr, String> {
self.or()
}
fn or(&mut self) -> Result<Expr, String> {
let mut lhs = self.and()?;
loop {
self.skip_ws();
if self.eat("||") {
let rhs = self.and()?;
lhs = Expr::Bin(BinOp::Or, Box::new(lhs), Box::new(rhs));
} else {
break;
}
}
Ok(lhs)
}
fn and(&mut self) -> Result<Expr, String> {
let mut lhs = self.cmp()?;
loop {
self.skip_ws();
if self.eat("&&") {
let rhs = self.cmp()?;
lhs = Expr::Bin(BinOp::And, Box::new(lhs), Box::new(rhs));
} else {
break;
}
}
Ok(lhs)
}
fn cmp(&mut self) -> Result<Expr, String> {
let lhs = self.sum()?;
self.skip_ws();
let op = if self.eat("<=") {
BinOp::Le
} else if self.eat(">=") {
BinOp::Ge
} else if self.eat("==") {
BinOp::Eq
} else if self.eat("!=") {
BinOp::Ne
} else if self.peek('<') && !self.peek2('=') {
self.pos += 1;
BinOp::Lt
} else if self.peek('>') && !self.peek2('=') {
self.pos += 1;
BinOp::Gt
} else {
return Ok(lhs);
};
let rhs = self.sum()?;
Ok(Expr::Bin(op, Box::new(lhs), Box::new(rhs)))
}
fn sum(&mut self) -> Result<Expr, String> {
let mut lhs = self.prod()?;
loop {
self.skip_ws();
if self.eat("+") {
let rhs = self.prod()?;
lhs = Expr::Bin(BinOp::Add, Box::new(lhs), Box::new(rhs));
} else if self.eat("-") {
let rhs = self.prod()?;
lhs = Expr::Bin(BinOp::Sub, Box::new(lhs), Box::new(rhs));
} else {
break;
}
}
Ok(lhs)
}
fn prod(&mut self) -> Result<Expr, String> {
let mut lhs = self.unary()?;
loop {
self.skip_ws();
if self.eat("*") {
let rhs = self.unary()?;
lhs = Expr::Bin(BinOp::Mul, Box::new(lhs), Box::new(rhs));
} else if self.eat("/") {
let rhs = self.unary()?;
lhs = Expr::Bin(BinOp::Div, Box::new(lhs), Box::new(rhs));
} else {
break;
}
}
Ok(lhs)
}
fn unary(&mut self) -> Result<Expr, String> {
self.skip_ws();
if self.eat("-") {
return Ok(Expr::Neg(Box::new(self.unary()?)));
}
if self.eat("!") {
return Ok(Expr::Not(Box::new(self.unary()?)));
}
self.call_or_atom()
}
fn call_or_atom(&mut self) -> Result<Expr, String> {
self.skip_ws();
if self.peek('$') {
return self.slot_ref();
}
if self.peek('(') {
self.pos += 1;
let e = self.expr()?;
self.skip_ws();
if !self.eat(")") {
return Err("expected )".into());
}
return Ok(e);
}
if self.peek_is(|c| c.is_ascii_alphabetic() || c == '_') {
let name = self.ident();
self.skip_ws();
if self.eat("(") {
let mut args = Vec::new();
if !self.peek(')') {
args.push(self.expr()?);
loop {
self.skip_ws();
if !self.eat(",") {
break;
}
args.push(self.expr()?);
}
}
self.skip_ws();
if !self.eat(")") {
return Err("expected )".into());
}
return Ok(Expr::Call(name, args));
}
return Err(format!(
"bare identifier {name:?} is not a function call; slot refs start with $"
));
}
self.number()
}
fn slot_ref(&mut self) -> Result<Expr, String> {
if !self.eat("$") {
return Err("expected $".into());
}
let kind = self.ident();
if !self.eat(".") {
return Err("expected '.' after kind".into());
}
let attr = self.ident();
Ok(Expr::Slot { kind, attr })
}
fn ident(&mut self) -> String {
let start = self.pos;
while let Some(c) = self.src[self.pos..].chars().next() {
if c.is_ascii_alphanumeric() || c == '_' {
self.pos += c.len_utf8();
} else {
break;
}
}
self.src[start..self.pos].to_owned()
}
fn number(&mut self) -> Result<Expr, String> {
self.skip_ws();
let start = self.pos;
while let Some(c) = self.src[self.pos..].chars().next() {
if c.is_ascii_digit() || c == '.' || c == '-' || c == 'e' || c == 'E' {
self.pos += c.len_utf8();
} else {
break;
}
}
self.src[start..self.pos]
.parse::<f64>()
.map(Expr::Num)
.map_err(|e| format!("number parse error: {e}"))
}
fn eat(&mut self, s: &str) -> bool {
self.skip_ws();
if self.src[self.pos..].starts_with(s) {
self.pos += s.len();
true
} else {
false
}
}
fn peek(&self, c: char) -> bool {
self.src[self.pos..].starts_with(c)
}
fn peek2(&self, c: char) -> bool {
let mut it = self.src[self.pos..].chars();
it.next();
it.next() == Some(c)
}
fn peek_is<F: Fn(char) -> bool>(&self, f: F) -> bool {
self.src[self.pos..].chars().next().is_some_and(f)
}
fn skip_ws(&mut self) {
while let Some(c) = self.src[self.pos..].chars().next() {
if c.is_whitespace() {
self.pos += c.len_utf8();
} else {
break;
}
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::super::ast::BinOp;
use super::*;
#[test]
fn literal_number() {
assert_eq!(Parser::new("1.5").parse().unwrap(), Expr::Num(1.5));
}
#[test]
fn slot_ref() {
let e = Parser::new("$audience.male_frac").parse().unwrap();
assert_eq!(
e,
Expr::Slot {
kind: "audience".into(),
attr: "male_frac".into(),
}
);
}
#[test]
fn precedence() {
let e = Parser::new("1 + 2 * 3").parse().unwrap();
let want = Expr::Bin(
BinOp::Add,
Box::new(Expr::Num(1.0)),
Box::new(Expr::Bin(
BinOp::Mul,
Box::new(Expr::Num(2.0)),
Box::new(Expr::Num(3.0)),
)),
);
assert_eq!(e, want);
}
#[test]
fn comparison_and_logical() {
let _ = Parser::new("$a.b > 0.5 && $a.c < 10").parse().unwrap();
}
#[test]
fn function_call() {
let e = Parser::new("max(1, 2)").parse().unwrap();
assert!(matches!(e, Expr::Call(ref name, _) if name == "max"));
}
#[test]
fn errors_on_trailing_input() {
assert!(Parser::new("1 2").parse().is_err());
}
}