use super::basis::{BasisSpace, Frame};
use super::metric::Metric;
use super::{
DotProduct, Duality, ExteriorProduct, Graded, InnerProduct, Involution, LeftContraction,
RegressiveProduct, RightContraction, ScalarProduct,
};
use crate::clifford::{
AnticommutatorProduct, CliffordAlgebra, CommutatorProduct, Divisibility, Exponential,
GeometricProducts,
};
use crate::{
clifford::{ScalarSpace, multivector::Multivector},
symbol_field::{ActiveField, CoefficientField},
};
use abstalg::{
AbelianGroup, BoundedOrder, CommuntativeMonoid, Domain, Monoid, SemiRing, Semigroup,
UnitaryRing,
};
use itertools::{EitherOrBoth, Itertools};
use proc_macro2::Span;
use quote::{ToTokens, quote};
use std::{
collections::{BTreeMap, BTreeSet},
fmt::{Display, Write},
};
use syn::{
BinOp, Expr, ExprLit, ExprMethodCall, ExprPath, ExprUnary, Lit, parse_quote, spanned::Spanned,
};
pub type SynBasis<B> = Multivector<B, ScalarSpace>;
pub type SynMultivector<B> = Multivector<SynBasis<B>, ActiveField>;
impl<B: Frame + Display> Display for SynBasis<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let terms: Vec<String> = self
.vectors
.iter()
.map(|(basis, coeff)| format!("{}*{}", coeff, basis))
.collect();
f.write_char('<')?;
write!(f, "{}", terms.join("+"))?;
f.write_char('>')?;
Ok(())
}
}
impl<B: Frame + Display> Display for SynMultivector<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let terms: Vec<String> = self
.vectors
.iter()
.map(|(basis, coeff)| format!("{} * {}", coeff, basis))
.collect();
write!(f, "{}", terms.join(" ++ "))
}
}
#[derive(Debug, Clone)]
pub struct SynAlgebra<B: Frame> {
pub basis: BasisSpace<B, ScalarSpace>,
pub syn: ActiveField,
}
impl<B: Frame> GeometricProducts for SynAlgebra<B> {}
impl<B: Frame> CliffordAlgebra for SynAlgebra<B> {}
#[allow(dead_code)]
impl<B: Frame> SynAlgebra<B> {
pub fn new(metric: Metric<B>, syn: ActiveField) -> Self {
let basis = BasisSpace::new(metric, ScalarSpace);
Self { basis, syn }
}
pub fn coeff(
&self,
mv: &SynMultivector<B>,
basis: &SynBasis<B>,
) -> <ActiveField as Domain>::Elem {
mv.vectors.get(basis).cloned().unwrap_or(self.syn.zero())
}
pub fn with_scalar_type(metric: Metric<B>, scalar_ty: syn::Type) -> Self {
Self {
basis: BasisSpace::new(metric, ScalarSpace),
syn: ActiveField::new(scalar_ty),
}
}
pub fn frame(&self) -> &Metric<B> {
self.basis.frame()
}
pub fn scalar_field(&self) -> &ActiveField {
&self.syn
}
pub fn zero_multivector(&self) -> SynMultivector<B> {
self.zero()
}
pub fn from_terms<I>(&self, terms: I) -> SynMultivector<B>
where
I: IntoIterator<Item = (SynBasis<B>, <ActiveField as Domain>::Elem)>,
{
let mut vectors = terms.into_iter().collect();
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
pub fn scalar(&self, value: <ActiveField as Domain>::Elem) -> SynMultivector<B> {
let mut vectors = BTreeMap::new();
let basis_one = self.basis.one();
vectors.insert(basis_one, value);
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
pub fn basis_blade(&self, blade: B) -> SynBasis<B> {
SynBasis::from_terms([(blade, self.basis.scalar.one())])
}
pub fn drop_zeros(&self, mv: &mut SynMultivector<B>) {
self.drop_zero_terms(&mut mv.vectors);
}
fn drop_zero_terms(&self, vectors: &mut BTreeMap<SynBasis<B>, <ActiveField as Domain>::Elem>) {
vectors.retain(|basis, coeff| !self.basis.is_zero(basis) && !self.syn.is_zero(coeff));
}
fn accumulate_term(
&self,
storage: &mut BTreeMap<SynBasis<B>, <ActiveField as Domain>::Elem>,
blade: SynBasis<B>,
coeff: <ActiveField as Domain>::Elem,
is_negative: bool,
) {
if self.syn.is_zero(&coeff) || self.basis.is_zero(&blade) {
return;
}
use std::collections::btree_map::Entry;
match storage.entry(blade) {
Entry::Occupied(mut occ) => {
if is_negative {
self.syn.sub_assign(occ.get_mut(), &coeff);
} else {
self.syn.add_assign(occ.get_mut(), &coeff);
}
if self.syn.is_zero(occ.get()) {
occ.remove_entry();
}
}
Entry::Vacant(vacant) => {
let value = if is_negative {
self.syn.neg(&coeff)
} else {
coeff
};
if !self.syn.is_zero(&value) {
vacant.insert(value);
}
}
}
}
fn mul_coeff(
&self,
lhs: &<ActiveField as Domain>::Elem,
rhs: &<ActiveField as Domain>::Elem,
) -> <ActiveField as Domain>::Elem {
if self.syn.is_zero(lhs) || self.syn.is_zero(rhs) {
return self.syn.zero();
}
let one = self.syn.one();
if self.syn.equals(lhs, &one) {
return rhs.clone();
}
if self.syn.equals(rhs, &one) {
return lhs.clone();
}
self.syn.mul(lhs, rhs)
}
fn expand_and_accumulate(
&self,
storage: &mut BTreeMap<SynBasis<B>, <ActiveField as Domain>::Elem>,
blades: &SynBasis<B>,
coeff: &<ActiveField as Domain>::Elem,
) {
if self.syn.is_zero(coeff) {
return;
}
for (mask, blade_coeff) in blades.vectors.iter() {
if self.basis.scalar().is_zero(blade_coeff) {
continue;
}
let canonical = self.basis_blade(*mask);
let value = blade_coeff.value();
let scalar_literal: Expr = parse_quote! { #value };
let blade_factor = CoefficientField::embed_expr(&self.syn, scalar_literal).unwrap();
if self.syn.is_zero(&blade_factor) {
continue;
}
let combined = self.mul_coeff(coeff, &blade_factor);
if self.syn.is_zero(&combined) {
continue;
}
self.accumulate_term(storage, canonical, combined, false);
}
}
#[inline]
pub fn pseudoscalar_with_mask(&self, mask: B) -> SynMultivector<B> {
let blade = SynBasis::from_terms([(mask, self.basis.scalar.one())]);
self.from_terms([(blade, self.syn.one())])
}
#[inline]
pub fn pseudoscalar(&self) -> SynMultivector<B> {
self.pseudoscalar_with_mask(self.frame().max())
}
#[inline]
pub fn complement_mask(&self, mask: B) -> B {
self.frame().sym_diff(self.frame().max(), mask)
}
#[inline]
pub fn complement_sign(&self, mask: B) -> bool {
self.frame().mul_parity(mask, self.complement_mask(mask))
}
#[inline]
pub fn uncomplement_sign(&self, mask: B) -> bool {
self.frame().mul_parity(self.complement_mask(mask), mask)
}
#[inline]
pub fn metric_sign(&self, mask: B) -> bool {
(mask & self.frame().imagimum).parity()
}
pub fn right_dual_with_mask(&self, elem: &SynMultivector<B>, mask: B) -> SynMultivector<B> {
let pseudoscalar = self.pseudoscalar_with_mask(mask);
let reversed = self.reverse(elem.clone());
self.mul(&reversed, &pseudoscalar)
}
pub fn left_dual_with_mask(&self, elem: &SynMultivector<B>, mask: B) -> SynMultivector<B> {
let pseudoscalar = self.pseudoscalar_with_mask(mask);
let reversed = self.reverse(elem.clone());
self.mul(&pseudoscalar, &reversed)
}
#[inline]
pub fn right_dual(&self, elem: &SynMultivector<B>) -> SynMultivector<B> {
self.right_dual_with_mask(elem, self.frame().max())
}
#[inline]
pub fn left_dual(&self, elem: &SynMultivector<B>) -> SynMultivector<B> {
self.left_dual_with_mask(elem, self.frame().max())
}
pub fn complement(&self, elem: &SynMultivector<B>) -> SynMultivector<B> {
let mut result = self.zero();
for (basis, coeff) in elem.vectors.iter() {
let complemented = self.basis.complement(basis);
self.accumulate_term(&mut result.vectors, complemented, coeff.clone(), false);
}
result
}
pub fn just_scalar(&self, elem: &SynMultivector<B>) -> Option<<ActiveField as Domain>::Elem> {
let mut iter = elem.vectors.iter().peekable();
let scalar;
if iter.peek().is_some_and(|(b, _)| self.basis.is_one(b)) {
let (_, s) = iter.next().unwrap();
scalar = s.clone();
} else {
scalar = self.syn.zero();
}
iter.all(|(_, c)| self.syn.is_zero(c)).then(|| scalar)
}
pub fn get_scalar(&self, elem: &SynMultivector<B>) -> <ActiveField as Domain>::Elem {
elem.vectors
.get(&self.basis.one())
.cloned()
.unwrap_or_else(|| self.syn.zero())
}
pub fn norm_squared(&self, elem: &SynMultivector<B>) -> <ActiveField as Domain>::Elem {
let inner_prod = self.inner(elem, &self.reverse(elem.clone()));
self.get_scalar(&inner_prod)
}
pub fn meet(&self, lhs: &SynMultivector<B>, rhs: &SynMultivector<B>) -> SynMultivector<B> {
let lhs_dual = self.right_dual(lhs);
let rhs_dual = self.right_dual(rhs);
let join_dual = self.wedge(&lhs_dual, &rhs_dual);
self.right_dual(&join_dual)
}
pub fn parse_literal(&self, literal: &ExprLit) -> syn::Result<SynMultivector<B>> {
match &literal.lit {
Lit::Int(_) | Lit::Float(_) | Lit::Bool(_) | Lit::Str(_) | Lit::Byte(_) => {
let expr = Expr::Lit(literal.clone());
Ok(self.scalar(CoefficientField::embed_expr(&self.syn, expr).unwrap()))
}
Lit::Char(_) | Lit::ByteStr(_) | Lit::CStr(_) | Lit::Verbatim(_) => Err(
syn::Error::new(literal.span(), "unsupported literal in Clifford expression"),
),
_ => Err(syn::Error::new(
literal.span(),
"unsupported literal in Clifford expression",
)),
}
}
pub fn parse_expr(&self, expr: &Expr) -> syn::Result<SynMultivector<B>> {
match expr {
Expr::Binary(bin) => {
let lhs = self.parse_expr(&bin.left)?;
let rhs = self.parse_expr(&bin.right)?;
match bin.op {
BinOp::Add(_) => Ok(self.add(&lhs, &rhs)),
BinOp::Sub(_) => Ok(self.sub(&lhs, &rhs)),
BinOp::Mul(_) => Ok(self.mul(&lhs, &rhs)),
BinOp::BitXor(_) => Ok(self.wedge(&lhs, &rhs)),
BinOp::BitOr(_) => Ok(self.add(&lhs, &rhs)),
BinOp::BitAnd(_) => Ok(self.inner(&lhs, &rhs)),
_ => Err(syn::Error::new(
bin.op.span(),
"unsupported operator in Clifford expression",
)),
}
}
Expr::Unary(ExprUnary { op, expr, .. }) => match op {
syn::UnOp::Neg(_) => {
let value = self.parse_expr(expr)?;
Ok(self.neg(&value))
}
syn::UnOp::Deref(_) => self.parse_expr(expr),
_ => Err(syn::Error::new(op.span(), "unsupported unary operator")),
},
Expr::Paren(paren) => self.parse_expr(&paren.expr),
Expr::Group(group) => self.parse_expr(&group.expr),
Expr::Path(path) => self.parse_path_expr(path),
Expr::Lit(literal) => self.parse_literal(literal),
Expr::MethodCall(call) => self.parse_method_call(call),
Expr::Call(call) => Err(syn::Error::new(
call.span(),
"function calls are not supported in Clifford expressions",
)),
Expr::Tuple(tuple) => Err(syn::Error::new(
tuple.span(),
"tuple expressions are not supported in Clifford expressions",
)),
Expr::Block(block) => Err(syn::Error::new(
block.span(),
"block expressions are not supported in Clifford expressions",
)),
other => Err(syn::Error::new(
other.span(),
"unsupported syntax in Clifford expression",
)),
}
}
fn parse_method_call(&self, call: &ExprMethodCall) -> syn::Result<SynMultivector<B>> {
let receiver = self.parse_expr(&call.receiver)?;
let method = call.method.to_string();
match method.as_str() {
"exp" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`exp` does not accept arguments in Clifford expressions",
));
}
Ok(self.exp(&receiver))
}
"reverse" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`reverse` does not accept arguments in Clifford expressions",
));
}
Ok(self.reverse(receiver))
}
"sandwich" => {
let mut args = call.args.iter();
let target_expr = args.next().ok_or_else(|| {
syn::Error::new(
call.span(),
"`sandwich` expects exactly one argument: the multivector to transform",
)
})?;
if args.next().is_some() {
return Err(syn::Error::new(
call.span(),
"`sandwich` expects exactly one argument",
));
}
let target = self.parse_expr(target_expr)?;
Ok(self.sandwich(&receiver, &target))
}
"dual" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`dual` does not accept arguments in Clifford expressions",
));
}
Ok(self.right_dual(&receiver))
}
"undual" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`undual` does not accept arguments in Clifford expressions",
));
}
Ok(self.left_dual(&receiver))
}
"complement" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`complement` does not accept arguments in Clifford expressions",
));
}
Ok(self.complement(&receiver))
}
"sqrt" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`sqrt` does not accept arguments in Clifford expressions",
));
}
Ok(self.sqrt(&receiver))
}
"norm_squared" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`norm_squared` does not accept arguments in Clifford expressions",
));
}
let scalar_val = self.norm_squared(&receiver);
Ok(self.scalar(scalar_val))
}
"norm" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`norm` does not accept arguments in Clifford expressions",
));
}
todo!()
}
"conjugate" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`conjugate` does not accept arguments in Clifford expressions",
));
}
Ok(self.conjugate(receiver))
}
"automorphism" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`automorphism` does not accept arguments in Clifford expressions",
));
}
Ok(self.automorphism(receiver))
}
"left_contract" => {
let mut args = call.args.iter();
let target_expr = args.next().ok_or_else(|| {
syn::Error::new(call.span(), "`left_contract` expects exactly one argument")
})?;
if args.next().is_some() {
return Err(syn::Error::new(
call.span(),
"`left_contract` expects exactly one argument",
));
}
let target = self.parse_expr(target_expr)?;
Ok(self.contract_onto(&receiver, &target))
}
"right_contract" => {
let mut args = call.args.iter();
let target_expr = args.next().ok_or_else(|| {
syn::Error::new(call.span(), "`right_contract` expects exactly one argument")
})?;
if args.next().is_some() {
return Err(syn::Error::new(
call.span(),
"`right_contract` expects exactly one argument",
));
}
let target = self.parse_expr(target_expr)?;
Ok(self.contract_by(&receiver, &target))
}
"scalar" => {
if !call.args.is_empty() {
return Err(syn::Error::new(
call.span(),
"`scalar` does not accept arguments in Clifford expressions",
));
}
match self.just_scalar(&receiver) {
Some(scalar_val) => Ok(self.scalar(scalar_val)),
None => Ok(self.scalar(self.syn.zero())),
}
}
other => Err(syn::Error::new(
call.method.span(),
format!("unsupported method `{other}` in Clifford expression"),
)),
}
}
fn parse_path_expr(&self, path: &ExprPath) -> syn::Result<SynMultivector<B>> {
if let Some(segment) = path.path.segments.last() {
self.parse_ident(&segment.ident, segment.ident.span())
} else {
Err(syn::Error::new(path.span(), "empty path in expression"))
}
}
fn parse_ident(
&self,
ident: &proc_macro2::Ident,
span: Span,
) -> syn::Result<SynMultivector<B>> {
let name = ident.to_string();
match name.as_str() {
"I" | "pss" => Ok(self.pseudoscalar()),
"0" | "zero" => Ok(self.zero()),
"1" | "one" | "er" => Ok(self.one()),
_ if name.starts_with('e') => {
let digits = &name[1..];
if digits.is_empty() {
return Ok(self.one());
}
let mut mask = B::ZERO;
for chunk in digits.split('_') {
if chunk.is_empty() {
continue;
}
let idx: usize = chunk.parse().map_err(|_| {
syn::Error::new(span, format!("invalid basis index `{chunk}`"))
})?;
if idx >= self.frame().dimensions {
return Err(syn::Error::new(
span,
format!("basis index `{idx}` exceeds frame dimensions"),
));
}
let bit = idx;
if bit >= B::BITS as usize {
return Err(syn::Error::new(
span,
format!(
"basis index `{idx}` cannot be represented with {} bits",
B::BITS
),
));
}
let blade_bit = B::ONE << bit;
if mask & blade_bit != B::ZERO {
return Err(syn::Error::new(
span,
format!("duplicate basis index `{idx}` in `{name}`"),
));
}
mask |= blade_bit;
}
let basis = self.basis_blade(mask);
Ok(self.from_terms([(basis, self.syn.one())]))
}
_ => Err(syn::Error::new(
span,
format!("unknown basis identifier `{name}`"),
)),
}
}
}
impl<B: Frame> Domain for SynAlgebra<B> {
type Elem = SynMultivector<B>;
fn equals(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> bool {
elem1
.vectors
.iter()
.merge_join_by(elem2.vectors.iter(), |(b1, _), (b2, _)| b1.cmp(b2))
.all(|either| match either {
itertools::EitherOrBoth::Both((_, v1), (_, v2)) => self.syn.equals(v1, v2),
itertools::EitherOrBoth::Left((_, v)) => self.syn.is_zero(v),
itertools::EitherOrBoth::Right((_, v)) => self.syn.is_zero(v),
})
}
fn contains(&self, elem: &Self::Elem) -> bool {
elem.vectors
.iter()
.all(|(b, v)| self.basis.contains(b) && self.syn.contains(v))
}
}
impl<B: Frame> CommuntativeMonoid for SynAlgebra<B> {
fn zero(&self) -> Self::Elem {
SynMultivector {
vectors: BTreeMap::new(),
}
}
fn add(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
let mut vectors = elem1
.vectors
.iter()
.merge_join_by(elem2.vectors.iter(), |(b1, _), (b2, _)| b1.cmp(b2))
.map(|either| match either {
EitherOrBoth::Both((b, v1), (_, v2)) => (b.clone(), self.syn.add(v1, v2)),
EitherOrBoth::Left((b, v)) | EitherOrBoth::Right((b, v)) => (b.clone(), v.clone()),
})
.collect();
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
fn is_zero(&self, elem: &Self::Elem) -> bool {
elem.vectors.values().all(|v| self.syn.is_zero(v))
}
fn add_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
elem2.vectors.iter().for_each(|(b, v2)| {
elem1
.vectors
.entry(b.clone())
.and_modify(|v1| {
*v1 = self.syn.add(v1, v2);
})
.or_insert_with(|| v2.clone());
});
self.drop_zero_terms(&mut elem1.vectors);
}
fn double(&self, elem: &mut Self::Elem) {
elem.vectors
.iter_mut()
.for_each(|(_, coeff)| self.syn.double(coeff));
self.drop_zero_terms(&mut elem.vectors);
}
fn times(&self, num: usize, elem: &Self::Elem) -> Self::Elem {
let mut vectors = elem
.vectors
.iter()
.map(|(b, s)| (b.clone(), CommuntativeMonoid::times(&self.syn, num, s)))
.collect();
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
}
impl<B: Frame> AbelianGroup for SynAlgebra<B> {
fn neg(&self, elem: &Self::Elem) -> Self::Elem {
let mut vectors = elem
.vectors
.iter()
.map(|(b, v)| (b.clone(), self.syn.neg(v)))
.collect();
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
fn neg_assign(&self, elem: &mut Self::Elem) {
elem.vectors
.iter_mut()
.for_each(|(_, v)| self.syn.neg_assign(v));
self.drop_zero_terms(&mut elem.vectors);
}
fn sub(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
let mut vectors = elem1
.vectors
.iter()
.merge_join_by(elem2.vectors.iter(), |(b1, _), (b2, _)| b1.cmp(b2))
.map(|either| match either {
EitherOrBoth::Both((b, v1), (_, v2)) => (b.clone(), self.syn.sub(v1, v2)),
EitherOrBoth::Left((b, v)) => (b.clone(), v.clone()),
EitherOrBoth::Right((b, v)) => (b.clone(), self.syn.neg(v)),
})
.collect();
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
fn sub_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
elem2.vectors.iter().for_each(|(b, v2)| {
elem1
.vectors
.entry(b.clone())
.and_modify(|v1| self.syn.sub_assign(v1, v2))
.or_insert_with(|| self.syn.neg(v2));
});
self.drop_zero_terms(&mut elem1.vectors);
}
fn times(&self, num: isize, elem: &Self::Elem) -> Self::Elem {
let mut vectors = elem
.vectors
.iter()
.map(|(b, v)| (b.clone(), AbelianGroup::times(&self.syn, num, v)))
.collect();
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
}
impl<B: Frame> Semigroup for SynAlgebra<B> {
fn mul(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
let mut product = self.zero();
for (b1, s1) in elem1.vectors.iter() {
for (b2, s2) in elem2.vectors.iter() {
let blades = self.basis.mul(b1, b2);
if self.basis.is_zero(&blades) {
continue;
}
let coeff = self.mul_coeff(s1, s2);
if self.syn.is_zero(&coeff) {
continue;
}
self.expand_and_accumulate(&mut product.vectors, &blades, &coeff);
}
}
self.drop_zero_terms(&mut product.vectors);
product
}
fn mul_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
*elem1 = self.mul(elem1, elem2);
}
fn square(&self, elem: &mut Self::Elem) {
*elem = self.mul(elem, elem);
}
}
impl<B: Frame> Monoid for SynAlgebra<B> {
fn one(&self) -> Self::Elem {
let mut vectors = BTreeMap::new();
vectors.insert(self.basis.one(), self.syn.one());
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
fn is_one(&self, elem: &Self::Elem) -> bool {
self.equals(elem, &self.one())
}
fn try_inv(&self, _elem: &Self::Elem) -> Option<Self::Elem> {
todo!("Implement multivector inversion")
}
fn invertible(&self, elem: &Self::Elem) -> bool {
self.try_inv(elem).is_some()
}
}
impl<B: Frame> SemiRing for SynAlgebra<B> {}
impl<B: Frame> UnitaryRing for SynAlgebra<B> {}
impl<B: Frame> Divisibility for SynAlgebra<B> {
fn try_div(&self, _elem1: &Self::Elem, _elem2: &Self::Elem) -> Option<Self::Elem> {
todo!("Implement multivector division")
}
}
impl<B: Frame> Graded for SynAlgebra<B> {
type Grade = Vec<u32>;
type Output = SynMultivector<B>;
fn grade_of(&self, elem: &Self::Elem) -> Self::Grade {
elem.vectors
.keys()
.flat_map(|basis| self.basis.grade_of(basis).into_iter())
.sorted()
.dedup()
.collect()
}
fn grade_by(&self, elem: &Self::Elem, grades: Self::Grade) -> Self::Output {
let filters: BTreeSet<u32> = grades.into_iter().collect();
let mut vectors = elem
.vectors
.iter()
.filter(|(basis, _)| {
self.basis
.grade_of(basis)
.into_iter()
.any(|grade| filters.contains(&grade))
})
.map(|(basis, coeff)| (basis.clone(), coeff.clone()))
.collect();
self.drop_zero_terms(&mut vectors);
SynMultivector { vectors }
}
}
impl<B: Frame> Duality for SynAlgebra<B> {
type Psuedo = Option<B>;
type Output = SynMultivector<B>;
fn dual(&self, elem: &Self::Elem, pseudoscalar: &Self::Psuedo) -> Self::Output {
let mut result = self.zero();
for (basis, coeff) in elem.vectors.iter() {
let dual_basis = self.basis.dual(basis, pseudoscalar);
self.accumulate_term(&mut result.vectors, dual_basis, coeff.clone(), false);
}
result
}
fn undual(&self, elem: &Self::Elem, pseudoscalar: &Self::Psuedo) -> Self::Output {
let mut result = self.zero();
for (basis, coeff) in elem.vectors.iter() {
let undual_basis = self.basis.undual(basis, pseudoscalar);
self.accumulate_term(&mut result.vectors, undual_basis, coeff.clone(), false);
}
result
}
}
impl<B: Frame> Involution for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn automorphism(&self, elem: Self::Elem) -> Self::Output {
let mut result = self.zero();
for (basis, coeff) in elem.vectors.into_iter() {
let transformed = self.basis.automorphism(basis);
self.accumulate_term(&mut result.vectors, transformed, coeff, false);
}
result
}
fn reverse(&self, elem: Self::Elem) -> Self::Output {
let mut result = self.zero();
for (basis, coeff) in elem.vectors.into_iter() {
let transformed = self.basis.reverse(basis);
self.accumulate_term(&mut result.vectors, transformed, coeff, false);
}
result
}
fn conjugate(&self, elem: Self::Elem) -> Self::Output {
let mut result = self.zero();
for (basis, coeff) in elem.vectors.into_iter() {
let transformed = self.basis.conjugate(basis);
self.accumulate_term(&mut result.vectors, transformed, coeff, false);
}
result
}
}
impl<B: Frame> ExteriorProduct for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn wedge(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
let mut result = self.zero();
for (b1, s1) in lhs.vectors.iter() {
for (b2, s2) in rhs.vectors.iter() {
let wedge = self.basis.wedge(b1, b2);
if self.basis.is_zero(&wedge) {
continue;
}
let coeff = self.mul_coeff(s1, s2);
if self.syn.is_zero(&coeff) {
continue;
}
self.expand_and_accumulate(&mut result.vectors, &wedge, &coeff);
}
}
result
}
}
impl<B: Frame> InnerProduct for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn inner(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
let mut result = self.zero();
for (b1, s1) in lhs.vectors.iter() {
for (b2, s2) in rhs.vectors.iter() {
let inner = self.basis.inner(b1, b2);
if self.basis.is_zero(&inner) {
continue;
}
let coeff = self.mul_coeff(s1, s2);
if self.syn.is_zero(&coeff) {
continue;
}
self.expand_and_accumulate(&mut result.vectors, &inner, &coeff);
}
}
result
}
}
impl<B: Frame> RegressiveProduct for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn antiwedge(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
let mut result = self.zero();
for (b1, s1) in lhs.vectors.iter() {
for (b2, s2) in rhs.vectors.iter() {
let anti = self.basis.antiwedge(b1, b2);
if self.basis.is_zero(&anti) {
continue;
}
let coeff = self.mul_coeff(s1, s2);
if self.syn.is_zero(&coeff) {
continue;
}
self.expand_and_accumulate(&mut result.vectors, &anti, &coeff);
}
}
result
}
}
impl<B: Frame> CommutatorProduct for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn commutate(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
self.sub(&self.mul(lhs, rhs), &self.mul(rhs, lhs))
}
}
impl<B: Frame> AnticommutatorProduct for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn anticommutate(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
self.add(&self.mul(lhs, rhs), &self.mul(rhs, lhs))
}
}
impl<B: Frame> LeftContraction for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn contract_onto(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
let mut result = self.zero();
for (b1, s1) in lhs.vectors.iter() {
for (b2, s2) in rhs.vectors.iter() {
let contracted = self.basis.contract_onto(b1, b2);
if self.basis.is_zero(&contracted) {
continue;
}
let coeff = self.mul_coeff(s1, s2);
if self.syn.is_zero(&coeff) {
continue;
}
self.expand_and_accumulate(&mut result.vectors, &contracted, &coeff);
}
}
result
}
}
impl<B: Frame> RightContraction for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn contract_by(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
let mut result = self.zero();
for (b1, s1) in lhs.vectors.iter() {
for (b2, s2) in rhs.vectors.iter() {
let contracted = self.basis.contract_by(b1, b2);
if self.basis.is_zero(&contracted) {
continue;
}
let coeff = self.mul_coeff(s1, s2);
if self.syn.is_zero(&coeff) {
continue;
}
self.expand_and_accumulate(&mut result.vectors, &contracted, &coeff);
}
}
result
}
}
impl<B: Frame> ScalarProduct for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn scalar_product(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
self.inner(lhs, rhs)
}
}
impl<B: Frame> DotProduct for SynAlgebra<B> {
type Output = SynMultivector<B>;
fn dot(&self, lhs: &Self::Elem, rhs: &Self::Elem) -> Self::Output {
self.inner(lhs, rhs)
}
}
impl<B: Frame> Exponential for SynAlgebra<B> {
type Output = Self::Elem;
fn exp(&self, elem: &Self::Elem) -> Self::Output {
self.exp(elem)
}
fn ln(&self, _elem: &Self::Elem) -> Self::Output {
todo!("Implement logarithm of multivector")
}
}
#[allow(dead_code)]
impl<B: Frame> SynAlgebra<B> {
pub fn powi(&self, elem: &SynMultivector<B>, exp: u32) -> SynMultivector<B> {
if exp == 0 {
return self.one();
}
let mut out = self.one();
for _ in 0..exp {
out = self.mul(&out, elem);
}
out
}
pub fn exp_with_terms(&self, elem: &SynMultivector<B>, terms: usize) -> SynMultivector<B> {
let mut sum = self.zero();
for k in 0..terms {
let power = self.powi(elem, k as u32);
let k_fact = (1..=k).fold(1usize, |acc, v| acc.saturating_mul(v));
let coeff_value = if k == 0 { 1.0 } else { 1.0 / (k_fact as f64) };
let coeff_expr: Expr = syn::parse_str(&format!("{coeff_value:.12}"))
.unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let coeff_mv =
self.scalar(CoefficientField::embed_expr(&self.syn, coeff_expr).unwrap());
let term = self.mul(&coeff_mv, &power);
sum = self.add(&sum, &term);
}
sum
}
pub fn exp(&self, elem: &SynMultivector<B>) -> SynMultivector<B> {
if let Some(scalar) = self.just_scalar(elem) {
let src = scalar.to_token_stream().to_string();
let expr_src = format!("({}).exp()", src);
let expr: Expr =
syn::parse_str(&expr_src).unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
return self.scalar(CoefficientField::embed_expr(&self.syn, expr).unwrap());
}
let a2 = self.mul(elem, elem);
if let Some(a2_scalar) = self.just_scalar(&a2) {
if self.syn.is_zero(&a2_scalar) {
return self.add(&self.one(), elem);
}
let parse_or_zero = |tokens: proc_macro2::TokenStream| -> Expr {
syn::parse2(tokens).unwrap_or_else(|_| syn::parse_str("0.0").unwrap())
};
let embed_tokens = |tokens: proc_macro2::TokenStream| {
let expr = parse_or_zero(tokens);
CoefficientField::embed_expr(&self.syn, expr)
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ScalarSign {
Negative,
Zero,
Positive,
Unknown,
}
fn sign_from_numeric(value: f64) -> ScalarSign {
if value > 0.0 {
ScalarSign::Positive
} else if value < 0.0 {
ScalarSign::Negative
} else {
ScalarSign::Zero
}
}
fn numeric_value_from_expr(expr: &Expr) -> Option<f64> {
match expr {
Expr::Lit(ExprLit { lit, .. }) => match lit {
Lit::Float(f) => f.base10_parse::<f64>().ok(),
Lit::Int(i) => i.base10_parse::<i128>().ok().map(|v| v as f64),
_ => None,
},
Expr::Unary(ExprUnary { op, expr, .. }) => match op {
syn::UnOp::Neg(_) => numeric_value_from_expr(expr).map(|v| -v),
_ => None,
},
Expr::Group(g) => numeric_value_from_expr(&g.expr),
Expr::Paren(p) => numeric_value_from_expr(&p.expr),
_ => None,
}
}
fn scalar_sign_from_expr(expr: &Expr) -> ScalarSign {
match expr {
Expr::Group(g) => scalar_sign_from_expr(&g.expr),
Expr::Paren(p) => scalar_sign_from_expr(&p.expr),
Expr::Unary(ExprUnary { op, expr, .. }) => match op {
syn::UnOp::Neg(_) => match scalar_sign_from_expr(expr) {
ScalarSign::Positive => ScalarSign::Negative,
ScalarSign::Negative => ScalarSign::Positive,
ScalarSign::Zero => ScalarSign::Zero,
ScalarSign::Unknown => ScalarSign::Negative,
},
_ => ScalarSign::Unknown,
},
Expr::MethodCall(call) => {
if call.method == "powi" {
if let Some(first_arg) = call.args.first() {
if let Some(exp_val) = numeric_value_from_expr(first_arg) {
if exp_val.round() as i64 % 2 == 0 {
return ScalarSign::Positive;
}
}
}
}
ScalarSign::Unknown
}
Expr::Lit(_) => numeric_value_from_expr(expr)
.map(sign_from_numeric)
.unwrap_or(ScalarSign::Unknown),
_ => ScalarSign::Unknown,
}
}
let a2_expr_opt = self.syn.to_expr(&a2_scalar);
if let Ok(a2_expr) = a2_expr_opt.clone() {
let numeric_from_expr = numeric_value_from_expr(&a2_expr);
let sign = numeric_from_expr
.as_ref()
.map(|v| sign_from_numeric(*v))
.unwrap_or_else(|| scalar_sign_from_expr(&a2_expr));
match sign {
ScalarSign::Zero => return self.add(&self.one(), elem),
ScalarSign::Negative => {
let a2_for_alpha = a2_expr.clone();
let alpha_tokens = quote! { (-(#a2_for_alpha)).sqrt() };
let alpha_expr: Expr = parse_or_zero(alpha_tokens.clone());
let alpha_for_sin = alpha_expr.clone();
let alpha_for_den = alpha_expr.clone();
let alpha_for_cos = alpha_expr.clone();
let s_tokens = quote! { (#alpha_for_sin).sin() / (#alpha_for_den) };
let c_tokens = quote! { (#alpha_for_cos).cos() };
let s_mv = self.scalar(embed_tokens(s_tokens).unwrap());
let c_mv = self.scalar(embed_tokens(c_tokens).unwrap());
let a_s = self.mul(&s_mv, elem);
return self.add(&a_s, &c_mv);
}
ScalarSign::Positive => {
let a2_for_alpha = a2_expr.clone();
let alpha_tokens = quote! { (#a2_for_alpha).sqrt() };
let alpha_expr: Expr = parse_or_zero(alpha_tokens.clone());
let alpha_for_sinh = alpha_expr.clone();
let alpha_for_den = alpha_expr.clone();
let alpha_for_cosh = alpha_expr.clone();
let s_tokens = quote! { (#alpha_for_sinh).sinh() / (#alpha_for_den) };
let c_tokens = quote! { (#alpha_for_cosh).cosh() };
let s_mv = self.scalar(embed_tokens(s_tokens).unwrap());
let c_mv = self.scalar(embed_tokens(c_tokens).unwrap());
let a_s = self.mul(&s_mv, elem);
return self.add(&a_s, &c_mv);
}
ScalarSign::Unknown => {
if let Some(numeric) = numeric_from_expr {
match sign_from_numeric(numeric) {
ScalarSign::Zero => return self.add(&self.one(), elem),
ScalarSign::Negative => {
let alpha = (-numeric).sqrt();
let alpha_src = format!("{alpha:.12}");
let s_src = format!("({alpha_src}).sin() / ({alpha_src})");
let c_src = format!("({alpha_src}).cos()");
let s_expr: Expr = syn::parse_str(&s_src)
.unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let c_expr: Expr = syn::parse_str(&c_src)
.unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let s_mv = self.scalar(
CoefficientField::embed_expr(&self.syn, s_expr).unwrap(),
);
let c_mv = self.scalar(
CoefficientField::embed_expr(&self.syn, c_expr).unwrap(),
);
let a_s = self.mul(&s_mv, elem);
return self.add(&a_s, &c_mv);
}
ScalarSign::Positive => {
let alpha = numeric.sqrt();
let alpha_src = format!("{alpha:.12}");
let s_src = format!("({alpha_src}).sinh() / ({alpha_src})");
let c_src = format!("({alpha_src}).cosh()");
let s_expr: Expr = syn::parse_str(&s_src)
.unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let c_expr: Expr = syn::parse_str(&c_src)
.unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let s_mv = self.scalar(
CoefficientField::embed_expr(&self.syn, s_expr).unwrap(),
);
let c_mv = self.scalar(
CoefficientField::embed_expr(&self.syn, c_expr).unwrap(),
);
let a_s = self.mul(&s_mv, elem);
return self.add(&a_s, &c_mv);
}
ScalarSign::Unknown => {}
}
}
}
}
}
let lit_val = (|| {
let s = a2_scalar.to_token_stream().to_string();
if let Ok(parsed) = syn::parse_str::<Expr>(&s) {
match parsed {
Expr::Lit(ExprLit {
lit: Lit::Float(f), ..
}) => f.base10_parse::<f64>().ok(),
Expr::Lit(ExprLit {
lit: Lit::Int(i), ..
}) => i.base10_parse::<i64>().ok().map(|v| v as f64),
_ => None,
}
} else {
None
}
})();
if let Some(v) = lit_val {
if v == 0.0 {
return self.add(&self.one(), elem);
}
if v < 0.0 {
let alpha = (-v).sqrt();
let alpha_src = format!("{alpha:.12}");
let s_src = format!("({alpha_src}).sin() / ({alpha_src})");
let c_src = format!("({alpha_src}).cos()");
let s_expr: Expr =
syn::parse_str(&s_src).unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let c_expr: Expr =
syn::parse_str(&c_src).unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let s_mv =
self.scalar(CoefficientField::embed_expr(&self.syn, s_expr).unwrap());
let c_mv =
self.scalar(CoefficientField::embed_expr(&self.syn, c_expr).unwrap());
let a_s = self.mul(&s_mv, elem);
return self.add(&a_s, &c_mv);
} else {
let alpha = v.sqrt();
let alpha_src = format!("{alpha:.12}");
let s_src = format!("({alpha_src}).sinh() / ({alpha_src})");
let c_src = format!("({alpha_src}).cosh()");
let s_expr: Expr =
syn::parse_str(&s_src).unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let c_expr: Expr =
syn::parse_str(&c_src).unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
let s_mv =
self.scalar(CoefficientField::embed_expr(&self.syn, s_expr).unwrap());
let c_mv =
self.scalar(CoefficientField::embed_expr(&self.syn, c_expr).unwrap());
let a_s = self.mul(&s_mv, elem);
return self.add(&a_s, &c_mv);
}
}
}
self.exp_with_terms(elem, 12)
}
pub fn sandwich(
&self,
rotor: &SynMultivector<B>,
target: &SynMultivector<B>,
) -> SynMultivector<B> {
let left = self.mul(rotor, target);
let rotor_reverse = self.reverse(rotor.clone());
self.mul(&left, &rotor_reverse)
}
pub fn sqrt(&self, elem: &SynMultivector<B>) -> SynMultivector<B> {
if let Some(scalar) = self.just_scalar(elem) {
let src = scalar.to_token_stream().to_string();
let expr_src = format!("({}).sqrt()", src);
let expr: Expr =
syn::parse_str(&expr_src).unwrap_or_else(|_| syn::parse_str("0.0").unwrap());
return self.scalar(CoefficientField::embed_expr(&self.syn, expr).unwrap());
}
self.zero()
}
}
#[cfg(test)]
mod tests {
use super::*;
use quote::ToTokens;
use syn::parse_str;
type Blade = u8;
fn algebra() -> SynAlgebra<Blade> {
SynAlgebra::new(Metric::new(3, 0), ActiveField::new(syn::parse_quote!(f32)))
}
fn basis_vector(algebra: &SynAlgebra<Blade>, mask: Blade) -> SynMultivector<Blade> {
let basis = algebra.basis_blade(mask);
algebra.from_terms([(basis, algebra.syn.one())])
}
#[test]
fn addition_accumulates_coefficients() {
let algebra = algebra();
let e1 = basis_vector(&algebra, 0b001);
let sum = algebra.add(&e1, &e1);
assert_eq!(sum.vectors.len(), 1);
let coeff = sum.vectors.values().next().unwrap();
let rendered = coeff.to_token_stream().to_string();
assert!(rendered.contains("2 as f32"));
}
#[test]
fn dual_maps_scalar_to_pseudoscalar() {
let algebra = algebra();
let scalar = algebra.one();
let pseudoscalar = algebra.pseudoscalar();
let dual_scalar = algebra.right_dual(&scalar);
assert!(algebra.equals(&dual_scalar, &pseudoscalar));
let dual_back = algebra.right_dual(&pseudoscalar);
assert!(algebra.equals(&dual_back, &scalar));
}
#[test]
fn meet_of_planes_returns_common_line() {
let algebra = algebra();
let e1 = basis_vector(&algebra, 0b001);
let e2 = basis_vector(&algebra, 0b010);
let e3 = basis_vector(&algebra, 0b100);
let plane_12 = algebra.wedge(&e1, &e2);
let plane_13 = algebra.wedge(&e1, &e3);
let meet = algebra.meet(&plane_12, &plane_13);
assert!(algebra.equals(&meet, &e1));
}
#[test]
fn inner_product_detects_parallel_and_orthogonal_vectors() {
let algebra = algebra();
let e1 = basis_vector(&algebra, 0b001);
let e2 = basis_vector(&algebra, 0b010);
let parallel = algebra.inner(&e1, &e1);
assert_eq!(parallel.vectors.len(), 1);
let coeff = parallel.vectors.values().next().unwrap();
let tokens = coeff.to_token_stream().to_string();
assert!(tokens == "1.0" || tokens == "1 as f32");
let orthogonal = algebra.inner(&e1, &e2);
assert!(orthogonal.vectors.is_empty());
}
#[test]
fn scalar_part_extracts_scalar_component() {
let algebra = algebra();
let scalar = algebra.one();
let e1 = basis_vector(&algebra, 0b001);
let mixed = algebra.add(&scalar, &e1);
let extracted = algebra.just_scalar(&mixed);
assert!(extracted.is_none());
let extracted = algebra.get_scalar(&mixed);
let rendered = extracted.to_token_stream().to_string();
assert!(rendered == "1.0" || rendered == "1 as f32");
}
#[test]
fn parse_expr_supports_method_chains() {
let algebra = algebra();
let expr: Expr = parse_str("((0.5 * (e1 ^ e2)).exp()).sandwich(e1)").unwrap();
let parsed = algebra.parse_expr(&expr).unwrap();
let rotor_expr: Expr = parse_str("(0.5 * (e1 ^ e2)).exp()").unwrap();
let rotor = algebra.parse_expr(&rotor_expr).unwrap();
let target = algebra.parse_expr(&parse_str("e1").unwrap()).unwrap();
let expected = algebra.sandwich(&rotor, &target);
assert!(algebra.equals(&parsed, &expected));
}
#[test]
fn parse_expr_rejects_unknown_identifier() {
let algebra = algebra();
let expr: Expr = parse_str("foo").unwrap();
let err = algebra.parse_expr(&expr).unwrap_err();
assert!(err.to_string().contains("unknown basis identifier"));
}
#[test]
fn parse_expr_rejects_duplicate_basis_indices() {
let algebra = algebra();
let expr: Expr = parse_str("e1_1").unwrap();
let err = algebra.parse_expr(&expr).unwrap_err();
assert!(err.to_string().contains("duplicate basis index"));
}
#[test]
fn scalar_exp_produces_symbolic_exp() {
let algebra = algebra();
let scalar = algebra.scalar(algebra.syn.one());
let result = algebra.exp(&scalar);
let rendered = algebra.get_scalar(&result).to_token_stream().to_string();
assert!(rendered.contains("exp"));
}
#[test]
fn sandwich_matches_manual_reverse_product() {
let algebra = algebra();
let e1 = basis_vector(&algebra, 0b001);
let e2 = basis_vector(&algebra, 0b010);
let bivector = algebra.wedge(&e1, &e2);
let half = algebra
.scalar(CoefficientField::embed_expr(&algebra.syn, parse_str("0.5").unwrap()).unwrap());
let scaled_bivector = algebra.mul(&half, &bivector);
let scalar_one = algebra.one();
let rotor = algebra.add(&scalar_one, &scaled_bivector);
let target = basis_vector(&algebra, 0b001);
let via_helper = algebra.sandwich(&rotor, &target);
let manual = {
let left = algebra.mul(&rotor, &target);
let right = algebra.reverse(rotor.clone());
algebra.mul(&left, &right)
};
assert!(algebra.equals(&via_helper, &manual));
}
}