use std::collections::BTreeMap;
use std::fmt;
use std::sync::OnceLock;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::diagnostics::DiagnosticError;
pub type SymbolTable = BTreeMap<String, u64>;
#[derive(Debug)]
pub struct ShapeExpr {
source: String,
ast: OnceLock<Result<Node, DiagnosticError>>,
}
impl ShapeExpr {
pub fn from_str(s: impl Into<String>) -> Self {
Self {
source: s.into(),
ast: OnceLock::new(),
}
}
pub fn as_source(&self) -> &str {
&self.source
}
pub fn evaluate(&self, syms: &SymbolTable) -> Result<u64, DiagnosticError> {
match self.ast.get_or_init(|| parse(&self.source)) {
Ok(ast) => evaluate_ast(ast, &self.source, syms),
Err(e) => Err(e.clone()),
}
}
}
impl Clone for ShapeExpr {
fn clone(&self) -> Self {
Self::from_str(self.source.clone())
}
}
impl PartialEq for ShapeExpr {
fn eq(&self, other: &Self) -> bool {
self.source == other.source
}
}
impl Eq for ShapeExpr {}
impl fmt::Display for ShapeExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.source)
}
}
impl Serialize for ShapeExpr {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.source)
}
}
impl<'de> Deserialize<'de> for ShapeExpr {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(ShapeExpr::from_str(s))
}
}
#[derive(Debug, Clone, PartialEq)]
enum Node {
Lit(u64),
Sym(String),
Mul(Box<Node>, Box<Node>),
Div(Box<Node>, u64),
}
fn parse(src: &str) -> Result<Node, DiagnosticError> {
let tokens = tokenize(src)?;
let mut p = Parser { tokens, pos: 0, src };
let node = p.parse_expr()?;
if p.pos != p.tokens.len() {
return Err(DiagnosticError::ShapeExprParse {
expr: src.to_string(),
message: format!("unexpected trailing token at position {}", p.pos),
});
}
Ok(node)
}
struct Parser<'a> {
tokens: Vec<Token>,
pos: usize,
src: &'a str,
}
#[derive(Debug, Clone, PartialEq)]
enum Token {
Integer(u64),
Symbol(String),
Star,
Slash,
}
impl<'a> Parser<'a> {
fn peek(&self) -> Option<&Token> {
self.tokens.get(self.pos)
}
fn bump(&mut self) -> Option<Token> {
let t = self.tokens.get(self.pos).cloned();
self.pos += 1;
t
}
fn parse_expr(&mut self) -> Result<Node, DiagnosticError> {
let mut left = self.parse_term()?;
while matches!(self.peek(), Some(Token::Star)) {
self.bump();
let right = self.parse_term()?;
left = Node::Mul(Box::new(left), Box::new(right));
}
Ok(left)
}
fn parse_term(&mut self) -> Result<Node, DiagnosticError> {
let mut left = self.parse_atom()?;
while matches!(self.peek(), Some(Token::Slash)) {
self.bump();
let divisor = match self.bump() {
Some(Token::Integer(n)) if n > 0 => n,
Some(Token::Integer(_)) => {
return Err(self.err("division by zero literal"))
}
_ => {
return Err(self.err(
"`/` must be followed by a positive integer literal",
))
}
};
left = Node::Div(Box::new(left), divisor);
}
Ok(left)
}
fn parse_atom(&mut self) -> Result<Node, DiagnosticError> {
match self.bump() {
Some(Token::Integer(n)) => Ok(Node::Lit(n)),
Some(Token::Symbol(s)) => Ok(Node::Sym(s)),
Some(other) => Err(self.err(&format!("expected atom, got {other:?}"))),
None => Err(self.err("expected atom, got end of input")),
}
}
fn err(&self, msg: &str) -> DiagnosticError {
DiagnosticError::ShapeExprParse {
expr: self.src.to_string(),
message: msg.to_string(),
}
}
}
fn tokenize(src: &str) -> Result<Vec<Token>, DiagnosticError> {
let bytes = src.as_bytes();
let mut out = Vec::new();
let mut i = 0;
while i < bytes.len() {
let c = bytes[i];
if c.is_ascii_whitespace() {
i += 1;
continue;
}
if c == b'*' {
out.push(Token::Star);
i += 1;
} else if c == b'/' {
out.push(Token::Slash);
i += 1;
} else if c.is_ascii_digit() {
let start = i;
while i < bytes.len() && bytes[i].is_ascii_digit() {
i += 1;
}
let n: u64 =
src[start..i].parse().map_err(|e| DiagnosticError::ShapeExprParse {
expr: src.to_string(),
message: format!("integer parse: {e}"),
})?;
out.push(Token::Integer(n));
} else if c == b'_' || c.is_ascii_alphabetic() {
let start = i;
while i < bytes.len()
&& (bytes[i] == b'_' || bytes[i].is_ascii_alphanumeric())
{
i += 1;
}
out.push(Token::Symbol(src[start..i].to_string()));
} else {
return Err(DiagnosticError::ShapeExprParse {
expr: src.to_string(),
message: format!("unexpected character `{}` at position {i}", c as char),
});
}
}
Ok(out)
}
fn evaluate_ast(
node: &Node,
src: &str,
syms: &SymbolTable,
) -> Result<u64, DiagnosticError> {
match node {
Node::Lit(n) => Ok(*n),
Node::Sym(s) => syms.get(s).copied().ok_or_else(|| {
DiagnosticError::UnresolvedShapeSymbol {
expr: src.to_string(),
symbol: s.clone(),
}
}),
Node::Mul(a, b) => {
let av = evaluate_ast(a, src, syms)?;
let bv = evaluate_ast(b, src, syms)?;
av.checked_mul(bv).ok_or_else(|| {
DiagnosticError::ShapeExprParse {
expr: src.to_string(),
message: format!("overflow evaluating {av} * {bv}"),
}
})
}
Node::Div(a, d) => {
let av = evaluate_ast(a, src, syms)?;
if av % d != 0 {
return Err(DiagnosticError::ShapeExprParse {
expr: src.to_string(),
message: format!("{av} is not divisible by {d}"),
});
}
Ok(av / d)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn syms<const N: usize>(pairs: [(&str, u64); N]) -> SymbolTable {
pairs.iter().map(|(k, v)| ((*k).to_string(), *v)).collect()
}
#[test]
fn evaluates_literal() {
let e = ShapeExpr::from_str("4096");
assert_eq!(e.evaluate(&SymbolTable::new()).unwrap(), 4096);
}
#[test]
fn evaluates_symbol() {
let e = ShapeExpr::from_str("hidden");
assert_eq!(
e.evaluate(&syms([("hidden", 4096)])).unwrap(),
4096
);
}
#[test]
fn evaluates_multiplication() {
let e = ShapeExpr::from_str("n_heads * head_dim");
assert_eq!(
e.evaluate(&syms([("n_heads", 32), ("head_dim", 128)])).unwrap(),
32 * 128
);
}
#[test]
fn evaluates_division_by_literal() {
let e = ShapeExpr::from_str("vocab / 4");
assert_eq!(
e.evaluate(&syms([("vocab", 128)])).unwrap(),
32
);
}
#[test]
fn multiplication_is_left_associative_and_combines_with_division() {
let e = ShapeExpr::from_str("n_heads * head_dim / 2");
assert_eq!(
e.evaluate(&syms([("n_heads", 32), ("head_dim", 128)])).unwrap(),
(32 * 128) / 2
);
}
#[test]
fn unresolved_symbol_reports_the_missing_name() {
let e = ShapeExpr::from_str("missing * hidden");
let err = e.evaluate(&syms([("hidden", 4096)])).unwrap_err();
match err {
DiagnosticError::UnresolvedShapeSymbol { expr, symbol } => {
assert_eq!(expr, "missing * hidden");
assert_eq!(symbol, "missing");
}
other => panic!("expected UnresolvedShapeSymbol, got {other:?}"),
}
}
#[test]
fn rejects_division_by_non_literal() {
let e = ShapeExpr::from_str("hidden / n_heads");
let err = e.evaluate(&syms([("hidden", 4096), ("n_heads", 32)])).unwrap_err();
assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
}
#[test]
fn rejects_non_integer_division() {
let e = ShapeExpr::from_str("5 / 2");
let err = e.evaluate(&SymbolTable::new()).unwrap_err();
match err {
DiagnosticError::ShapeExprParse { message, .. } => {
assert!(message.contains("not divisible"), "msg: {message}");
}
other => panic!("expected ShapeExprParse, got {other:?}"),
}
}
#[test]
fn rejects_division_by_zero_literal() {
let e = ShapeExpr::from_str("hidden / 0");
let err = e.evaluate(&syms([("hidden", 4096)])).unwrap_err();
assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
}
#[test]
fn rejects_trailing_garbage() {
let e = ShapeExpr::from_str("hidden *");
let err = e.evaluate(&syms([("hidden", 4)])).unwrap_err();
assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
}
#[test]
fn rejects_leading_operator() {
let e = ShapeExpr::from_str("/ 2");
let err = e.evaluate(&SymbolTable::new()).unwrap_err();
assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
}
#[test]
fn rejects_unknown_character() {
let e = ShapeExpr::from_str("hidden + 1");
let err = e.evaluate(&syms([("hidden", 4)])).unwrap_err();
assert!(matches!(err, DiagnosticError::ShapeExprParse { .. }));
}
#[test]
fn allows_underscores_in_symbols() {
let e = ShapeExpr::from_str("n_kv_heads");
assert_eq!(e.evaluate(&syms([("n_kv_heads", 8)])).unwrap(), 8);
}
#[test]
fn handles_whitespace_gracefully() {
let e = ShapeExpr::from_str(" n_heads * head_dim ");
assert_eq!(
e.evaluate(&syms([("n_heads", 2), ("head_dim", 4)])).unwrap(),
8
);
}
#[test]
fn display_renders_source() {
let e = ShapeExpr::from_str("a * b");
assert_eq!(format!("{e}"), "a * b");
}
}