use crate::FrameType;
use crate::basis_grammar::{Alias, BasisAlias};
use crate::builder::BasisElem;
use crate::clifford::{Metric, SynAlgebra};
use abstalg::CommuntativeMonoid;
use itertools::Either;
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, format_ident, quote};
use std::collections::BTreeMap;
use syn::{
Item, Token, braced,
parse::{Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
};
#[derive(Debug)]
pub struct AlgebraSpec {
pub scalar_ty: syn::Type,
pub comma1: Token![,],
pub positive_count: syn::LitInt,
pub comma2: Token![,],
pub negative_count: syn::LitInt,
pub module_name: syn::Ident,
pub basis_indices: Punctuated<BasisDef, Token![;]>,
pub shape_defs: Vec<ShapeDef>,
pub module: syn::ItemMod,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BasisIndex {
pub sign: Either<Token![+], Token![-]>,
pub scalar: Option<syn::LitFloat>,
pub mul_token: Option<Token![*]>,
pub ident: Option<syn::Ident>,
}
#[derive(Debug)]
pub struct BasisDef {
pub name: syn::Ident,
pub eq_token: Token![=],
pub indices: Vec<BasisIndex>,
}
#[derive(Debug, Clone)]
pub struct ShapeDefField {
pub const_token: Option<Token![const]>, pub name: ShapeFieldName,
pub alias: Option<(Token![:], ShapeFieldName)>, }
impl ShapeDefField {
pub fn ident(&self) -> &syn::Ident {
self.name.ident()
}
pub fn span(&self) -> Span {
self.name.span()
}
pub fn alias_name(&self) -> (&str, Span) {
if let Some((_, alias)) = &self.alias {
(alias.alias(), alias.span())
} else {
(self.name.alias(), self.name.span())
}
}
}
#[derive(Debug, Clone)]
pub struct ShapeFieldName {
pub ident: syn::Ident,
pub alias: String,
pub is_literal: bool,
}
impl ShapeFieldName {
pub fn ident(&self) -> &syn::Ident {
&self.ident
}
pub fn alias(&self) -> &str {
&self.alias
}
pub fn span(&self) -> Span {
self.ident.span()
}
}
impl ToTokens for ShapeFieldName {
fn to_tokens(&self, tokens: &mut TokenStream) {
if self.is_literal {
let lit = syn::LitInt::new(&self.alias, self.span());
lit.to_tokens(tokens);
} else {
self.ident.to_tokens(tokens);
}
}
}
#[derive(Debug, Clone)]
pub struct ShapeDef {
pub name: syn::Ident,
pub braces: syn::token::Brace,
pub fields: Punctuated<ShapeDefField, Token![,]>,
}
pub struct ShapesMacroBody(pub Vec<ShapeDef>);
impl Parse for BasisIndex {
fn parse(input: ParseStream) -> syn::Result<Self> {
let sign = if input.peek(Token![+]) {
Either::Left(input.parse()?)
} else if input.peek(Token![-]) {
Either::Right(input.parse()?)
} else {
Either::Left(Token))
};
let scalar = if input.peek(syn::LitFloat) {
Some(input.parse()?)
} else {
None
};
let mul_token = if input.peek(Token![*]) {
Some(input.parse()?)
} else {
None
};
let ident = if input.peek(syn::Ident) {
Some(input.parse()?)
} else {
None
};
Ok(BasisIndex {
sign,
scalar,
mul_token,
ident,
})
}
}
impl Parse for BasisDef {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(BasisDef {
name: input.parse()?,
eq_token: input.parse()?,
indices: {
let mut indices = Vec::new();
while !input.peek(Token![;]) && !input.is_empty() {
indices.push(input.parse()?);
}
indices
},
})
}
}
impl Parse for ShapeFieldName {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.peek(syn::Ident) {
let ident: syn::Ident = input.parse()?;
Ok(Self {
ident: ident.clone(),
alias: ident.to_string(),
is_literal: false,
})
} else if input.peek(syn::LitInt) {
let lit: syn::LitInt = input.parse()?;
let alias = lit.base10_digits().to_string();
if alias != "1" {
return Err(syn::Error::new(
lit.span(),
"only the numeric literal `1` is supported in shape definitions",
));
}
Ok(Self {
ident: format_ident!("_{}", alias, span = lit.span()),
alias,
is_literal: true,
})
} else {
Err(syn::Error::new(
input.span(),
"expected identifier or numeric literal in shape definition",
))
}
}
}
impl Parse for ShapeDefField {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(ShapeDefField {
const_token: if input.peek(Token![const]) {
Some(input.parse()?)
} else {
None
},
name: input.parse()?,
alias: if input.peek(Token![:]) {
let colon = input.parse()?;
let alias = input.parse()?;
Some((colon, alias))
} else {
None
},
})
}
}
impl Parse for ShapeDef {
fn parse(input: ParseStream) -> syn::Result<Self> {
let content;
Ok(ShapeDef {
name: input.parse()?,
braces: braced!(content in input),
fields: content.parse_terminated(ShapeDefField::parse, Token![,])?,
})
}
}
impl Parse for ShapesMacroBody {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut shapes = Vec::new();
while !input.is_empty() {
shapes.push(input.parse()?);
}
Ok(ShapesMacroBody(shapes))
}
}
impl Parse for AlgebraSpec {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut spec = AlgebraSpec {
scalar_ty: input.parse()?,
comma1: input.parse()?,
positive_count: input.parse()?,
comma2: input.parse()?,
negative_count: input.parse()?,
module_name: format_ident!("_"),
basis_indices: Punctuated::new(),
shape_defs: Vec::new(),
module: input.parse()?,
};
spec.module_name = spec.module.ident.clone();
let Some((brace, items)) = spec.module.content.take() else {
return Err(syn::Error::new_spanned(
&spec.module,
"Module has no content",
));
};
let mut new_items: Vec<syn::Item> = vec![];
for item in items.into_iter() {
match item {
syn::Item::Macro(mac) if mac.mac.path.is_ident("basis") => {
spec.basis_indices.push(mac.mac.parse_body()?);
}
syn::Item::Macro(mac) if mac.mac.path.is_ident("bases") => {
let defs: Punctuated<BasisDef, Token![;]> =
mac.mac.parse_body_with(Punctuated::parse_terminated)?;
spec.basis_indices.extend(defs.into_iter());
}
syn::Item::Macro(mac) if mac.mac.path.is_ident("shape") => {
spec.shape_defs.push(mac.mac.parse_body()?);
let def = spec.shape_defs.last().unwrap();
new_items.push(spec.shape_struct(&mac, def)?);
}
syn::Item::Macro(mac) if mac.mac.path.is_ident("shapes") => {
let parsed: ShapesMacroBody = mac.mac.parse_body()?;
for def in parsed.0 {
new_items.push(spec.shape_struct(&mac, &def)?);
spec.shape_defs.push(def);
}
}
other => new_items.push(other),
}
}
let mac = syn::ItemMacro::from(&spec);
new_items.push(syn::Item::Macro(mac));
new_items.push(syn::Item::Use(syn::parse_quote! {
pub(crate) use expr;
}));
spec.module.content = Some((brace, new_items));
Ok(spec)
}
}
impl ToTokens for BasisDef {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.name.to_tokens(tokens);
self.eq_token.to_tokens(tokens);
let mut iter = self.indices.iter();
match iter.next() {
Some(BasisIndex {
sign: Either::Left(_plus),
scalar,
mul_token,
ident,
}) => {
if let Some(s) = scalar {
s.to_tokens(tokens);
}
if let Some(m) = mul_token {
m.to_tokens(tokens);
}
if let Some(i) = ident {
i.to_tokens(tokens);
}
}
Some(BasisIndex {
sign: Either::Right(minus),
scalar,
mul_token,
ident,
}) => {
minus.to_tokens(tokens);
if let Some(s) = scalar {
s.to_tokens(tokens);
}
if let Some(m) = mul_token {
m.to_tokens(tokens);
}
if let Some(i) = ident {
i.to_tokens(tokens);
}
}
_ => (),
}
for basis in iter {
match &basis.sign {
Either::Left(plus) => plus.to_tokens(tokens),
Either::Right(minus) => minus.to_tokens(tokens),
}
if let Some(s) = &basis.scalar {
s.to_tokens(tokens);
}
if let Some(m) = &basis.mul_token {
m.to_tokens(tokens);
}
if let Some(i) = &basis.ident {
i.to_tokens(tokens);
}
}
}
}
impl ToTokens for ShapeDefField {
fn to_tokens(&self, tokens: &mut TokenStream) {
if let Some(const_token) = &self.const_token {
const_token.to_tokens(tokens);
}
self.name.to_tokens(tokens);
if let Some((colon, alias)) = &self.alias {
colon.to_tokens(tokens);
alias.to_tokens(tokens);
}
}
}
impl ToTokens for ShapeDef {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.name.to_tokens(tokens);
self.braces.surround(tokens, |tokens| {
self.fields.to_tokens(tokens);
});
}
}
impl ToTokens for AlgebraSpec {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.scalar_ty.to_tokens(tokens);
self.comma1.to_tokens(tokens);
self.positive_count.to_tokens(tokens);
self.comma2.to_tokens(tokens);
self.negative_count.to_tokens(tokens);
let mod_name = &self.module_name;
let bases = &self.basis_indices;
let shapes = &self.shape_defs;
quote!(mod #mod_name { bases! { #bases } shapes! { #(#shapes)* } }).to_tokens(tokens);
}
}
impl AlgebraSpec {
fn shape_struct(&self, mac: &syn::ItemMacro, def: &ShapeDef) -> Result<Item, syn::Error> {
Ok(syn::Item::Struct(syn::ItemStruct {
attrs: mac.attrs.clone(),
vis: syn::Visibility::Public(Token)),
struct_token: Token),
ident: def.name.clone(),
generics: syn::Generics::default(),
fields: self.shape_fields(def)?,
semi_token: mac.semi_token,
}))
}
fn shape_fields(&self, def: &ShapeDef) -> Result<syn::Fields, syn::Error> {
let brace_token = def.braces.clone();
let mut named = syn::punctuated::Punctuated::new();
for pair in def.fields.pairs() {
match pair.into_tuple() {
(field_def, comma) => {
if field_def.const_token.is_none() {
let field_ident = field_def.name.ident().clone();
named.push_value(syn::Field {
attrs: vec![],
vis: syn::Visibility::Public(Token)),
mutability: syn::FieldMutability::None,
ident: Some(field_ident),
colon_token: Some(Token)),
ty: self.scalar_ty.clone(),
});
} else {
return Err(syn::Error::new_spanned(
field_def.const_token.unwrap(),
"Const fields are not yet supported in shape definitions",
));
}
if let Some(comma) = comma {
named.push_punct(comma.clone());
}
}
}
}
let fields = syn::Fields::Named(syn::FieldsNamed { brace_token, named });
Ok(fields)
}
pub fn build_alias_mapping(&self) -> syn::Result<BTreeMap<Alias, BasisElem>> {
use crate::clifford::Scalar;
let pos = self.positive_count.base10_parse()?;
let neg = self.negative_count.base10_parse()?;
let basis = self.algebra()?.basis;
let mut mapping = BTreeMap::new();
for basis_def in self.basis_indices.iter() {
let alias = BasisAlias::try_from(basis_def)?;
let mut accumulator = basis.zero();
for slot in alias.terms.iter() {
let coeff_value = slot.coeff();
let mask = slot.to_mask(pos, neg)?;
let coeff = Scalar::new(coeff_value);
let term = BasisElem::from_terms([(mask, coeff)]);
basis.add_assign(&mut accumulator, &term);
}
mapping.insert(alias.alias.clone(), accumulator);
}
Ok(mapping)
}
pub fn algebra(&self) -> syn::Result<SynAlgebra<FrameType>> {
Ok(SynAlgebra::with_scalar_type(
Metric::new(
self.positive_count.base10_parse()?,
self.negative_count.base10_parse()?,
),
self.scalar_ty.clone(),
))
}
}
impl From<&AlgebraSpec> for syn::ItemMacro {
fn from(spec: &AlgebraSpec) -> Self {
syn::parse_quote! {
macro_rules! expr {
($($body:tt)*) => {
reefer::build_expr!(#spec $($body)*)
}
}
}
}
}
impl From<AlgebraSpec> for syn::File {
fn from(spec: AlgebraSpec) -> Self {
syn::File {
shebang: None,
attrs: vec![],
items: vec![syn::Item::Mod(spec.module)],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::basis_grammar::Alias;
use abstalg::Semigroup;
use num_traits::ConstOne;
fn alias_entry<'a>(mapping: &'a BTreeMap<Alias, BasisElem>, name: &str) -> &'a BasisElem {
mapping
.get(&Alias::new(name, Span::call_site()))
.unwrap_or_else(|| panic!("missing alias `{name}`"))
}
#[test]
fn alias_mapping_preserves_signs() {
let spec_tokens = quote! {
f32, 3, 1
mod fixture {
basis! { e0 = P0 - N0 }
}
};
let spec = syn::parse2::<AlgebraSpec>(spec_tokens).expect("algebra spec");
let mapping = spec.build_alias_mapping().expect("alias mapping");
let alias = alias_entry(&mapping, "e0");
assert_eq!(alias.vectors.len(), 2, "expected two basis components");
let p0_mask = FrameType::ONE << 0;
let n0_mask = FrameType::ONE << 3;
let p0_coeff = alias
.vectors
.get(&p0_mask)
.expect("missing P0 contribution");
assert!((p0_coeff.value() - 1.0).abs() < f64::EPSILON);
let n0_coeff = alias
.vectors
.get(&n0_mask)
.expect("missing N0 contribution");
assert!((n0_coeff.value() + 1.0).abs() < f64::EPSILON);
}
#[test]
fn idempotent_basis_element_squares_to_itself() {
let spec_tokens = quote! {
f32, 3, 0
mod fixture {
basis! { e = 0.5 + 0.5 * P0 }
shape! { Point { e } }
}
};
let spec = syn::parse2::<AlgebraSpec>(spec_tokens).expect("algebra spec");
let mapping = spec.build_alias_mapping().expect("alias mapping");
let e = alias_entry(&mapping, "e");
let algebra = spec.algebra().expect("algebra");
let e_squared = algebra.basis.mul(e, e);
assert_eq!(
e_squared.vectors.len(),
e.vectors.len(),
"e² should have same number of terms as e"
);
let scalar_mask = 0;
let e_scalar = e
.vectors
.get(&scalar_mask)
.expect("e should have scalar term");
let e2_scalar = e_squared
.vectors
.get(&scalar_mask)
.expect("e² should have scalar term");
assert!(
(e_scalar.value() - 0.5).abs() < 1e-10,
"e scalar should be 0.5, got {}",
e_scalar.value()
);
assert!(
(e2_scalar.value() - 0.5).abs() < 1e-10,
"e² scalar should be 0.5, got {}",
e2_scalar.value()
);
let p0_mask = FrameType::ONE << 0;
let e_p0 = e.vectors.get(&p0_mask).expect("e should have P0 term");
let e2_p0 = e_squared
.vectors
.get(&p0_mask)
.expect("e² should have P0 term");
assert!(
(e_p0.value() - 0.5).abs() < 1e-10,
"e P0 coeff should be 0.5, got {}",
e_p0.value()
);
assert!(
(e2_p0.value() - 0.5).abs() < 1e-10,
"e² P0 coeff should be 0.5, got {}",
e2_p0.value()
);
}
#[test]
fn point_shape_with_idempotent_basis_squares_to_itself() {
let spec_tokens = quote! {
f32, 3, 0
mod fixture {
basis! { e = 0.5 + 0.5 * P0 }
shape! { Point { e } }
}
};
let spec = syn::parse2::<AlgebraSpec>(spec_tokens).expect("algebra spec");
let mapping = spec.build_alias_mapping().expect("alias mapping");
let e_basis = alias_entry(&mapping, "e");
let algebra = spec.algebra().expect("algebra");
let point = e_basis.clone();
let point_squared = algebra.basis.mul(&point, &point);
assert_eq!(
point_squared.vectors.len(),
point.vectors.len(),
"point² should have same number of terms as point"
);
let scalar_mask = 0;
let point_scalar = point
.vectors
.get(&scalar_mask)
.expect("point should have scalar term");
let point2_scalar = point_squared
.vectors
.get(&scalar_mask)
.expect("point² should have scalar term");
assert!(
(point_scalar.value() - point2_scalar.value()).abs() < 1e-10,
"point² scalar should equal point scalar: {} vs {}",
point2_scalar.value(),
point_scalar.value()
);
let p0_mask = FrameType::ONE << 0;
let point_p0 = point
.vectors
.get(&p0_mask)
.expect("point should have P0 term");
let point2_p0 = point_squared
.vectors
.get(&p0_mask)
.expect("point² should have P0 term");
assert!(
(point_p0.value() - point2_p0.value()).abs() < 1e-10,
"point² P0 coeff should equal point P0 coeff: {} vs {}",
point2_p0.value(),
point_p0.value()
);
}
}