use crate::FrameType;
use crate::basis_grammar::Alias;
use crate::clifford::SynBasis;
use crate::optimizer::Optimizer;
use crate::spec::AlgebraSpec;
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use quote::quote;
use std::collections::BTreeMap;
use std::mem::take;
use syn::parse::{Parse, ParseStream};
pub fn do_algebra_macro(
attr_ts: TokenStream,
mod_ts: TokenStream,
) -> Result<TokenStream, syn::Error> {
let spec: AlgebraSpec = syn::parse2(quote!(#attr_ts #mod_ts))?;
let mod_ = spec.module;
Ok(quote!(#mod_))
}
pub fn do_build_expr_macro(ts: TokenStream) -> Result<TokenStream, syn::Error> {
let build_expr: BuildExpr = syn::parse2(ts)?;
let optimized_expr = build_expr.optimize()?;
let fun = optimized_expr.fun;
let call_site = Span::call_site();
let mut fun = fun.clone();
let original_body = *fun.body.clone();
fun.body = Box::new(syn::parse_quote!({
use core::ops::{Add, Mul, Div, Sub, Neg};
#original_body
}));
let mut ts = fun.into_token_stream();
ts = ts
.into_iter()
.map(|mut token| {
token.set_span(token.span().resolved_at(call_site));
token
})
.collect();
Ok(ts)
}
pub type BasisElem = SynBasis<FrameType>;
pub struct BuildExpr {
pub spec: AlgebraSpec,
pub fun: syn::ExprClosure,
pub alias_mapping: BTreeMap<Alias, BasisElem>,
}
impl Parse for BuildExpr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut build = BuildExpr {
spec: input.parse()?,
fun: input.parse()?,
alias_mapping: BTreeMap::new(),
};
build.alias_mapping = build.spec.build_alias_mapping()?;
Ok(build)
}
}
impl BuildExpr {
pub fn optimize(mut self) -> syn::Result<Self> {
let alias_map = take(&mut self.alias_mapping);
let mut build = Optimizer::new(self.spec, alias_map)?;
build = build.run_pass1(&mut self.fun)?;
build = build.run_pass2(&mut self.fun)?;
self.spec = build.spec;
Ok(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::{Expr, ExprClosure};
#[test]
fn test_do_algebra_macro_valid() {
let attr_ts: TokenStream = quote! { f32, 2, 1 };
let mod_ts: TokenStream = quote! {
mod testmod {
basis!(e1 = P0 + N0);
basis!(e2 = P1);
shape!(Bar { e1, e2 });
}
};
let result = do_algebra_macro(attr_ts, mod_ts);
result.unwrap();
}
#[test]
fn test_do_algebra_macro_invalid_attr() {
let attr_ts: TokenStream = quote! { invalid, attr };
let mod_ts: TokenStream = quote! {
mod testmod {
basis! { e1 = P0 + P1 }
}
};
let result = do_algebra_macro(attr_ts, mod_ts);
assert!(result.is_err());
}
#[test]
fn test_do_algebra_macro_invalid_mod() {
let attr_ts: TokenStream = quote! { f32, 2, 1 };
let mod_ts: TokenStream = quote! { invalid mod syntax };
let result = do_algebra_macro(attr_ts, mod_ts);
assert!(result.is_err());
}
fn build_fixture(spec_tokens: TokenStream, closure_src: &str) -> BuildExpr {
let spec: AlgebraSpec = syn::parse2(spec_tokens).expect("spec");
let fun: ExprClosure = syn::parse_str(closure_src).expect("closure");
let alias_mapping = spec.build_alias_mapping().expect("alias map");
BuildExpr {
spec,
fun,
alias_mapping,
}
}
#[test]
fn constant_closure_rewrites_to_shape_literal() {
let spec_tokens = quote! {
f32, 3, 0
mod fixture {
basis!(e1 = P0);
basis!(e2 = P1);
basis!(e3 = P2);
shape!(Vector { e1, e2, e3 });
}
};
let build_expr = build_fixture(spec_tokens, "|| e1 + 2.0 * e2 - 3.0 * e3");
let optimized = build_expr.optimize().expect("optimize");
assert!(optimized.fun.inputs.is_empty(), "closure inputs mutated");
match optimized.fun.body.as_ref() {
Expr::Struct(struct_expr) => {
let path = struct_expr.path.to_token_stream().to_string();
assert!(
path.ends_with("fixture :: Vector"),
"unexpected path: {path}"
);
assert_eq!(struct_expr.fields.len(), 3, "expected three vector fields");
}
other => panic!("expected struct literal, got {:?}", other),
}
}
#[test]
fn unknown_alias_passes_through_as_scalar() {
let spec_tokens = quote! {
f32, 3, 0
mod fixture {
basis!(e1 = P0);
basis!(e2 = P1);
basis!(e3 = P2);
shape!(Vector { e1, e2, e3 });
shape!(Scalar { 1 });
}
};
let build_expr = build_fixture(spec_tokens, "|| unknown_alias");
let optimized = build_expr.optimize().expect("optimize");
assert!(optimized.fun.inputs.is_empty(), "closure inputs mutated");
match optimized.fun.body.as_ref() {
Expr::Struct(struct_expr) => {
let path = struct_expr.path.to_token_stream().to_string();
assert!(
path.ends_with("fixture :: Scalar"),
"unexpected path: {path}"
);
assert_eq!(struct_expr.fields.len(), 1, "expected one scalar field");
}
other => panic!("expected struct literal, got {:?}", other),
}
}
#[test]
fn closure_args_properly_bound() {
let spec_tokens = quote! {
f32, 1, 1
mod fixture {
basis!(t = P0);
basis!(e1 = N0);
shape!(Event { t, e1 });
}
};
let build_expr = build_fixture(spec_tokens, "|ev: fixture::Event| ev");
let optimized = build_expr.optimize().expect("optimize");
assert_eq!(optimized.fun.inputs.len(), 1, "closure inputs mutated");
match optimized.fun.body.as_ref() {
Expr::Struct(struct_expr) => {
let path = struct_expr.path.to_token_stream().to_string();
assert!(
path.ends_with("fixture :: Event"),
"unexpected path: {path}"
);
assert_eq!(struct_expr.fields.len(), 2, "expected two event fields");
let field_names: Vec<_> = struct_expr
.fields
.iter()
.map(|f| match &f.member {
syn::Member::Named(ident) => ident.to_string(),
syn::Member::Unnamed(_) => panic!("expected named field"),
})
.collect();
assert!(field_names.contains(&"t".to_string()), "missing field: t");
assert!(field_names.contains(&"e1".to_string()), "missing field: e1");
}
other => panic!("expected struct literal, got {:?}", other),
}
}
#[test]
fn closure_with_unknown_alias_and_arg() {
let spec_tokens = quote! {
f32, 1, 1
mod fixture {
basis!(t = P0);
basis!(e1 = N0);
shape!(Event { t, e1 });
shape!(Scalar { 1 });
}
};
let build_expr = build_fixture(spec_tokens, "|ev: fixture::Event| unknown_alias * t + ev");
let optimized = build_expr.optimize().expect("optimize");
assert_eq!(optimized.fun.inputs.len(), 1, "closure inputs mutated");
let expected = "fixture :: Event { t : (unknown_alias) . add (ev . t) , e1 : ev . e1 }";
let actual = optimized.fun.body.to_token_stream().to_string();
assert_eq!(actual, expected, "optimized closure body mismatch");
}
}