kitt_score 0.1.0

Decision engine at the core of Project KITT — in-memory stateful matching with pluggable scoring backends.
Documentation
//! Hand-written recursive-descent parser for the predicate DSL.
//!
//! The grammar has one literal kind (number), one variable kind (slot ref),
//! and fixed-arity operators — ~40 lines of Rust. Pulling in `pest`, `chumsky`,
//! or `nom` would hide what the parser actually does; RD is the right tool at
//! this size.

use super::ast::{BinOp, Expr};

/// Recursive-descent parser for the predicate DSL.
pub struct Parser<'a> {
    src: &'a str,
    pos: usize,
}

impl<'a> Parser<'a> {
    /// Create a parser over the given source string.
    #[must_use]
    pub const fn new(src: &'a str) -> Self {
        Self { src, pos: 0 }
    }

    /// Parse `src` as a complete expression. Returns an error on syntax
    /// failure or trailing input.
    ///
    /// # Errors
    ///
    /// Any parse failure (unexpected token, unterminated paren, trailing
    /// input, malformed number) returns a descriptive error string.
    pub fn parse(mut self) -> Result<Expr, String> {
        let e = self.expr()?;
        self.skip_ws();
        if self.pos != self.src.len() {
            return Err(format!(
                "trailing input at byte {}: {:?}",
                self.pos,
                &self.src[self.pos..]
            ));
        }
        Ok(e)
    }

    fn expr(&mut self) -> Result<Expr, String> {
        self.or()
    }

    fn or(&mut self) -> Result<Expr, String> {
        let mut lhs = self.and()?;
        loop {
            self.skip_ws();
            if self.eat("||") {
                let rhs = self.and()?;
                lhs = Expr::Bin(BinOp::Or, Box::new(lhs), Box::new(rhs));
            } else {
                break;
            }
        }
        Ok(lhs)
    }

    fn and(&mut self) -> Result<Expr, String> {
        let mut lhs = self.cmp()?;
        loop {
            self.skip_ws();
            if self.eat("&&") {
                let rhs = self.cmp()?;
                lhs = Expr::Bin(BinOp::And, Box::new(lhs), Box::new(rhs));
            } else {
                break;
            }
        }
        Ok(lhs)
    }

    fn cmp(&mut self) -> Result<Expr, String> {
        let lhs = self.sum()?;
        self.skip_ws();
        let op = if self.eat("<=") {
            BinOp::Le
        } else if self.eat(">=") {
            BinOp::Ge
        } else if self.eat("==") {
            BinOp::Eq
        } else if self.eat("!=") {
            BinOp::Ne
        } else if self.peek('<') && !self.peek2('=') {
            self.pos += 1;
            BinOp::Lt
        } else if self.peek('>') && !self.peek2('=') {
            self.pos += 1;
            BinOp::Gt
        } else {
            return Ok(lhs);
        };
        let rhs = self.sum()?;
        Ok(Expr::Bin(op, Box::new(lhs), Box::new(rhs)))
    }

    fn sum(&mut self) -> Result<Expr, String> {
        let mut lhs = self.prod()?;
        loop {
            self.skip_ws();
            if self.eat("+") {
                let rhs = self.prod()?;
                lhs = Expr::Bin(BinOp::Add, Box::new(lhs), Box::new(rhs));
            } else if self.eat("-") {
                let rhs = self.prod()?;
                lhs = Expr::Bin(BinOp::Sub, Box::new(lhs), Box::new(rhs));
            } else {
                break;
            }
        }
        Ok(lhs)
    }

    fn prod(&mut self) -> Result<Expr, String> {
        let mut lhs = self.unary()?;
        loop {
            self.skip_ws();
            if self.eat("*") {
                let rhs = self.unary()?;
                lhs = Expr::Bin(BinOp::Mul, Box::new(lhs), Box::new(rhs));
            } else if self.eat("/") {
                let rhs = self.unary()?;
                lhs = Expr::Bin(BinOp::Div, Box::new(lhs), Box::new(rhs));
            } else {
                break;
            }
        }
        Ok(lhs)
    }

    fn unary(&mut self) -> Result<Expr, String> {
        self.skip_ws();
        if self.eat("-") {
            return Ok(Expr::Neg(Box::new(self.unary()?)));
        }
        if self.eat("!") {
            return Ok(Expr::Not(Box::new(self.unary()?)));
        }
        self.call_or_atom()
    }

    fn call_or_atom(&mut self) -> Result<Expr, String> {
        self.skip_ws();
        if self.peek('$') {
            return self.slot_ref();
        }
        if self.peek('(') {
            self.pos += 1;
            let e = self.expr()?;
            self.skip_ws();
            if !self.eat(")") {
                return Err("expected )".into());
            }
            return Ok(e);
        }
        if self.peek_is(|c| c.is_ascii_alphabetic() || c == '_') {
            let name = self.ident();
            self.skip_ws();
            if self.eat("(") {
                let mut args = Vec::new();
                if !self.peek(')') {
                    args.push(self.expr()?);
                    loop {
                        self.skip_ws();
                        if !self.eat(",") {
                            break;
                        }
                        args.push(self.expr()?);
                    }
                }
                self.skip_ws();
                if !self.eat(")") {
                    return Err("expected )".into());
                }
                return Ok(Expr::Call(name, args));
            }
            return Err(format!(
                "bare identifier {name:?} is not a function call; slot refs start with $"
            ));
        }
        self.number()
    }

    fn slot_ref(&mut self) -> Result<Expr, String> {
        if !self.eat("$") {
            return Err("expected $".into());
        }
        let kind = self.ident();
        if !self.eat(".") {
            return Err("expected '.' after kind".into());
        }
        let attr = self.ident();
        Ok(Expr::Slot { kind, attr })
    }

    fn ident(&mut self) -> String {
        let start = self.pos;
        while let Some(c) = self.src[self.pos..].chars().next() {
            if c.is_ascii_alphanumeric() || c == '_' {
                self.pos += c.len_utf8();
            } else {
                break;
            }
        }
        self.src[start..self.pos].to_owned()
    }

    fn number(&mut self) -> Result<Expr, String> {
        self.skip_ws();
        let start = self.pos;
        while let Some(c) = self.src[self.pos..].chars().next() {
            if c.is_ascii_digit() || c == '.' || c == '-' || c == 'e' || c == 'E' {
                self.pos += c.len_utf8();
            } else {
                break;
            }
        }
        self.src[start..self.pos]
            .parse::<f64>()
            .map(Expr::Num)
            .map_err(|e| format!("number parse error: {e}"))
    }

    fn eat(&mut self, s: &str) -> bool {
        self.skip_ws();
        if self.src[self.pos..].starts_with(s) {
            self.pos += s.len();
            true
        } else {
            false
        }
    }

    fn peek(&self, c: char) -> bool {
        self.src[self.pos..].starts_with(c)
    }

    fn peek2(&self, c: char) -> bool {
        let mut it = self.src[self.pos..].chars();
        it.next();
        it.next() == Some(c)
    }

    fn peek_is<F: Fn(char) -> bool>(&self, f: F) -> bool {
        self.src[self.pos..].chars().next().is_some_and(f)
    }

    fn skip_ws(&mut self) {
        while let Some(c) = self.src[self.pos..].chars().next() {
            if c.is_whitespace() {
                self.pos += c.len_utf8();
            } else {
                break;
            }
        }
    }
}

#[cfg(test)]
mod tests {
    #![allow(clippy::unwrap_used)]
    use super::super::ast::BinOp;
    use super::*;

    #[test]
    fn literal_number() {
        assert_eq!(Parser::new("1.5").parse().unwrap(), Expr::Num(1.5));
    }

    #[test]
    fn slot_ref() {
        let e = Parser::new("$audience.male_frac").parse().unwrap();
        assert_eq!(
            e,
            Expr::Slot {
                kind: "audience".into(),
                attr: "male_frac".into(),
            }
        );
    }

    #[test]
    fn precedence() {
        let e = Parser::new("1 + 2 * 3").parse().unwrap();
        let want = Expr::Bin(
            BinOp::Add,
            Box::new(Expr::Num(1.0)),
            Box::new(Expr::Bin(
                BinOp::Mul,
                Box::new(Expr::Num(2.0)),
                Box::new(Expr::Num(3.0)),
            )),
        );
        assert_eq!(e, want);
    }

    #[test]
    fn comparison_and_logical() {
        let _ = Parser::new("$a.b > 0.5 && $a.c < 10").parse().unwrap();
    }

    #[test]
    fn function_call() {
        let e = Parser::new("max(1, 2)").parse().unwrap();
        assert!(matches!(e, Expr::Call(ref name, _) if name == "max"));
    }

    #[test]
    fn errors_on_trailing_input() {
        assert!(Parser::new("1 2").parse().is_err());
    }
}