use crate::{
FrameType,
basis_grammar::{Alias, expand_compound_alias},
builder::BasisElem,
clifford::{
ExteriorProduct, InnerProduct, Involution, LeftContraction, RegressiveProduct as _,
RightContraction, SynAlgebra, SynMultivector,
},
spec::{AlgebraSpec, ShapeDef},
symbol_field::CoefficientField,
};
use abstalg::{AbelianGroup, CommuntativeMonoid, Monoid, Semigroup};
use itertools::{EitherOrBoth, Itertools};
use proc_macro2::Span;
use std::collections::{BTreeMap, HashMap, HashSet};
use syn::{
BinOp, Expr, ExprClosure, FieldValue, Member, Stmt, Token,
punctuated::Punctuated,
spanned::Spanned,
visit::{self, Visit},
};
macro_rules! arity_args {
($arity:expr, $args:expr) => {
if $args.len() == $arity {
$args
} else {
return Err(syn::Error::new($args.span(), "unexpected arity"));
}
};
}
pub struct Optimizer {
pub spec: AlgebraSpec,
pub alg: SynAlgebra<FrameType>,
values: HashMap<usize, SynMultivector<FrameType>>,
scopes: Vec<HashMap<String, SynMultivector<FrameType>>>,
basis_aliases: HashSet<String>,
pub mv: Option<SynMultivector<FrameType>>,
pub err: Option<syn::Error>,
}
impl Optimizer {
pub fn new(spec: AlgebraSpec, aliases: BTreeMap<Alias, BasisElem>) -> syn::Result<Self> {
let alg = spec.algebra()?;
let alias_values = Self::materialize_aliases(&alg, aliases);
Ok(Self {
spec,
alg,
values: HashMap::new(),
basis_aliases: alias_values.keys().cloned().collect(),
scopes: vec![alias_values, HashMap::new()],
mv: None,
err: None,
})
}
pub fn run_pass1(mut self, fun: &mut syn::ExprClosure) -> syn::Result<Self> {
self.visit_expr_closure(fun);
if let Some(err) = self.err.take() {
return Err(err);
}
fun.body = Box::new(self.shape_literal()?);
Ok(self)
}
pub fn run_pass2(self, _fun: &mut syn::ExprClosure) -> syn::Result<Self> {
Ok(self)
}
fn materialize_aliases(
alg: &SynAlgebra<FrameType>,
aliases: BTreeMap<Alias, BasisElem>,
) -> HashMap<String, SynMultivector<FrameType>> {
aliases
.into_iter()
.map(|(alias, elem)| {
(
alias.name,
alg.from_terms(elem.vectors.iter().map(|(mask, coeff)| {
let value = coeff.value();
let lit_expr = syn::parse_quote! { #value };
(alg.basis_blade(*mask), alg.syn.wrap_numeric_expr(lit_expr))
})),
)
})
.collect()
}
fn expr_key(expr: &Expr) -> usize {
expr as *const _ as usize
}
fn set_value(&mut self, expr: &Expr, value: SynMultivector<FrameType>) {
self.values.insert(Self::expr_key(expr), value);
}
fn value_of(&self, expr: &Expr) -> Option<&SynMultivector<FrameType>> {
self.values.get(&Self::expr_key(expr))
}
fn require_value(&self, expr: &Expr) -> syn::Result<SynMultivector<FrameType>> {
self.value_of(expr)
.cloned()
.ok_or_else(|| syn::Error::new(expr.span(), "missing cached value during optimisation"))
}
fn scalar_mv(&self, expr: syn::Expr) -> SynMultivector<FrameType> {
self.alg.scalar(self.alg.syn.wrap_expr(expr))
}
fn bind_ident(
&mut self,
name: syn::Ident,
value: SynMultivector<FrameType>,
) -> syn::Result<()> {
match self.scopes.last_mut() {
Some(scope) => match scope.insert(name.to_string(), value) {
None => Ok(()),
Some(_) => Err(syn::Error::new(
name.span(),
"name shadowing not currently supported",
)),
},
None => Err(syn::Error::new(
name.span(),
"no scope available to bind identifier during optimisation",
)),
}
}
fn bind_pattern(
&mut self,
pat: &syn::Pat,
value: SynMultivector<FrameType>,
) -> syn::Result<()> {
match pat {
syn::Pat::Ident(pat_ident) => {
if pat_ident.subpat.is_some() {
return Err(syn::Error::new(
pat_ident.ident.span(),
"`let` patterns with sub-bindings are not supported",
));
}
self.bind_ident(pat_ident.ident.clone(), value)?;
Ok(())
}
syn::Pat::Wild(_) => Ok(()),
syn::Pat::Type(pat_type) => self.bind_pattern(&pat_type.pat, value),
syn::Pat::Reference(pat_ref) => self.bind_pattern(&pat_ref.pat, value),
syn::Pat::Paren(pat_paren) => self.bind_pattern(&pat_paren.pat, value),
_ => Err(syn::Error::new(
pat.span(),
"unsupported `let` pattern in build_expr optimisation",
)),
}
}
fn lookup_binding(&self, name: &str) -> Option<SynMultivector<FrameType>> {
self.scopes
.iter()
.rev()
.find_map(|scope| scope.get(name))
.cloned()
.or_else(|| self.alias_multivector(name))
}
fn bind_closure_inputs(&mut self, closure: &ExprClosure) -> syn::Result<()> {
for pat in closure.inputs.iter() {
match pat {
syn::Pat::Type(pat_type) => {
let (ident, shape) = self.extract_binding(pat_type)?;
let value = self.build_shape_parameter_mv(&ident, shape)?;
self.bind_ident(ident.clone(), value)?;
}
syn::Pat::Ident(pat_ident) => {
return Err(syn::Error::new(
pat_ident.ident.span(),
"closure parameters in `expr!` must have an explicit shape type",
));
}
other => {
return Err(syn::Error::new(
other.span(),
"unsupported closure parameter pattern in `expr!`",
));
}
}
}
Ok(())
}
fn build_shape_parameter_mv(
&self,
ident: &syn::Ident,
shape: &ShapeDef,
) -> syn::Result<SynMultivector<FrameType>> {
let mut accum = self.alg.zero();
for field in shape.fields.iter() {
let field_ident = field.ident();
let (alias_name, alias_span) = field.alias_name();
let Some(basis_mv) = self.alias_multivector(alias_name) else {
let name = &shape.name;
let msg = format!("unknown alias `{alias_name}` referenced in shape `{name}`");
return Err(syn::Error::new(alias_span, msg));
};
let field_access: syn::Expr = syn::parse_quote! { #ident.#field_ident };
let scalar = self.alg.scalar(self.alg.syn.wrap_expr(field_access));
let scaled = self.alg.mul(&scalar, &basis_mv);
accum = self.alg.add(&accum, &scaled);
}
Ok(accum)
}
fn find_shape_by_type(&self, ty: &syn::Type) -> Option<&ShapeDef> {
let type_ident = match ty {
syn::Type::Path(type_path) => type_path.path.segments.last().map(|seg| &seg.ident),
_ => None,
}?;
self.spec
.shape_defs
.iter()
.find(|shape| shape.name == *type_ident)
}
fn extract_binding<'a>(
&'a self,
pat_type: &'a syn::PatType,
) -> syn::Result<(&'a syn::Ident, &'a ShapeDef)> {
let ident = match pat_type.pat.as_ref() {
syn::Pat::Ident(pat_ident) => &pat_ident.ident,
other => {
return Err(syn::Error::new(
other.span(),
"unsupported binding pattern in closure parameter",
));
}
};
let shape = self
.find_shape_by_type(pat_type.ty.as_ref())
.ok_or_else(|| {
syn::Error::new(pat_type.ty.span(), "expected shape type for parameter")
})?;
Ok((ident, shape))
}
fn alias_value(&self, name: &str) -> Option<SynMultivector<FrameType>> {
self.scopes
.first()
.and_then(|scope| scope.get(name))
.cloned()
.or_else(|| match name {
"0" => Some(self.alg.zero()),
"1" => Some(self.alg.one()),
_ => None,
})
}
fn alias_multivector(&self, name: &str) -> Option<SynMultivector<FrameType>> {
if let Some(value) = self.alias_value(name) {
return Some(value.clone());
}
let components = expand_compound_alias(name, &self.basis_aliases)?;
let mut iter = components.into_iter();
let first = iter.next()?;
let mut acc = self.alias_multivector(&first)?;
for component in iter {
let rhs = self.alias_multivector(&component)?;
let next = self.alg.wedge(&acc, &rhs);
acc = next;
}
Some(acc)
}
pub fn shape_literal(&mut self) -> syn::Result<Expr> {
let mut mv = match self.mv.take() {
Some(mv) => mv,
None => {
return Err(syn::Error::new(
Span::call_site(),
"no multivector available for shape literal construction",
));
}
};
self.alg.drop_zeros(&mut mv);
let mut low_score = usize::MAX;
let mut best_shape = None;
for shape in &self.spec.shape_defs {
if let Some(score) = self.shape_penalty(shape, &mv)? {
if score < low_score {
low_score = score;
best_shape = Some(shape);
}
if score == 0 {
break;
}
}
}
let Some(shape) = best_shape else {
#[cfg(debug_assertions)]
eprintln!(
"shape literal search failed for {} with mv:\n{}",
self.spec.module_name, mv
);
return Err(syn::Error::new(
self.spec.module_name.span(),
"no matching shape literal found",
));
};
if low_score > 0 {
#[cfg(debug_assertions)]
println!(
"shape literal search found suboptimal penalty {} for {}\nConsider adding a new shape definition.",
low_score, shape.name
);
}
Ok(syn::Expr::Struct(self.build_shape_struct(mv, shape)?))
}
fn build_shape_struct(
&self,
mv: SynMultivector<FrameType>,
shape_def: &ShapeDef,
) -> syn::Result<syn::ExprStruct> {
let mod_ident = &self.spec.module_name;
let shape_ident = &shape_def.name;
let mut fields = Punctuated::new();
for field in shape_def.fields.iter() {
let (alias_name, _) = field.alias_name();
let basis_mv = self.alias_multivector(alias_name).unwrap();
let (basis, sign) = basis_mv.vectors.into_iter().next().unwrap();
let coeff = self.alg.coeff(&mv, &basis);
let value = self.alg.scalar_field().mul(&sign, &coeff);
fields.push(FieldValue {
attrs: vec![],
member: Member::Named(field.ident().clone()),
colon_token: Some(Token)),
expr: self.alg.scalar_field().to_expr(&value)?,
});
}
Ok(syn::ExprStruct {
attrs: vec![],
qself: None,
path: syn::parse_quote! { #mod_ident :: #shape_ident },
brace_token: shape_def.braces.clone(),
fields,
dot2_token: None,
rest: None,
})
}
fn shape_penalty(
&self,
shape: &ShapeDef,
mv: &SynMultivector<FrameType>,
) -> syn::Result<Option<usize>> {
if self.alg.is_zero(mv) {
return Ok(Some(shape.fields.len()));
}
let mut coeffs_mv = self.alg.zero();
for field in shape.fields.iter() {
let (alias_name, alias_span) = field.alias_name();
let Some(single_elem_mv) = self.alias_multivector(alias_name) else {
return Err(syn::Error::new(
alias_span,
format!("unknown alias `{alias_name}` in shape literal"),
));
};
self.alg.add_assign(&mut coeffs_mv, &single_elem_mv);
}
let mut penalty = 0;
for either in mv
.vectors
.keys()
.merge_join_by(coeffs_mv.vectors.keys(), |a, b| a.cmp(b))
{
match either {
EitherOrBoth::Both(_, _) => (),
EitherOrBoth::Right(_) => penalty += 1,
EitherOrBoth::Left(_) => penalty += u16::MAX as usize,
}
}
Ok(Some(penalty))
}
fn evaluate_expr(&mut self, expr: &Expr) -> syn::Result<()> {
match expr {
Expr::Binary(bin) => {
let lhs = self.require_value(&bin.left)?;
let rhs = self.require_value(&bin.right)?;
let value = match bin.op {
BinOp::Add(_) => self.alg.add(&lhs, &rhs),
BinOp::Sub(_) => self.alg.sub(&lhs, &rhs),
BinOp::Mul(_) => self.alg.mul(&lhs, &rhs),
BinOp::BitXor(_) => self.alg.wedge(&lhs, &rhs),
BinOp::BitOr(_) => self.alg.inner(&lhs, &rhs),
BinOp::BitAnd(_) => self.alg.antiwedge(&lhs, &rhs),
_ => {
return Err(syn::Error::new(
bin.op.span(),
"unsupported operator in Clifford expression",
));
}
};
self.set_value(expr, value);
}
Expr::Unary(unary) => {
let inner = self.require_value(&unary.expr)?;
let value = match unary.op {
syn::UnOp::Neg(_) => self.alg.neg(&inner),
syn::UnOp::Deref(_) => inner,
_ => {
return Err(syn::Error::new(
unary.op.span(),
"unsupported unary operator in Clifford expression",
));
}
};
self.set_value(expr, value);
}
Expr::Path(path) => {
if let Some(value) = path
.path
.get_ident()
.and_then(|ident| self.lookup_binding(&ident.to_string()))
{
self.set_value(expr, value);
} else {
self.set_value(expr, self.scalar_mv(syn::Expr::Path(path.clone())));
}
}
Expr::Lit(lit) => {
self.set_value(expr, self.scalar_mv(syn::Expr::Lit(lit.clone())));
}
Expr::Paren(paren) => {
let inner = self.require_value(&paren.expr)?;
self.set_value(expr, inner);
}
Expr::Group(group) => {
let inner = self.require_value(&group.expr)?;
self.set_value(expr, inner);
}
Expr::MethodCall(call) => {
let receiver = self.require_value(&call.receiver)?;
let method = call.method.to_string();
let value = match method.as_str() {
"exp" => {
arity_args!(0, &call.args);
self.alg.exp(&receiver)
}
"reverse" => {
arity_args!(0, &call.args);
self.alg.reverse(receiver.clone())
}
"sandwich" => {
let args = arity_args!(1, &call.args);
let target_value = self.require_value(&args[0])?;
self.alg.sandwich(&receiver, &target_value)
}
"dual" => {
arity_args!(0, &call.args);
self.alg.right_dual(&receiver)
}
"complement" => {
arity_args!(0, &call.args);
self.alg.complement(&receiver)
}
"sqrt" => {
arity_args!(0, &call.args);
self.alg.sqrt(&receiver)
}
"norm_squared" => {
arity_args!(0, &call.args);
let scalar_val = self.alg.norm_squared(&receiver);
self.alg.scalar(scalar_val)
}
"conjugate" => {
arity_args!(0, &call.args);
self.alg.conjugate(receiver.clone())
}
"automorphism" => {
arity_args!(0, &call.args);
self.alg.automorphism(receiver.clone())
}
"left_contract" => {
let args = arity_args!(1, &call.args);
let target_value = self.require_value(&args[0])?;
self.alg.contract_onto(&receiver, &target_value)
}
"right_contract" => {
let args = arity_args!(1, &call.args);
let target_value = self.require_value(&args[0])?;
self.alg.contract_by(&receiver, &target_value)
}
"scalar" => {
arity_args!(0, &call.args);
match self.alg.just_scalar(&receiver) {
Some(scalar_val) => self.alg.scalar(scalar_val),
None => self.alg.scalar(self.alg.syn.zero()),
}
}
_ => {
return Err(syn::Error::new(
call.method.span(),
format!("unsupported method `{method}` in Clifford expression"),
));
}
};
self.set_value(expr, value);
}
Expr::Block(block) => {
let Some(Stmt::Expr(last_expr, None)) = block.block.stmts.last() else {
return Err(syn::Error::new(
block.block.span(),
"unsupported block in Clifford expression",
));
};
let value = self.require_value(last_expr)?;
self.set_value(expr, value);
}
Expr::Closure(_) => {
}
Expr::Array(_)
| Expr::Assign(_)
| Expr::Async(_)
| Expr::Await(_)
| Expr::Call(_)
| Expr::Cast(_)
| Expr::Field(_)
| Expr::ForLoop(_)
| Expr::If(_)
| Expr::Index(_)
| Expr::Let(_) | Expr::Loop(_)
| Expr::Macro(_)
| Expr::Match(_)
| Expr::Range(_)
| Expr::Reference(_)
| Expr::Repeat(_)
| Expr::Struct(_)
| Expr::Try(_)
| Expr::TryBlock(_)
| Expr::Tuple(_)
| Expr::While(_)
| Expr::Yield(_)|Expr::Break(_) | Expr::Continue(_) | Expr::Return(_) | Expr::Verbatim(_) => {
return Err(syn::Error::new(
expr.span(),
"unsupported syntax in Clifford expression",
));
}
_ => {
return Err(syn::Error::new(
expr.span(),
"unsupported syntax in Clifford expression",
));
}
}
Ok(())
}
}
macro_rules! v_guard {
($elf:ident) => {
if $elf.err.is_some() {
return;
}
};
}
macro_rules! v_throw {
($elf:ident, $error:expr) => {
return $elf.err = Some($error)
};
}
macro_rules! v_try {
($elf:ident, $expr:expr) => {
match $expr {
Ok(value) => value,
Err(err) => v_throw!($elf, err),
}
};
}
impl Visit<'_> for Optimizer {
fn visit_block(&mut self, block: &syn::Block) {
v_guard!(self);
self.scopes.push(HashMap::new());
visit::visit_block(self, block);
self.scopes.pop();
}
fn visit_local(&mut self, local: &syn::Local) {
v_guard!(self);
visit::visit_local(self, local);
v_guard!(self);
let Some(init) = &local.init else {
let msg = "`let` bindings in expr! closures must have an initializer";
v_throw!(self, syn::Error::new(local.span(), msg));
};
if init.diverge.is_some() {
let msg = "`let ... else` is not supported in expr!";
v_throw!(self, syn::Error::new(local.span(), msg));
}
let value = v_try!(self, self.require_value(&init.expr));
v_try!(self, self.bind_pattern(&local.pat, value));
}
fn visit_expr(&mut self, expr: &Expr) {
v_guard!(self);
visit::visit_expr(self, expr);
v_guard!(self);
v_try!(self, self.evaluate_expr(expr));
}
fn visit_expr_closure(&mut self, closure: &ExprClosure) {
v_guard!(self);
self.scopes.push(HashMap::new());
if let Err(err) = self.bind_closure_inputs(closure) {
self.scopes.pop();
v_throw!(self, err);
}
visit::visit_expr_closure(self, closure);
self.scopes.pop();
v_guard!(self);
if let Some(value) = self.value_of(&closure.body).cloned() {
self.mv = Some(value);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use std::collections::HashMap;
fn fixture_spec(tokens: TokenStream) -> AlgebraSpec {
syn::parse2::<AlgebraSpec>(tokens).expect("failed to parse algebra spec")
}
fn alias_values_map(
alg: &SynAlgebra<FrameType>,
aliases: &BTreeMap<Alias, BasisElem>,
) -> HashMap<String, SynMultivector<FrameType>> {
aliases
.iter()
.map(|(alias, elem)| {
let terms = elem.vectors.iter().map(|(mask, coeff)| {
let blade = alg.basis_blade(*mask);
let value = coeff.value();
let literal: syn::Expr = syn::parse_quote! { #value };
let scalar = alg.syn.wrap_numeric_expr(literal);
(blade, scalar)
});
(alias.name.clone(), alg.from_terms(terms))
})
.collect()
}
fn multivector_signature(mv: &SynMultivector<FrameType>) -> Vec<(Vec<FrameType>, String)> {
mv.vectors
.iter()
.map(|(blade, coeff)| {
let masks = blade.vectors.keys().cloned().collect::<Vec<_>>();
let coeff_str = coeff.to_token_stream().to_string();
(masks, coeff_str)
})
.collect()
}
#[test]
fn compound_alias_builds_wedge_product() {
let spec_tokens = quote! {
f32, 3, 0
mod fixture {
bases! { e1 = P0; e2 = P1; e3 = P2; }
}
};
let spec = fixture_spec(spec_tokens);
let alias_map = spec.build_alias_mapping().expect("alias map");
let mut closure: ExprClosure = syn::parse_str("|| e12 ^ e3").expect("closure");
let mut optimizer = Optimizer::new(spec, alias_map.clone()).expect("optimizer");
optimizer.visit_expr_closure(&mut closure);
assert!(
optimizer.err.is_none(),
"unexpected optimizer error: {:?}",
optimizer.err
);
let result = optimizer.mv.clone().expect("multivector result");
let alias_values = alias_values_map(&optimizer.alg, &alias_map);
let e1 = alias_values.get("e1").expect("e1");
let e2 = alias_values.get("e2").expect("e2");
let e3 = alias_values.get("e3").expect("e3");
let e12 = optimizer.alg.wedge(e1, e2);
let expected = optimizer.alg.wedge(&e12, e3);
assert_eq!(
multivector_signature(&result),
multivector_signature(&expected)
);
}
#[test]
fn scoped_binding_overrides_alias_lookup() {
let spec_tokens = quote! {
f32, 3, 0
mod fixture {
bases! { e1 = P0; e2 = P1; e3 = P2; }
}
};
let spec = fixture_spec(spec_tokens);
let alias_map = spec.build_alias_mapping().expect("alias map");
let mut closure: ExprClosure =
syn::parse_str("|| { let e12 = e1 ^ e2; e12 ^ e3 }").expect("closure");
let mut optimizer = Optimizer::new(spec, alias_map.clone()).expect("optimizer");
optimizer.visit_expr_closure(&mut closure);
assert!(
optimizer.err.is_none(),
"unexpected optimizer error: {:?}",
optimizer.err
);
let result = optimizer.mv.clone().expect("multivector result");
let alias_values = alias_values_map(&optimizer.alg, &alias_map);
let e1 = alias_values.get("e1").expect("e1");
let e2 = alias_values.get("e2").expect("e2");
let e3 = alias_values.get("e3").expect("e3");
let scoped = optimizer.alg.wedge(e1, e2);
let expected = optimizer.alg.wedge(&scoped, e3);
assert_eq!(
multivector_signature(&result),
multivector_signature(&expected)
);
}
#[test]
fn lorentz_like_expression_keeps_parameter_in_scope() {
let spec_tokens = quote! {
f32, 1, 3
mod fixture {
basis!(t = P0);
basis!(e1 = N0);
basis!(e2 = N1);
basis!(e3 = N2);
shape!(Event { t, e1, e2, e3 });
}
};
let spec = fixture_spec(spec_tokens);
let alias_map = spec.build_alias_mapping().expect("alias map");
let mut closure: ExprClosure =
syn::parse_str("|ev: fixture::Event| (a * e1t).exp().sandwich(ev)").expect("closure");
let mut optimizer = Optimizer::new(spec, alias_map).expect("optimizer");
optimizer.visit_expr_closure(&mut closure);
assert!(
optimizer.err.is_none(),
"optimizer emitted error: {:?}",
optimizer.err
);
let result = optimizer.mv.expect("multivector result");
let coeff_tokens: Vec<_> = result
.vectors
.values()
.map(|coeff| coeff.to_token_stream().to_string())
.collect();
assert!(
coeff_tokens.iter().any(|tok| tok.contains("ev . t")),
"expected parameter coefficient to reference ev.t: {:?}",
coeff_tokens
);
}
}