llama-rs 0.17.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
//! Shape expression grammar and evaluator.
//!
//! Grammar (v1, deliberately tiny):
//!
//! ```text
//! expr    := term ('*' term)*
//! term    := atom ('/' integer)*
//! atom    := symbol | integer
//! symbol  := [A-Za-z_][A-Za-z0-9_]*
//! integer := [0-9]+
//! ```
//!
//! `*` is left-associative; `/` only accepts an integer literal on its
//! right-hand side (matches how every real-world shape expression in
//! llama-rs factorizes).

use std::collections::BTreeMap;
use std::fmt;
use std::sync::OnceLock;

use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::diagnostics::DiagnosticError;

pub type SymbolTable = BTreeMap<String, u64>;

/// A single shape expression. Wraps the raw source string plus a
/// lazily-cached parse of its AST so repeated `evaluate()` calls in
/// the comparison loop don't re-parse the same string.
#[derive(Debug)]
pub struct ShapeExpr {
    source: String,
    ast: OnceLock<Result<Node, DiagnosticError>>,
}

impl ShapeExpr {
    pub fn from_str(s: impl Into<String>) -> Self {
        Self {
            source: s.into(),
            ast: OnceLock::new(),
        }
    }

    pub fn as_source(&self) -> &str {
        &self.source
    }

    /// Evaluate against the provided symbol table. First call parses
    /// the AST and caches it; subsequent calls reuse the cached AST.
    pub fn evaluate(&self, syms: &SymbolTable) -> Result<u64, DiagnosticError> {
        match self.ast.get_or_init(|| parse(&self.source)) {
            Ok(ast) => evaluate_ast(ast, &self.source, syms),
            Err(e) => Err(e.clone()),
        }
    }
}

impl Clone for ShapeExpr {
    fn clone(&self) -> Self {
        // Don't clone the cache — a clone is rare (tests, profile
        // merge) and forcing a re-parse keeps the type's hash / eq
        // behavior agnostic of cache state.
        Self::from_str(self.source.clone())
    }
}

impl PartialEq for ShapeExpr {
    fn eq(&self, other: &Self) -> bool {
        self.source == other.source
    }
}

impl Eq for ShapeExpr {}

impl fmt::Display for ShapeExpr {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(&self.source)
    }
}

impl Serialize for ShapeExpr {
    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
        serializer.serialize_str(&self.source)
    }
}

impl<'de> Deserialize<'de> for ShapeExpr {
    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
        let s = String::deserialize(deserializer)?;
        Ok(ShapeExpr::from_str(s))
    }
}

// =============================================================================
// Parser
// =============================================================================

#[derive(Debug, Clone, PartialEq)]
enum Node {
    Lit(u64),
    Sym(String),
    Mul(Box<Node>, Box<Node>),
    Div(Box<Node>, u64),
}

fn parse(src: &str) -> Result<Node, DiagnosticError> {
    let tokens = tokenize(src)?;
    let mut p = Parser { tokens, pos: 0, src };
    let node = p.parse_expr()?;
    if p.pos != p.tokens.len() {
        return Err(DiagnosticError::ShapeExprParse {
            expr: src.to_string(),
            message: format!("unexpected trailing token at position {}", p.pos),
        });
    }
    Ok(node)
}

struct Parser<'a> {
    tokens: Vec<Token>,
    pos: usize,
    src: &'a str,
}

#[derive(Debug, Clone, PartialEq)]
enum Token {
    Integer(u64),
    Symbol(String),
    Star,
    Slash,
}

impl<'a> Parser<'a> {
    fn peek(&self) -> Option<&Token> {
        self.tokens.get(self.pos)
    }

    fn bump(&mut self) -> Option<Token> {
        let t = self.tokens.get(self.pos).cloned();
        self.pos += 1;
        t
    }

    fn parse_expr(&mut self) -> Result<Node, DiagnosticError> {
        let mut left = self.parse_term()?;
        while matches!(self.peek(), Some(Token::Star)) {
            self.bump();
            let right = self.parse_term()?;
            left = Node::Mul(Box::new(left), Box::new(right));
        }
        Ok(left)
    }

    fn parse_term(&mut self) -> Result<Node, DiagnosticError> {
        let mut left = self.parse_atom()?;
        while matches!(self.peek(), Some(Token::Slash)) {
            self.bump();
            let divisor = match self.bump() {
                Some(Token::Integer(n)) if n > 0 => n,
                Some(Token::Integer(_)) => {
                    return Err(self.err("division by zero literal"))
                }
                _ => {
                    return Err(self.err(
                        "`/` must be followed by a positive integer literal",
                    ))
                }
            };
            left = Node::Div(Box::new(left), divisor);
        }
        Ok(left)
    }

    fn parse_atom(&mut self) -> Result<Node, DiagnosticError> {
        match self.bump() {
            Some(Token::Integer(n)) => Ok(Node::Lit(n)),
            Some(Token::Symbol(s)) => Ok(Node::Sym(s)),
            Some(other) => Err(self.err(&format!("expected atom, got {other:?}"))),
            None => Err(self.err("expected atom, got end of input")),
        }
    }

    fn err(&self, msg: &str) -> DiagnosticError {
        DiagnosticError::ShapeExprParse {
            expr: self.src.to_string(),
            message: msg.to_string(),
        }
    }
}

fn tokenize(src: &str) -> Result<Vec<Token>, DiagnosticError> {
    let bytes = src.as_bytes();
    let mut out = Vec::new();
    let mut i = 0;
    while i < bytes.len() {
        let c = bytes[i];
        if c.is_ascii_whitespace() {
            i += 1;
            continue;
        }
        if c == b'*' {
            out.push(Token::Star);
            i += 1;
        } else if c == b'/' {
            out.push(Token::Slash);
            i += 1;
        } else if c.is_ascii_digit() {
            let start = i;
            while i < bytes.len() && bytes[i].is_ascii_digit() {
                i += 1;
            }
            let n: u64 =
                src[start..i].parse().map_err(|e| DiagnosticError::ShapeExprParse {
                    expr: src.to_string(),
                    message: format!("integer parse: {e}"),
                })?;
            out.push(Token::Integer(n));
        } else if c == b'_' || c.is_ascii_alphabetic() {
            let start = i;
            while i < bytes.len()
                && (bytes[i] == b'_' || bytes[i].is_ascii_alphanumeric())
            {
                i += 1;
            }
            out.push(Token::Symbol(src[start..i].to_string()));
        } else {
            return Err(DiagnosticError::ShapeExprParse {
                expr: src.to_string(),
                message: format!("unexpected character `{}` at position {i}", c as char),
            });
        }
    }
    Ok(out)
}

fn evaluate_ast(
    node: &Node,
    src: &str,
    syms: &SymbolTable,
) -> Result<u64, DiagnosticError> {
    match node {
        Node::Lit(n) => Ok(*n),
        Node::Sym(s) => syms.get(s).copied().ok_or_else(|| {
            DiagnosticError::UnresolvedShapeSymbol {
                expr: src.to_string(),
                symbol: s.clone(),
            }
        }),
        Node::Mul(a, b) => {
            let av = evaluate_ast(a, src, syms)?;
            let bv = evaluate_ast(b, src, syms)?;
            av.checked_mul(bv).ok_or_else(|| {
                DiagnosticError::ShapeExprParse {
                    expr: src.to_string(),
                    message: format!("overflow evaluating {av} * {bv}"),
                }
            })
        }
        Node::Div(a, d) => {
            let av = evaluate_ast(a, src, syms)?;
            if av % d != 0 {
                return Err(DiagnosticError::ShapeExprParse {
                    expr: src.to_string(),
                    message: format!("{av} is not divisible by {d}"),
                });
            }
            Ok(av / d)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn syms<const N: usize>(pairs: [(&str, u64); N]) -> SymbolTable {
        pairs.iter().map(|(k, v)| ((*k).to_string(), *v)).collect()
    }

    #[test]
    fn evaluates_literal() {
        let e = ShapeExpr::from_str("4096");
        assert_eq!(e.evaluate(&SymbolTable::new()).unwrap(), 4096);
    }

    #[test]
    fn evaluates_symbol() {
        let e = ShapeExpr::from_str("hidden");
        assert_eq!(
            e.evaluate(&syms([("hidden", 4096)])).unwrap(),
            4096
        );
    }

    #[test]
    fn evaluates_multiplication() {
        let e = ShapeExpr::from_str("n_heads * head_dim");
        assert_eq!(
            e.evaluate(&syms([("n_heads", 32), ("head_dim", 128)])).unwrap(),
            32 * 128
        );
    }

    #[test]
    fn evaluates_division_by_literal() {
        let e = ShapeExpr::from_str("vocab / 4");
        assert_eq!(
            e.evaluate(&syms([("vocab", 128)])).unwrap(),
            32
        );
    }

    #[test]
    fn multiplication_is_left_associative_and_combines_with_division() {
        let e = ShapeExpr::from_str("n_heads * head_dim / 2");
        assert_eq!(
            e.evaluate(&syms([("n_heads", 32), ("head_dim", 128)])).unwrap(),
            (32 * 128) / 2
        );
    }

    #[test]
    fn unresolved_symbol_reports_the_missing_name() {
        let e = ShapeExpr::from_str("missing * hidden");
        let err = e.evaluate(&syms([("hidden", 4096)])).unwrap_err();
        match err {
            DiagnosticError::UnresolvedShapeSymbol { expr, symbol } => {
                assert_eq!(expr, "missing * hidden");
                assert_eq!(symbol, "missing");
            }
            other => panic!("expected UnresolvedShapeSymbol, got {other:?}"),
        }
    }

    #[test]
    fn rejects_division_by_non_literal() {
        let e = ShapeExpr::from_str("hidden / n_heads");
        let err = e.evaluate(&syms([("hidden", 4096), ("n_heads", 32)])).unwrap_err();
        assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
    }

    #[test]
    fn rejects_non_integer_division() {
        let e = ShapeExpr::from_str("5 / 2");
        let err = e.evaluate(&SymbolTable::new()).unwrap_err();
        match err {
            DiagnosticError::ShapeExprParse { message, .. } => {
                assert!(message.contains("not divisible"), "msg: {message}");
            }
            other => panic!("expected ShapeExprParse, got {other:?}"),
        }
    }

    #[test]
    fn rejects_division_by_zero_literal() {
        let e = ShapeExpr::from_str("hidden / 0");
        let err = e.evaluate(&syms([("hidden", 4096)])).unwrap_err();
        assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
    }

    #[test]
    fn rejects_trailing_garbage() {
        let e = ShapeExpr::from_str("hidden *");
        let err = e.evaluate(&syms([("hidden", 4)])).unwrap_err();
        assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
    }

    #[test]
    fn rejects_leading_operator() {
        let e = ShapeExpr::from_str("/ 2");
        let err = e.evaluate(&SymbolTable::new()).unwrap_err();
        assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
    }

    #[test]
    fn rejects_unknown_character() {
        let e = ShapeExpr::from_str("hidden + 1");
        let err = e.evaluate(&syms([("hidden", 4)])).unwrap_err();
        assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
    }

    #[test]
    fn allows_underscores_in_symbols() {
        let e = ShapeExpr::from_str("n_kv_heads");
        assert_eq!(e.evaluate(&syms([("n_kv_heads", 8)])).unwrap(), 8);
    }

    #[test]
    fn handles_whitespace_gracefully() {
        let e = ShapeExpr::from_str("  n_heads  *  head_dim  ");
        assert_eq!(
            e.evaluate(&syms([("n_heads", 2), ("head_dim", 4)])).unwrap(),
            8
        );
    }

    #[test]
    fn display_renders_source() {
        let e = ShapeExpr::from_str("a * b");
        assert_eq!(format!("{e}"), "a * b");
    }
}