ngeom_macros/
lib.rs

1extern crate proc_macro;
2use core::cmp::Ordering;
3use core::ops::{AddAssign, Mul};
4use proc_macro2::Span;
5use proc_macro2::TokenStream;
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::parse::{Error, Parse, ParseStream, Result};
8use syn::punctuated::Punctuated;
9use syn::{
10    parse2, AttrStyle, Attribute, Fields, FieldsNamed, Ident, Item, ItemMacro, LitInt, Meta, Path,
11    Token,
12};
13
14struct MultivectorStruct {
15    ident: Ident,
16    components: Vec<Ident>,
17}
18
19fn bubble_sort_count_swaps(l: &mut [usize]) -> usize {
20    let mut swaps: usize = 0;
21    for i in (0..l.len()).rev() {
22        for j in 0..i {
23            if l[j] > l[j + 1] {
24                (l[j], l[j + 1]) = (l[j + 1], l[j]);
25                swaps += 1
26            }
27        }
28    }
29    swaps
30}
31
32fn sign_from_parity(swaps: usize) -> isize {
33    match swaps % 2 {
34        0 => 1,
35        1 => -1,
36        _ => panic!("Expected parity to be 0 or 1"),
37    }
38}
39
40#[derive(Default, Clone)]
41struct SymbolicSumExpr(Vec<SymbolicProdExpr>);
42
43#[derive(PartialEq, Eq, Clone)]
44struct SymbolicProdExpr(isize, Vec<Symbol>);
45
46#[derive(PartialOrd, Ord, PartialEq, Eq, Clone)]
47enum Symbol {
48    Scalar(Ident),
49    StructField(Ident, Ident), // Two Idents: a var and a field (i.e. var.field)
50}
51
52impl ToTokens for Symbol {
53    fn to_tokens(&self, tokens: &mut TokenStream) {
54        match self {
55            Symbol::StructField(var, field) => {
56                var.to_tokens(tokens);
57                <Token![.]>::default().to_tokens(tokens);
58                field.to_tokens(tokens);
59            }
60            Symbol::Scalar(var) => {
61                var.to_tokens(tokens);
62            }
63        }
64    }
65}
66
67impl PartialOrd for SymbolicProdExpr {
68    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
69        Some(self.cmp(other))
70    }
71}
72
73impl Ord for SymbolicProdExpr {
74    fn cmp(&self, SymbolicProdExpr(other_coef, other_symbols): &Self) -> Ordering {
75        let SymbolicProdExpr(self_coef, self_symbols) = self;
76        self_symbols
77            .cmp(&other_symbols)
78            .then_with(|| self_coef.cmp(&other_coef))
79    }
80}
81
82impl Mul<SymbolicProdExpr> for SymbolicProdExpr {
83    type Output = SymbolicProdExpr;
84    fn mul(mut self, SymbolicProdExpr(r_coef, mut r_symbols): Self) -> SymbolicProdExpr {
85        let SymbolicProdExpr(l_coef, l_symbols) = &mut self;
86        *l_coef *= r_coef;
87        l_symbols.append(&mut r_symbols);
88        self
89    }
90}
91
92impl Mul<isize> for SymbolicProdExpr {
93    type Output = SymbolicProdExpr;
94    fn mul(mut self, r: isize) -> SymbolicProdExpr {
95        let SymbolicProdExpr(l_coef, _) = &mut self;
96        *l_coef *= r;
97        self
98    }
99}
100
101impl SymbolicProdExpr {
102    fn simplify(mut self) -> Self {
103        let SymbolicProdExpr(coef, symbols) = &mut self;
104        // Sort expression
105        if *coef == 0 {
106            symbols.clear();
107        } else {
108            symbols.sort();
109        }
110        self
111    }
112}
113
114impl ToTokens for SymbolicSumExpr {
115    fn to_tokens(&self, tokens: &mut TokenStream) {
116        let SymbolicSumExpr(terms) = self;
117        if terms.len() == 0 {
118            tokens.append_all(quote! { T::default() });
119        } else {
120            for (count, prod_expr) in terms.iter().enumerate() {
121                let SymbolicProdExpr(coef, prod_terms) = prod_expr;
122                let coef = *coef;
123
124                if coef >= 0 {
125                    if count != 0 {
126                        tokens.append_all(quote! { + });
127                    }
128                } else {
129                    tokens.append_all(quote! { - });
130                }
131                let coef = coef.abs();
132
133                if prod_terms.len() == 0 {
134                    // If there are no symbols in the product, then this is a scalar
135                    if coef == 0 {
136                        tokens.append_all(quote! { T::default() });
137                    } else if coef == 1 {
138                        tokens.append_all(quote! {
139                            T::one()
140                        });
141                    } else {
142                        panic!("Scalar was not 0, -1 or 1");
143                    }
144                } else {
145                    // There are symbols in the product
146                    if coef == 0 {
147                        tokens.append_all(quote! { T::default() * });
148                    } else if coef == 1 {
149                        // No token needed if coefficient is unity
150                    } else if coef == 2 {
151                        tokens.append_all(quote! { (T::one() + T::one()) * });
152                    } else {
153                        panic!("No representation for large coefficient {}", coef);
154                    }
155                    for (sym_count, sym) in prod_terms.iter().enumerate() {
156                        if sym_count > 0 {
157                            tokens.append_all(quote! { * });
158                        }
159                        sym.to_tokens(tokens);
160                    }
161                }
162            }
163        }
164    }
165}
166
167impl SymbolicSumExpr {
168    fn simplify(self) -> Self {
169        let SymbolicSumExpr(terms) = self;
170
171        // Simplify all products
172        let mut terms: Vec<_> = terms.into_iter().map(|prod| prod.simplify()).collect();
173
174        // Sort expression by symbolic values
175        terms.sort();
176
177        // Combine adjacent terms whose symbolic parts are equal
178        let mut new_expression = vec![];
179        let mut prev_coef = 0;
180        let mut prev_symbols = vec![];
181        for SymbolicProdExpr(coef, symbols) in terms.into_iter() {
182            if prev_symbols == symbols {
183                prev_coef += coef;
184            } else {
185                new_expression.push(SymbolicProdExpr(prev_coef, prev_symbols));
186                prev_coef = coef;
187                prev_symbols = symbols;
188            }
189        }
190        new_expression.push(SymbolicProdExpr(prev_coef, prev_symbols));
191
192        let mut terms = new_expression;
193
194        // Remove all products with coefficient = 0
195        terms.retain(|SymbolicProdExpr(coef, _)| *coef != 0);
196
197        SymbolicSumExpr(terms)
198    }
199}
200
201impl AddAssign<SymbolicProdExpr> for SymbolicSumExpr {
202    fn add_assign(&mut self, r_term: SymbolicProdExpr) {
203        let SymbolicSumExpr(l_terms) = self;
204        l_terms.push(r_term);
205    }
206}
207
208impl AddAssign<SymbolicSumExpr> for SymbolicSumExpr {
209    fn add_assign(&mut self, SymbolicSumExpr(mut r_terms): SymbolicSumExpr) {
210        let SymbolicSumExpr(l_terms) = self;
211        l_terms.append(&mut r_terms);
212    }
213}
214
215impl Mul<SymbolicProdExpr> for SymbolicSumExpr {
216    type Output = SymbolicSumExpr;
217    fn mul(self, r: SymbolicProdExpr) -> SymbolicSumExpr {
218        let SymbolicSumExpr(l) = self;
219        SymbolicSumExpr(l.into_iter().map(|lp| lp * r.clone()).collect())
220    }
221}
222
223impl Mul<isize> for SymbolicSumExpr {
224    type Output = SymbolicSumExpr;
225    fn mul(self, r: isize) -> SymbolicSumExpr {
226        let SymbolicSumExpr(l) = self;
227        SymbolicSumExpr(l.into_iter().map(|lp| lp * r).collect())
228    }
229}
230
231// Returns the right_complement of a basis element as a pair of
232// (coef, complement_ix)
233fn right_complement(right_complement_signs: &Vec<isize>, coef: isize, i: usize) -> (isize, usize) {
234    let complement_ix = right_complement_signs.len() - i - 1;
235    (coef * right_complement_signs[i], complement_ix)
236}
237
238// Returns the inverse of the right complement of a basis element
239// (i.e. the left complement)
240// as a pair of (coef, complement_ix)
241fn left_complement(right_complement_signs: &Vec<isize>, coef: isize, i: usize) -> (isize, usize) {
242    let complement_ix = right_complement_signs.len() - i - 1;
243    (coef * right_complement_signs[complement_ix], complement_ix)
244}
245
246// We will represent a multivector as an array of coefficients on the basis elements.
247// e.g. in 2D, there are 1 + 3 + 3 + 1 = 8 basis elements,
248// and a full multivector uses all of them: [1, 1, 1, 1, 1, 1, 1, 1]
249// An object such as a Bivector would only use a few of them: [0, 0, 0, 0, 1, 1, 1, 0]
250
251#[derive(PartialEq)]
252enum Object {
253    Scalar,
254    Struct(StructObject),
255}
256
257#[derive(PartialEq)]
258struct StructObject {
259    name: Ident,
260    select_components: Vec<Option<(Ident, isize)>>,
261    is_compound: bool,
262}
263
264impl Object {
265    fn type_name(&self) -> TokenStream {
266        match self {
267            Object::Scalar => quote! { T },
268            Object::Struct(StructObject { name, .. }) => {
269                quote! { #name < T > }
270            }
271        }
272    }
273    fn type_name_colons(&self) -> TokenStream {
274        match self {
275            Object::Scalar => quote! { T },
276            Object::Struct(StructObject { name, .. }) => {
277                quote! { #name :: < T > }
278            }
279        }
280    }
281    fn has_component(&self, i: usize) -> bool {
282        match self {
283            Object::Scalar => i == 0,
284            Object::Struct(StructObject {
285                select_components, ..
286            }) => select_components[i].is_some(),
287        }
288    }
289    fn is_compound(&self) -> bool {
290        match self {
291            Object::Scalar => false,
292            Object::Struct(StructObject { is_compound, .. }) => *is_compound,
293        }
294    }
295    fn select_components(&self, var: Ident, len: usize) -> Vec<Option<(Symbol, isize)>> {
296        match self {
297            Object::Scalar => {
298                let mut result = vec![None; len];
299                result[0] = Some((Symbol::Scalar(var), 1));
300                result
301            }
302            Object::Struct(StructObject {
303                select_components, ..
304            }) => select_components
305                .iter()
306                .map(|select_component| {
307                    select_component.as_ref().map(|(field, coef)| {
308                        (Symbol::StructField(var.clone(), field.clone()), *coef)
309                    })
310                })
311                .collect(),
312        }
313    }
314}
315
316// Function to generate a simple unary operator on an object,
317// where entries can be rearranged and coefficients can be modified
318// (e.g. for implementing neg, right_complement, reverse)
319// The result will be a list of symbolic expressions,
320// one for each coefficient value in the resultant multivector.
321fn generate_symbolic_rearrangement<F: Fn(isize, usize) -> (isize, usize)>(
322    select_components: &[Option<(Symbol, isize)>],
323    op: F,
324) -> Vec<SymbolicSumExpr> {
325    // Generate the sum
326    let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components.len()];
327
328    for (i, is_selected) in select_components.iter().enumerate() {
329        if let Some((symbol, coef)) = is_selected {
330            let (coef, result_basis_ix) = op(*coef, i);
331            result[result_basis_ix] += SymbolicProdExpr(coef, vec![symbol.clone()]);
332        }
333    }
334
335    result.into_iter().map(|expr| expr.simplify()).collect()
336}
337
338fn generate_symbolic_norm<F: Fn(isize, usize, isize, usize) -> (isize, usize)>(
339    select_components: &[Option<(Symbol, isize)>],
340    product: F,
341    sqrt: bool,
342) -> Vec<SymbolicSumExpr> {
343    // Generate the product
344    let mut expressions: Vec<SymbolicSumExpr> = vec![Default::default(); select_components.len()];
345    for (i, (i_symbol, i_coef)) in select_components
346        .iter()
347        .enumerate()
348        .filter_map(|(i, selected)| selected.as_ref().map(|selected| (i, selected)))
349    {
350        for (j, (j_symbol, j_coef)) in select_components
351            .iter()
352            .enumerate()
353            .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
354        {
355            let (coef, ix) = product(*i_coef, i, *j_coef, j);
356
357            expressions[ix] += SymbolicProdExpr(coef, vec![i_symbol.clone(), j_symbol.clone()]);
358        }
359    }
360
361    let expressions: Vec<_> = expressions
362        .into_iter()
363        .map(|expr| expr.simplify())
364        .collect();
365
366    if sqrt {
367        // See if we can take the square root symbolically
368        // Otherwise, return an empty expression (which will cause no code to be generated)
369        {
370            let is_scalar = expressions
371                .iter()
372                .enumerate()
373                .all(|(i, expr)| i == 0 || expr.0.len() == 0);
374            let is_anti_scalar = expressions
375                .iter()
376                .enumerate()
377                .all(|(i, expr)| i == expressions.len() - 1 || expr.0.len() == 0);
378            let expression = if is_scalar {
379                Some(expressions[0].clone())
380            } else if is_anti_scalar {
381                Some(expressions[expressions.len() - 1].clone())
382            } else {
383                None
384            };
385            if let Some(expression) = expression {
386                let SymbolicSumExpr(terms) = &expression;
387                if terms.len() == 1 {
388                    let SymbolicProdExpr(coef, terms) = &terms[0];
389                    if *coef == 1 && terms.len() == 2 && terms[0] == terms[1] {
390                        let sqrt_expression =
391                            SymbolicSumExpr(vec![SymbolicProdExpr(1, vec![terms[0].clone()])]);
392                        let target_ix = if is_scalar {
393                            0
394                        } else if is_anti_scalar {
395                            expressions.len() - 1
396                        } else {
397                            panic!("Took sqrt of something that wasn't a scalar or antiscalar");
398                        };
399                        Some(
400                            (0..select_components.len())
401                                .map(|i| {
402                                    if i == target_ix {
403                                        sqrt_expression.clone()
404                                    } else {
405                                        Default::default()
406                                    }
407                                })
408                                .collect(),
409                        )
410                    } else {
411                        None // Expression is not a square
412                    }
413                } else {
414                    None // Multiple terms in the sum
415                }
416            } else {
417                None // Squared norm is not a scalar or antiscalar
418            }
419        }
420        .unwrap_or(vec![Default::default(); select_components.len()])
421    } else {
422        // Return the squared norm
423        expressions
424    }
425}
426
427// Function to generate a sum of two objects, e.g. for overloading + or -
428// The result will be a list of symbolic expressions,
429// one for each coefficient value in the resultant multivector.
430fn generate_symbolic_sum(
431    select_components_a: &[Option<(Symbol, isize)>],
432    select_components_b: &[Option<(Symbol, isize)>],
433    coef_a: isize,
434    coef_b: isize,
435) -> Vec<SymbolicSumExpr> {
436    // Generate the sum
437    let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components_a.len()];
438
439    for (i, (a_selected, b_selected)) in select_components_a
440        .iter()
441        .zip(select_components_b.iter())
442        .enumerate()
443    {
444        if let Some((symbol_a, coef_symbol_a)) = a_selected {
445            result[i] += SymbolicProdExpr(coef_symbol_a * coef_a, vec![symbol_a.clone()]);
446        }
447        if let Some((symbol_b, coef_symbol_b)) = b_selected {
448            result[i] += SymbolicProdExpr(coef_symbol_b * coef_b, vec![symbol_b.clone()]);
449        }
450    }
451
452    result.into_iter().map(|expr| expr.simplify()).collect()
453}
454
455// Function to generate a product of two objects, e.g. geometric product, wedge product, etc.
456// The result will be a list of symbolic expressions,
457// one for each coefficient value in the resultant multivector.
458fn generate_symbolic_product<F: Fn(isize, usize, isize, usize) -> (isize, usize)>(
459    select_components_a: &[Option<(Symbol, isize)>],
460    select_components_b: &[Option<(Symbol, isize)>],
461    product: F,
462) -> Vec<SymbolicSumExpr> {
463    // Generate the product
464    let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components_a.len()];
465    for (i, (i_symbol, i_coef)) in select_components_a
466        .iter()
467        .enumerate()
468        .filter_map(|(i, selected)| selected.as_ref().map(|selected| (i, selected)))
469    {
470        for (j, (j_symbol, j_coef)) in select_components_b
471            .iter()
472            .enumerate()
473            .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
474        {
475            let (coef, ix) = product(*i_coef, i, *j_coef, j);
476
477            result[ix] += SymbolicProdExpr(coef, vec![i_symbol.clone(), j_symbol.clone()]);
478        }
479    }
480
481    result.into_iter().map(|expr| expr.simplify()).collect()
482}
483
484// Function to generate a double product of two objects, e.g. sandwich product, project,
485// etc.
486// The result will be a list of symbolic expressions,
487// one for each coefficient value in the resultant multivector.
488// The resulting code will implement the product in the following order:
489// (B PRODUCT1 A) PRODUCT2 B
490fn generate_symbolic_double_product<
491    F1: Fn(isize, usize, isize, usize) -> (isize, usize),
492    F2: Fn(isize, usize, isize, usize) -> (isize, usize),
493>(
494    select_components_a: &[Option<(Symbol, isize)>],
495    select_components_b: &[Option<(Symbol, isize)>],
496    product_1: F1,
497    product_2: F2,
498) -> Vec<SymbolicSumExpr> {
499    // Generate the first intermediate product B PRODUCT1 A
500    // where i maps to components of B, and j maps to components of A
501    let mut intermediate_result: Vec<SymbolicSumExpr> =
502        vec![Default::default(); select_components_a.len()];
503    for (i, (i_symbol, i_coef)) in select_components_b
504        .iter()
505        .enumerate()
506        .filter_map(|(i, selected)| selected.as_ref().map(|selected| (i, selected)))
507    {
508        for (j, (j_symbol, j_coef)) in select_components_a
509            .iter()
510            .enumerate()
511            .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
512        {
513            let (coef, ix) = product_1(*i_coef, i, *j_coef, j);
514            intermediate_result[ix] +=
515                SymbolicProdExpr(coef, vec![i_symbol.clone(), j_symbol.clone()]);
516        }
517    }
518    let intermediate_result: Vec<_> = intermediate_result
519        .into_iter()
520        .map(|expr| expr.simplify())
521        .collect();
522
523    // Generate the final product (B PRODUCT1 A) PRODUCT2 B
524    // where i maps to components of the intermediate result B PRODUCT1 A
525    // and j maps to components of B.
526    let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components_a.len()];
527    for (i, intermediate_term) in intermediate_result.iter().enumerate() {
528        for (j, (j_symbol, j_coef)) in select_components_b
529            .iter()
530            .enumerate()
531            .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
532        {
533            let (coef, ix) = product_2(1, i, *j_coef, j);
534            let new_term = SymbolicProdExpr(coef, vec![j_symbol.clone()]);
535            let result_term = intermediate_term.clone() * new_term;
536            result[ix] += result_term;
537        }
538    }
539
540    result.into_iter().map(|expr| expr.simplify()).collect()
541}
542
543fn find_output_object<'a>(
544    objects: &'a [Object],
545    output_expressions: &[SymbolicSumExpr],
546) -> Option<&'a Object> {
547    let select_output_components: Vec<_> = output_expressions
548        .iter()
549        .map(|SymbolicSumExpr(e)| e.len() != 0)
550        .collect();
551    objects.iter().find(|o| {
552        select_output_components
553            .iter()
554            .enumerate()
555            .find(|&(i, &out_c)| out_c && !o.has_component(i))
556            .is_none()
557    })
558}
559
560fn gen_unary_operator(
561    objects: &[Object],
562    op_trait: TokenStream,
563    op_fn: Ident,
564    obj: &Object,
565    expressions: &[SymbolicSumExpr],
566    alias: Option<(TokenStream, Ident)>,
567) -> TokenStream {
568    if matches!(obj, Object::Scalar) {
569        // Do not generate operations with the scalar being the LHS--
570        // typically because these would violate rust's orphan rule
571        // or result in conflicting trait implementations
572        return quote! {};
573    };
574
575    // Figure out what the type of the output is
576    let output_object = find_output_object(&objects, &expressions);
577
578    let Some(output_object) = output_object else {
579        // No output object matches the result we got,
580        // so don't generate any code
581        return quote! {};
582    };
583
584    if matches!(output_object, Object::Scalar) && expressions[0].0.len() == 0 {
585        // This operation unconditionally returns 0,
586        // so invoking it is probably a type error--
587        // do not generate code for it
588        return quote! {};
589    }
590
591    let output_type_name = &output_object.type_name_colons();
592    let type_name = &obj.type_name();
593
594    let return_expr = match output_object {
595        Object::Scalar => {
596            let expr = &expressions[0];
597            quote! { #expr }
598        }
599        Object::Struct(output_struct_object) => {
600            let output_fields: TokenStream = output_struct_object
601                .select_components
602                .iter()
603                .zip(expressions.iter())
604                .map(|(select_component, expr)| {
605                    if let Some((field, coef)) = select_component {
606                        let expr = expr.clone() * *coef;
607                        quote! { #field: #expr, }
608                    } else {
609                        return quote! {};
610                    }
611                })
612                .collect();
613            quote! {
614                #output_type_name {
615                    #output_fields
616                }
617            }
618        }
619    };
620
621    let associated_output_type = quote! { type Output = #output_type_name; };
622
623    let code = quote! {
624        impl < T: Ring > #op_trait for #type_name {
625            #associated_output_type
626
627            fn #op_fn (self) -> #output_type_name {
628                #return_expr
629            }
630        }
631    };
632
633    let alias_code = if let Some((alias_trait, alias_fn)) = alias {
634        quote! {
635            impl < T: Ring > #alias_trait for #type_name {
636                #associated_output_type
637
638                fn #alias_fn (self) -> #output_type_name {
639                    self.#op_fn()
640                }
641            }
642        }
643    } else {
644        quote! {}
645    };
646
647    quote! {
648        #code
649        #alias_code
650    }
651}
652
653fn gen_binary_operator<
654    F: Fn(&[Option<(Symbol, isize)>], &[Option<(Symbol, isize)>]) -> Vec<SymbolicSumExpr>,
655>(
656    basis_element_count: usize,
657    objects: &[Object],
658    op_trait: TokenStream,
659    op_fn: Ident,
660    lhs_obj: &Object,
661    op: F,
662    implicit_promotion_to_compound: bool,
663    alias: Option<(TokenStream, Ident)>,
664) -> TokenStream {
665    objects
666        .iter()
667        .map(|rhs_obj| {
668            if matches!(lhs_obj, Object::Scalar) {
669                // Do not generate operations with the scalar being the LHS--
670                // these violate rust's orphan rule
671                // or result in conflicting trait implementations.
672                // Technically we could allow this for custom ops such as T::cross(r: Vector<T>)
673                // but we won't for the sake of consistency.
674                // Use Vector<T>::cross(r: T) instead.
675                return quote! {};
676            };
677
678            let expressions = op(
679                &lhs_obj
680                    .select_components(Ident::new("self", Span::call_site()), basis_element_count),
681                &rhs_obj.select_components(Ident::new("r", Span::call_site()), basis_element_count),
682            );
683
684            // Figure out what the type of the output is
685            let output_object = find_output_object(&objects, &expressions);
686
687            let Some(output_object) = output_object else {
688                // No output object matches the result we got,
689                // so don't generate any code
690                return quote! {};
691            };
692
693            let rhs_type_name = &rhs_obj.type_name();
694            let lhs_type_name = &lhs_obj.type_name();
695            let output_type_name = &output_object.type_name_colons();
696
697            if !implicit_promotion_to_compound
698                && output_object.is_compound()
699                && !(lhs_obj.is_compound() || rhs_obj.is_compound())
700            {
701                // Do not create compound objects unintentionally.
702                // Only allow returning a compound object
703                // when taking products of compound objects and other objects
704                return quote! {};
705            }
706
707            if matches!(output_object, Object::Scalar) && expressions[0].0.len() == 0 {
708                // This operation unconditionally returns 0,
709                // so invoking it is probably a type error--
710                // do not generate code for it.
711
712                return quote! {};
713            }
714
715            let return_expr = match output_object {
716                Object::Scalar => {
717                    let expr = &expressions[0];
718                    quote! { #expr }
719                }
720                Object::Struct(output_struct_object) => {
721                    let output_fields: TokenStream = output_struct_object
722                        .select_components
723                        .iter()
724                        .zip(expressions.iter())
725                        .map(|(select_component, expr)| {
726                            if let Some((field, coef)) = select_component {
727                                let expr = expr.clone() * *coef;
728                                quote! { #field: #expr, }
729                            } else {
730                                return quote! {};
731                            }
732                        })
733                        .collect();
734                    quote! {
735                        #output_type_name {
736                            #output_fields
737                        }
738                    }
739                }
740            };
741
742            let associated_output_type = quote! { type Output = #output_type_name; };
743
744            let code = quote! {
745                impl < T: Ring > #op_trait < #rhs_type_name >  for #lhs_type_name {
746                    #associated_output_type
747
748                    fn #op_fn (self, r: #rhs_type_name) -> #output_type_name {
749                        #return_expr
750                    }
751                }
752            };
753
754            let alias_code = if let Some((alias_trait, alias_fn)) = &alias {
755                quote! {
756                    impl < T: Ring > #alias_trait < #rhs_type_name > for #lhs_type_name {
757                        #associated_output_type
758
759                        fn #alias_fn (self, r: #rhs_type_name) -> #output_type_name {
760                            self.#op_fn(r)
761                        }
762                    }
763                }
764            } else {
765                quote! {}
766            };
767
768            quote! {
769                #code
770                #alias_code
771            }
772        })
773        .collect()
774}
775
776fn gen_antiscalar_operator(
777    basis_element_count: usize,
778    op_trait: Ident,
779    anti_op_trait: Ident,
780    op_fns: &[(Ident, Ident)],
781    obj: &Object,
782) -> TokenStream {
783    let is_antiscalar = (0..basis_element_count).all(|i| {
784        if i == basis_element_count - 1 {
785            obj.has_component(i)
786        } else {
787            !obj.has_component(i)
788        }
789    });
790
791    if !is_antiscalar {
792        // Do not generate anti-operations on types which are not the antiscalar
793        return quote! {};
794    }
795
796    let Object::Struct(struct_obj) = obj else {
797        // Do not generate operations on the scalar--
798        // these violate rust's orphan rule
799        // or result in conflicting trait implementations.
800        // This should never happen--the scalar should never be the antiscalar
801        return quote! {};
802    };
803
804    let (field, coef) = struct_obj.select_components[basis_element_count - 1]
805        .as_ref()
806        .unwrap();
807
808    if *coef != 1 {
809        // Do not generate antiscalar operations if the antiscalar field in the struct
810        // is actually the negative antiscalar.
811        // We could generate these, but this is likely an error and is very confusing
812        // e.g. x.abs() would return a negative field value
813        return quote! {};
814    }
815
816    let field_expr = quote! {
817        self . #field
818    };
819
820    let type_name = &obj.type_name();
821    let struct_name = struct_obj.name.clone();
822
823    let functions_code: TokenStream = op_fns
824        .iter()
825        .map(|(fn_ident, anti_fn_ident)| {
826            quote! {
827                fn #anti_fn_ident (self) -> Self::Output {
828                    #struct_name {
829                        #field: #field_expr . #fn_ident ()
830                    }
831                }
832            }
833        })
834        .collect();
835
836    quote! {
837        impl < T: #op_trait > #anti_op_trait for #type_name {
838            type Output = #struct_name < < T as #op_trait >::Output >;
839
840            #functions_code
841        }
842    }
843}
844
845fn implement_geometric_algebra(
846    basis_vector_idents: Vec<Ident>,
847    metric: Vec<isize>,
848    multivector_structs: Vec<MultivectorStruct>,
849) -> Result<TokenStream> {
850    // Sanity checks
851    if basis_vector_idents.len() == 0 {
852        return Err(Error::new(Span::call_site(), "Basis vector set is empty"));
853    }
854    if basis_vector_idents.len() != metric.len() {
855        return Err(Error::new(
856            Span::call_site(),
857            "Metric and basis are different sizes",
858        ));
859    }
860    if multivector_structs.len() == 0 {
861        return Err(Error::new(
862            Span::call_site(),
863            "No multivector structs defined",
864        ));
865    }
866
867    // The number of dimensions in the algebra
868    // (e.g. use a 4D algebra to represent 3D geometry)
869    let dimension = metric.len();
870
871    let basis = {
872        // We will represent a basis element in the algebra as a Vec<usize>
873        // e.g. vec![2] = e_2, vec![1,2,3] = e_123, vec![] = 1
874
875        // To generate all the basis elements, we will iterate through the k-vectors.
876        // For each k, we right-multiply the set of basis vectors
877        // onto the basis elements of grade k-1,
878        // removing any results that are already
879        // represented in the basis element set.
880
881        // Start with the 0-vector, i.e. scalar, represented as vec![]
882        let mut basis_km1_vectors = vec![vec![]];
883        let mut basis = vec![vec![]];
884
885        for _ in 1..=dimension {
886            let mut basis_k_vectors = vec![];
887            for b1 in basis_km1_vectors {
888                for be2 in 0..dimension {
889                    // Insert be2 into b1 in sorted order
890                    match b1.binary_search(&be2) {
891                        Ok(_) => {}
892                        Err(pos) => {
893                            let mut b = b1.clone();
894                            b.insert(pos, be2);
895                            if !basis_k_vectors.contains(&b) {
896                                basis_k_vectors.push(b.clone());
897                                basis.push(b);
898                            }
899                        }
900                    }
901                }
902            }
903            basis_km1_vectors = basis_k_vectors;
904        }
905
906        // We now have a set of basis components in the vec `basis`.
907        // Each one is a product of basis vectors, in sorted order.
908        basis
909    };
910
911    let basis_element_count = basis.len();
912
913    // Analyze the basis vector identifiers to find the common prefix,
914    // as well as the part that varies
915    let (ident_prefix, ident_variants) = {
916        let first = basis_vector_idents
917            .first()
918            .map(|x| x.to_string())
919            .unwrap_or_default();
920
921        let len = first.chars().count();
922        assert!(len >= 1, "Identifier should be non-zero length");
923        let prefix = first.chars().take(len - 1).collect::<String>();
924
925        let variants = basis_vector_idents
926            .iter()
927            .map(|basis_ident| {
928                let basis_ident_str = basis_ident.to_string();
929                if basis_ident_str.chars().count() != len || !basis_ident_str.starts_with(&prefix) {
930                    return Err(Error::new(
931                        basis_ident.span(),
932                        "Bad identifier name: must be common prefix + 1 char",
933                    ));
934                }
935
936                Ok(basis_ident_str.chars().nth(len - 1).unwrap())
937            })
938            .collect::<Result<Vec<_>>>()?;
939        (prefix, variants)
940    };
941    let ident_prefix_len = ident_prefix.chars().count();
942
943    // Start with the scalar type
944    let mut objects = vec![Object::Scalar];
945    let struct_objects = multivector_structs
946        .iter()
947        .map(|multivector_struct| {
948            let name = multivector_struct.ident.clone();
949            let mut select_components = vec![None; basis_element_count];
950
951            // Put together the select_components array
952            for component_ident in multivector_struct.components.iter() {
953                let component_ident_str = component_ident.to_string();
954
955                if component_ident_str.chars().count() <= ident_prefix_len
956                    || !component_ident_str.starts_with(&ident_prefix)
957                {
958                    return Err(Error::new(
959                        component_ident.span(),
960                        "Bad identifier: must be common prefix + 1 or more chars (and there is already a scalar field)",
961                    ));
962                }
963                let variant_product = component_ident_str
964                    .chars()
965                    .skip(ident_prefix_len)
966                    .collect::<Vec<_>>();
967                let mut b = variant_product
968                    .iter()
969                    .map(|c| {
970                        ident_variants
971                            .iter()
972                            .position(|c2| c2 == c)
973                            .ok_or(Error::new(
974                                component_ident.span(),
975                                "Bad identifier: must be composed of basis vectors (and there is already a scalar field)",
976                            ))
977                    })
978                    .collect::<Result<Vec<_>>>();
979
980                // Don't immediately throw the error if we get something that is not composed of basis vectors--we will special-case the scalar value
981                // which can be any identifier.
982                if let Err(e) = b {
983                    if select_components[0].is_some() {
984                        return Err(e);
985                    }
986                    b = Ok(vec![]);
987                }
988
989                let mut b = b?;
990
991                let sign = sign_from_parity(bubble_sort_count_swaps(&mut b));
992                let basis_index = basis.iter().position(|b2| b2 == &b).ok_or(Error::new(
993                    component_ident.span(),
994                    "Bad identifier: cannot repeat basis vectors",
995                ))?;
996
997                if select_components[basis_index].is_some() {
998                    return Err(Error::new(
999                        component_ident.span(),
1000                        "Bad identifier: duplicate",
1001                    ));
1002                }
1003                select_components[basis_index] = Some((component_ident.clone(), sign));
1004            }
1005
1006            // Check if object is mixed-grade
1007            let grades = select_components
1008                .iter()
1009                .enumerate()
1010                .filter_map(|(i, component)| component.as_ref().map(|_| basis[i].len()))
1011                .collect::<Vec<_>>();
1012            let is_compound = grades.windows(2).any(|g| g[0] != g[1]);
1013
1014            Ok(Object::Struct(StructObject {
1015                name,
1016                select_components,
1017                is_compound,
1018            }))
1019        })
1020        .collect::<Result<Vec<_>>>()?;
1021
1022    objects.extend(struct_objects);
1023
1024    let right_complement_signs: Vec<_> = (0..basis_element_count)
1025        .map(|i| {
1026            let dual_i = basis_element_count - i - 1;
1027
1028            // Compute the product of the basis element and its dual basis element
1029            // and figure out what the sign needs so that the product equals I
1030            // (rather than -I)
1031            let mut product: Vec<usize> = basis[i]
1032                .iter()
1033                .cloned()
1034                .chain(basis[dual_i].iter().cloned())
1035                .collect();
1036            sign_from_parity(bubble_sort_count_swaps(product.as_mut()))
1037        })
1038        .collect();
1039
1040    // Generate dot product multiplication table for the basis elements.
1041    // The dot product always produces a scalar.
1042    let dot_product_multiplication_table: Vec<Vec<isize>> = {
1043        let multiply_basis_vectors = |ei: usize, ej: usize| {
1044            // This would need to get more complicated for CGA
1045            // where the metric g is not a diagonal matrix
1046            if ei == ej {
1047                metric[ei]
1048            } else {
1049                0
1050            }
1051        };
1052
1053        let multiply_basis_elements = |i: usize, j: usize| {
1054            // The scalar product of bivectors, trivectors, etc.
1055            // can be found using the Gram determinant.
1056
1057            let bi = &basis[i];
1058            let bj = &basis[j];
1059            if bi.len() != bj.len() {
1060                return 0;
1061            }
1062            if bi.len() == 0 {
1063                return 1; // 1 • 1 = 1
1064            }
1065
1066            let gram_matrix: Vec<Vec<isize>> = bi
1067                .iter()
1068                .map(|&ei| {
1069                    bj.iter()
1070                        .map(|&ej| multiply_basis_vectors(ei, ej))
1071                        .collect()
1072                })
1073                .collect();
1074
1075            fn determinant(m: &Vec<Vec<isize>>) -> isize {
1076                if m.len() == 1 {
1077                    m[0][0]
1078                } else {
1079                    let n = m.len();
1080                    (0..n)
1081                        .map(move |j| {
1082                            let i = 0;
1083                            let sign = match (i + j) % 2 {
1084                                0 => 1,
1085                                1 => -1,
1086                                _ => panic!("Expected parity to be 0 or 1"),
1087                            };
1088
1089                            let minor: Vec<Vec<_>> = (0..n)
1090                                .flat_map(|i2| {
1091                                    if i2 == i {
1092                                        None
1093                                    } else {
1094                                        Some(
1095                                            (0..n)
1096                                                .flat_map(|j2| {
1097                                                    if j2 == j {
1098                                                        None
1099                                                    } else {
1100                                                        Some(m[i2][j2])
1101                                                    }
1102                                                })
1103                                                .collect(),
1104                                        )
1105                                    }
1106                                })
1107                                .collect();
1108
1109                            sign * m[i][j] * determinant(&minor)
1110                        })
1111                        .sum()
1112                }
1113            }
1114            determinant(&gram_matrix)
1115        };
1116
1117        (0..basis_element_count)
1118            .map(|i| {
1119                (0..basis_element_count)
1120                    .map(move |j| multiply_basis_elements(i, j))
1121                    .collect()
1122            })
1123            .collect()
1124    };
1125
1126    let dot_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1127        let coef_mul = dot_product_multiplication_table[i][j];
1128        (coef_i * coef_j * coef_mul, 0)
1129    };
1130
1131    let anti_dot_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1132        let (i_coef, i) = right_complement(&right_complement_signs, coef_i, i);
1133        let (j_coef, j) = right_complement(&right_complement_signs, coef_j, j);
1134        let (coef, ix) = dot_product_f(i_coef, i, j_coef, j);
1135        let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1136        (coef, ix)
1137    };
1138
1139    // Generate geometric product multiplication table for the basis elements
1140    // Each entry is a tuple of the (coefficient, basis_index)
1141    // e.g. (1, 0) means the multiplication result is 1 * scalar = 1
1142    let geometric_product_multiplication_table: Vec<Vec<(isize, usize)>> = {
1143        let multiply_basis_elements = |i: usize, j: usize| {
1144            let mut product: Vec<_> = basis[i]
1145                .iter()
1146                .cloned()
1147                .chain(basis[j].iter().cloned())
1148                .collect();
1149            let swaps = bubble_sort_count_swaps(product.as_mut());
1150            let mut coef = match swaps % 2 {
1151                0 => 1,
1152                1 => -1,
1153                _ => panic!("Expected parity to be 0 or 1"),
1154            };
1155
1156            // Remove repeated elements in the product
1157            let mut new_product = vec![];
1158            let mut prev_e = None;
1159            for e in product.into_iter() {
1160                if Some(e) == prev_e {
1161                    coef *= metric[e];
1162                    prev_e = None;
1163                } else {
1164                    if let Some(prev_e) = prev_e {
1165                        new_product.push(prev_e);
1166                    }
1167                    prev_e = Some(e);
1168                }
1169            }
1170            if let Some(prev_e) = prev_e {
1171                new_product.push(prev_e);
1172            }
1173
1174            // Figure out which basis element this corresponds to
1175            basis
1176                .iter()
1177                .enumerate()
1178                .find_map(|(i, b)| {
1179                    let mut b_sorted = b.clone();
1180                    let swaps = bubble_sort_count_swaps(b_sorted.as_mut());
1181                    (new_product == b_sorted).then(|| {
1182                        let coef = coef
1183                            * match swaps % 2 {
1184                                0 => 1,
1185                                1 => -1,
1186                                _ => panic!("Expected parity to be 0 or 1"),
1187                            };
1188                        (coef, i)
1189                    })
1190                })
1191                .expect("Product of basis elements not found in basis set")
1192        };
1193
1194        (0..basis_element_count)
1195            .map(|i| {
1196                (0..basis_element_count)
1197                    .map(move |j| multiply_basis_elements(i, j))
1198                    .collect()
1199            })
1200            .collect()
1201    };
1202
1203    let geometric_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1204        let (coef_mul, ix) = geometric_product_multiplication_table[i][j];
1205        (coef_i * coef_j * coef_mul, ix)
1206    };
1207
1208    let geometric_antiproduct_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1209        let (coef_i, i) = right_complement(&right_complement_signs, coef_i, i);
1210        let (coef_j, j) = right_complement(&right_complement_signs, coef_j, j);
1211        let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1212        let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1213        (coef, ix)
1214    };
1215
1216    let impl_code: TokenStream = objects
1217        .iter()
1218        .map(|obj| {
1219            // Derive From / Into
1220            let from_code: TokenStream = {
1221                if matches!(obj, Object::Scalar) {
1222                    // Do not implement From on the scalar--
1223                    // typically because these would violate rust's orphan rule
1224                    // and are also not useful
1225                    return quote! {};
1226                }
1227
1228                objects.iter().map(|other_obj| {
1229                    // See if the object we are converting from
1230                    // contains a non-empty strict subset of the components in our object
1231
1232                    let is_subset = (0..basis_element_count).all(|i| obj.has_component(i) || !other_obj.has_component(i));
1233                    if !is_subset { return quote! {}; }
1234
1235                    let is_same_object = (0..basis_element_count).all(|i| obj.has_component(i) == other_obj.has_component(i));
1236                    if is_same_object { return quote! {}; }
1237
1238                    let is_not_empty = (0..basis_element_count).any(|i| obj.has_component(i) && other_obj.has_component(i));
1239                    if !is_not_empty { return quote! {}; }
1240
1241                    let my_type_name = obj.type_name();
1242                    let my_type_name_colons = obj.type_name_colons();
1243                    let other_type_name = other_obj.type_name();
1244
1245                    let expressions: Vec<_> = other_obj.select_components(Ident::new("value", Span::call_site()), basis_element_count).iter().map(|select_component| {
1246                        match select_component {
1247                            Some((symbol, coef)) => SymbolicSumExpr(vec![SymbolicProdExpr(*coef, vec![symbol.clone()])]),
1248                            None => Default::default(),
1249                        }
1250                    }).collect();
1251
1252                    let return_expr = match obj {
1253                        Object::Scalar => {
1254                            let expr = &expressions[0];
1255                            quote! { #expr }
1256                        }
1257                        Object::Struct(output_struct_object) => {
1258                            let output_fields: TokenStream = output_struct_object
1259                                .select_components
1260                                .iter()
1261                                .zip(expressions.iter())
1262                                .map(|(select_component, expr)| {
1263                                    if let Some((field, coef)) = select_component {
1264                                        let expr = expr.clone() * *coef;
1265                                        quote! { #field: #expr, }
1266                                    } else {
1267                                        return quote! {};
1268                                    }
1269                                })
1270                                .collect();
1271                            quote! {
1272                                #my_type_name_colons {
1273                                    #output_fields
1274                                }
1275                            }
1276                        }
1277                    };
1278
1279                    quote! {
1280                        impl<T: core::default::Default> From<#other_type_name> for #my_type_name {
1281                            fn from(value: #other_type_name) -> #my_type_name {
1282                                #return_expr
1283                            }
1284                        }
1285                    }
1286                }).collect()
1287            };
1288
1289            let obj_self_components = &obj.select_components(Ident::new("self", Span::call_site()), basis_element_count);
1290
1291            // Add a method anti_abs() on the antiscalar
1292            let anti_abs_code = gen_antiscalar_operator(
1293                basis_element_count,
1294                Ident::new("Abs", Span::call_site()),
1295                Ident::new("AntiAbs", Span::call_site()),
1296                &[(
1297                    Ident::new("abs", Span::call_site()),
1298                    Ident::new("anti_abs", Span::call_site()),
1299                )],
1300                obj,
1301            );
1302
1303            // Add a method anti_recip() on the antiscalar
1304            let anti_recip_code = gen_antiscalar_operator(
1305                basis_element_count,
1306                Ident::new("Recip", Span::call_site()),
1307                Ident::new("AntiRecip", Span::call_site()),
1308                &[(
1309                    Ident::new("recip", Span::call_site()),
1310                    Ident::new("anti_recip", Span::call_site()),
1311                )],
1312                obj,
1313            );
1314
1315            // Add a method anti_sqrt() on the antiscalar
1316            let anti_sqrt_code = gen_antiscalar_operator(
1317                basis_element_count,
1318                Ident::new("Sqrt", Span::call_site()),
1319                Ident::new("AntiSqrt", Span::call_site()),
1320                &[(
1321                    Ident::new("sqrt", Span::call_site()),
1322                    Ident::new("anti_sqrt", Span::call_site()),
1323                )],
1324                obj,
1325            );
1326
1327            // Add a method anti_trig() on the antiscalar
1328            let anti_trig_code = gen_antiscalar_operator(
1329                basis_element_count,
1330                Ident::new("Trig", Span::call_site()),
1331                Ident::new("AntiTrig", Span::call_site()),
1332                &[
1333                    (
1334                        Ident::new("cos", Span::call_site()),
1335                        Ident::new("anti_cos", Span::call_site()),
1336                    ),
1337                    (
1338                        Ident::new("sin", Span::call_site()),
1339                        Ident::new("anti_sin", Span::call_site()),
1340                    ),
1341                    (
1342                        Ident::new("sinc", Span::call_site()),
1343                        Ident::new("anti_sinc", Span::call_site()),
1344                    ),
1345                ],
1346                obj,
1347            );
1348
1349            // Overload unary -
1350            let op_trait = quote! { core::ops::Neg };
1351            let op_fn = Ident::new("neg", Span::call_site());
1352            let neg_code = gen_unary_operator(
1353                &objects,
1354                op_trait,
1355                op_fn,
1356                &obj,
1357                &generate_symbolic_rearrangement(&obj_self_components, |coef: isize, i: usize| (-coef, i)),
1358                None,
1359            );
1360
1361            // Add a method A.reverse()
1362            let op_trait = quote! { Reverse };
1363            let op_fn = Ident::new("reverse", Span::call_site());
1364            let reverse_f = |coef: isize, i: usize| {
1365                    let coef_rev = sign_from_parity((basis[i].len() / 2) % 2);
1366                    (coef * coef_rev, i)
1367                };
1368            let reverse_expressions = generate_symbolic_rearrangement(&obj_self_components, reverse_f);
1369            let reverse_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &reverse_expressions, None);
1370
1371            // Add a method A.anti_reverse()
1372            let op_trait = quote! { AntiReverse };
1373            let op_fn = Ident::new("anti_reverse", Span::call_site());
1374            let alias_trait = quote! { InverseTransformation };
1375            let alias_fn = Ident::new("inverse_transformation", Span::call_site());
1376            let anti_reverse_f = |coef: isize, i: usize| {
1377                    let (coef, i) = right_complement(&right_complement_signs, coef, i);
1378                    let (coef, i) = reverse_f(coef, i);
1379                    let (coef, i) = left_complement(&right_complement_signs, coef, i);
1380                    (coef, i)
1381                };
1382            let anti_reverse_expressions = generate_symbolic_rearrangement(&obj_self_components, anti_reverse_f);
1383            let anti_reverse_code = gen_unary_operator(
1384                &objects,
1385                op_trait,
1386                op_fn,
1387                &obj,
1388                &anti_reverse_expressions,
1389                Some((alias_trait, alias_fn)), // alias
1390            );
1391
1392            // Add a method A.right_complement()
1393            let op_trait = quote! { RightComplement };
1394            let op_fn = Ident::new("right_complement", Span::call_site());
1395            let right_complement_expressions = generate_symbolic_rearrangement(&obj_self_components, |coef: isize, i: usize| right_complement(&right_complement_signs, coef, i));
1396            let right_complement_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &right_complement_expressions, None);
1397
1398            // Add a method A.left_complement()
1399            let op_trait = quote! { LeftComplement };
1400            let op_fn = Ident::new("left_complement", Span::call_site());
1401            let left_complement_expressions = generate_symbolic_rearrangement(&obj_self_components, |coef: isize, i: usize| left_complement(&right_complement_signs, coef, i));
1402            let left_complement_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &left_complement_expressions, None);
1403
1404            // Add a method A.bulk()
1405            let op_trait = quote! { Bulk };
1406            let op_fn = Ident::new("bulk", Span::call_site());
1407            let bulk_f = |coef: isize, i: usize| {
1408                // This would need to be changed for CGA where the dot product multiplication table
1409                // is not diagonal
1410                let (zero_or_one, _) = dot_product_f(1, i, 1, i);
1411                (coef * zero_or_one, i)
1412            };
1413            let bulk_expressions = generate_symbolic_rearrangement(&obj_self_components, bulk_f);
1414            let bulk_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_expressions, None);
1415
1416            // Add a method A.weight()
1417            let op_trait = quote! { Weight };
1418            let op_fn = Ident::new("weight", Span::call_site());
1419            let weight_f = |coef: isize, i: usize| {
1420                let (coef, i) = right_complement(&right_complement_signs, coef, i);
1421                let (coef, i) = bulk_f(coef, i);
1422                let (coef, i) = left_complement(&right_complement_signs, coef, i);
1423                (coef, i)
1424            };
1425            let weight_expressions = generate_symbolic_rearrangement(&obj_self_components, weight_f);
1426            let weight_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_expressions, None);
1427
1428            // Add a method A.bulk_dual() which computes A★
1429            let op_trait = quote! { BulkDual };
1430            let op_fn = Ident::new("bulk_dual", Span::call_site());
1431            let bulk_dual_f = |coef: isize, i: usize| {
1432                let (coef, i) = bulk_f(coef, i);
1433                let (coef, i) = right_complement(&right_complement_signs, coef, i);
1434                (coef, i)
1435            };
1436            let bulk_dual_expressions = generate_symbolic_rearrangement(&obj_self_components, bulk_dual_f);
1437            let bulk_dual_code =
1438                gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_dual_expressions, None);
1439
1440            // Add a method A.weight_dual() which computes A☆
1441            let op_trait = quote! { WeightDual };
1442            let op_fn = Ident::new("weight_dual", Span::call_site());
1443            let alias_trait = quote! { Normal };
1444            let alias_fn = Ident::new("normal", Span::call_site());
1445            let weight_dual_f = |coef: isize, i: usize| {
1446                let (coef, i) = weight_f(coef, i);
1447                let (coef, i) = right_complement(&right_complement_signs, coef, i);
1448                (coef, i)
1449            };
1450            let weight_dual_expressions = generate_symbolic_rearrangement(&obj_self_components, weight_dual_f);
1451            let weight_dual_code =
1452                gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_dual_expressions, Some((alias_trait, alias_fn)));
1453
1454
1455            // Add a method A.bulk_norm_squared()
1456            let op_trait = quote! { BulkNormSquared };
1457            let op_fn = Ident::new("bulk_norm_squared", Span::call_site());
1458            // Squared norm uses the product A • A
1459
1460            let bulk_norm_squared_expressions = generate_symbolic_norm(&obj_self_components, dot_product_f, false);
1461            let bulk_norm_squared_code =
1462                gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_norm_squared_expressions, None);
1463
1464            let bulk_norm_code = if !bulk_norm_squared_code.is_empty() {
1465                // Add a method A.bulk_norm() if it is possible to symbolically take the sqrt
1466                let op_trait = quote! { BulkNorm };
1467                let op_fn = Ident::new("bulk_norm", Span::call_site());
1468                let bulk_norm_expressions = generate_symbolic_norm(&obj_self_components, dot_product_f, true);
1469                let bulk_norm_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_norm_expressions, None);
1470                if !bulk_norm_code.is_empty() {
1471                    // We found a way to symbolically take the sqrt
1472                    bulk_norm_code
1473                } else {
1474                    // Take the square root of the norm numerically
1475                    let type_name = obj.type_name();
1476                    quote! {
1477                        impl < T: Ring > BulkNorm for #type_name
1478                            where <Self as BulkNormSquared>::Output: Sqrt {
1479                            type Output = <<Self as BulkNormSquared>::Output as Sqrt>::Output;
1480
1481                            fn bulk_norm (self) -> Self::Output {
1482                                self.bulk_norm_squared().sqrt()
1483                            }
1484                        }
1485                    }
1486                }
1487            } else {
1488                quote! {}  // There is no squared bulk norm, so return no impls for norm
1489            };
1490
1491            let op_trait = quote! { WeightNormSquared };
1492            let op_fn = Ident::new("weight_norm_squared", Span::call_site());
1493            let weight_norm_squared_expressions = generate_symbolic_norm(&obj_self_components, anti_dot_product_f, false);
1494
1495            let weight_norm_squared_code =
1496                gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_norm_squared_expressions, None);
1497
1498            let weight_norm_code = if !weight_norm_squared_code.is_empty() {
1499                // Add a method A.weight_norm() if it is possible to symbolically take the sqrt
1500                let op_trait = quote! { WeightNorm };
1501                let op_fn = Ident::new("weight_norm", Span::call_site());
1502                let weight_norm_expressions = generate_symbolic_norm(&obj_self_components, anti_dot_product_f, true);
1503                let weight_norm_code =
1504                    gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_norm_expressions, None);
1505                if !weight_norm_code.is_empty() {
1506                    // We found a way to symbolically take the sqrt
1507                    weight_norm_code
1508                } else {
1509                    // Take the square root of the norm numerically
1510                    let type_name = obj.type_name();
1511                    quote! {
1512                        impl < T: Ring > WeightNorm for #type_name
1513                            where <Self as WeightNormSquared>::Output: AntiSqrt {
1514                            type Output = <<Self as WeightNormSquared>::Output as AntiSqrt>::Output;
1515
1516                            fn weight_norm (self) -> Self::Output {
1517                                self.weight_norm_squared().anti_sqrt()
1518                            }
1519                        }
1520                    }
1521                }
1522            } else {
1523                quote! {}  // There is no squared norm, so return no impls for norm
1524            };
1525
1526            /*
1527            // Add a method A.geometric_norm_squared()
1528            let op_trait = quote! { GeometricNormSquared };
1529            let op_fn = Ident::new("geometric_norm_squared", Span::call_site());
1530            let geometric_norm_product_f =|i: usize, j: usize| {
1531                // Compute the product A * ~A
1532
1533                let reverse_coef = match (basis[j].len() / 2) % 2 {
1534                    0 => 1,
1535                    1 => -1,
1536                    _ => panic!("Expected parity to be 0 or 1"),
1537                };
1538
1539                let (coef, ix_result) = multiplication_table[i][j];
1540
1541                (coef * reverse_coef, ix_result)
1542            };
1543            let geometric_norm_squared_expressions = generate_symbolic_norm(&basis, &obj.select_components, geometric_norm_product_f, false);
1544            let geometric_norm_squared_code =
1545                gen_unary_operator(&basis, &objects, op_trait, op_fn, &obj, &geometric_norm_squared_expressions
1546                    , false);
1547
1548            let geometric_norm_code = if !geometric_norm_squared_code.is_empty() {
1549                // Add a method A.geometric_norm() if it is possible to symbolically take the sqrt
1550                let op_trait = quote! { GeometricNorm };
1551                let op_fn = Ident::new("geometric_norm", Span::call_site());
1552                let geometric_norm_code =
1553                    gen_unary_operator(&basis, &objects, op_trait, op_fn, &obj, &generate_symbolic_norm(&basis, &obj.select_components, geometric_norm_product_f, true)
1554                        , false);
1555                if !geometric_norm_code.is_empty() {
1556                    // We found a way to symbolically take the sqrt
1557                    geometric_norm_code
1558                } else {
1559                    // Take the square root of the norm numerically
1560                    let type_name = obj.type_name();
1561                    quote! {
1562                        impl < T: Ring > GeometricNorm for #type_name
1563                            where <Self as GeometricNormSquared>::Output: Sqrt {
1564                            type Output = <Self as GeometricNormSquared>::Output;
1565
1566                            fn geometric_norm (self) -> Self::Output {
1567                                self.geometric_norm_squared().sqrt()
1568                            }
1569                        }
1570                    }
1571                }
1572            } else {
1573                quote! {}  // There is no squared norm, so return no impls for norm
1574            };
1575
1576            let op_trait = quote! { GeometricInfNormSquared };
1577            let op_fn = Ident::new("geometric_inf_norm_squared", Span::call_site());
1578            let geometric_inf_norm_product_f = |i: usize, j: usize| {
1579                let (coef_i, i) = right_complement(&right_complement_signs, i);
1580                let (coef_j, j) = right_complement(&right_complement_signs, j);
1581                let (coef_prod, ix) = geometric_norm_product_f(i, j);
1582                (coef_i * coef_j * coef_prod, ix)
1583            };
1584            let geometric_inf_norm_squared_expressions = generate_symbolic_norm(&basis, &obj.select_components, geometric_inf_norm_product_f, false);
1585            let geometric_inf_norm_squared_code =
1586                gen_unary_operator(&basis, &objects, op_trait, op_fn, &obj, &geometric_inf_norm_squared_expressions
1587                    , false);
1588
1589            let geometric_inf_norm_code = if !geometric_inf_norm_squared_code.is_empty() {
1590                // Add a method A.geometric_inf_norm() if it is possible to symbolically take the sqrt
1591                let op_trait = quote! { GeometricInfNorm };
1592                let op_fn = Ident::new("geometric_inf_norm", Span::call_site());
1593                let geometric_inf_norm_code =
1594                    gen_unary_operator(&basis, &objects, op_trait, op_fn, &obj, &generate_symbolic_norm(&basis, &obj.select_components, geometric_inf_norm_product_f, true)
1595                        , false);
1596                if !geometric_inf_norm_code.is_empty() {
1597                    // We found a way to symbolically take the sqrt
1598                    geometric_inf_norm_code
1599                } else {
1600                    // Take the square root of the norm numerically
1601                    let type_name = obj.type_name();
1602                    quote! {
1603                        impl < T: Ring > GeometricInfNorm for #type_name
1604                            where <Self as GeometricInfNormSquared>::Output: Sqrt {
1605                            type Output = <Self as GeometricInfNormSquared>::Output;
1606
1607                            fn geometric_inf_norm (self) -> Self::Output {
1608                                self.geometric_inf_norm_squared().sqrt()
1609                            }
1610                        }
1611                    }
1612                }
1613            } else {
1614                quote! {}  // There is no squared norm, so return no impls for norm
1615            };
1616
1617            */
1618
1619            // Implement .normalized() which returns a scaled copy of the object
1620            // where the bulk norm has been made to equal to 1
1621            // Also implement .unitized() which returns a scaled copy of the object
1622            // where the weight norm has been made to 𝟙
1623            let hat_code = if !matches!(obj, Object::Scalar) {
1624                let type_name = obj.type_name();
1625                quote! {
1626                    impl<T: Ring + Recip<Output=T>> Normalized for #type_name
1627                    where
1628                        Self: BulkNorm<Output=T>
1629                    {
1630                        type Output = Self;
1631                        fn normalized(self) -> Self {
1632                            self * self.bulk_norm().recip()
1633                        }
1634                    }
1635
1636                    impl<T: Ring + Recip<Output=T>> Unitized for #type_name
1637                    where
1638                        Self: WeightNorm<Output=AntiScalar<T>> // TODO: Figure out anti-scalar
1639                                                               // handling
1640                    {
1641                        type Output = Self;
1642                        fn unitized(self) -> Self {
1643                            self.anti_mul(self.weight_norm().anti_recip())
1644                        }
1645                    }
1646                }
1647            } else {
1648                quote! {}
1649            };
1650
1651            // Overload +
1652            let op_trait = quote! { core::ops::Add };
1653            let op_fn = Ident::new("add", Span::call_site());
1654            let add_code = gen_binary_operator(
1655                basis_element_count,
1656                &objects,
1657                op_trait,
1658                op_fn,
1659                &obj,
1660                |a, b| generate_symbolic_sum(a, b, 1, 1),
1661                true,  // implicit_promotion_to_compound
1662                None, // alias
1663            );
1664
1665            // Overload -
1666            let op_trait = quote! { core::ops::Sub };
1667            let op_fn = Ident::new("sub", Span::call_site());
1668            let sub_code = gen_binary_operator(
1669                basis_element_count,
1670                &objects,
1671                op_trait,
1672                op_fn,
1673                &obj,
1674                |a, b| generate_symbolic_sum(a, b, 1, -1),
1675                true,  // implicit_promotion_to_compound
1676                None, // alias
1677            );
1678
1679            // Add a method A.wedge(B) which computes A ∧ B
1680            let op_trait = quote! { Wedge };
1681            let op_fn = Ident::new("wedge", Span::call_site());
1682            let alias_trait = quote! { Join };
1683            let alias_fn = Ident::new("join", Span::call_site());
1684            let wedge_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1685                let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1686                // Select grade s + t
1687                let s = basis[i].len();
1688                let t = basis[j].len();
1689                let u = basis[ix].len();
1690                let coef = if s + t == u { coef } else { 0 };
1691                (coef, ix)
1692            };
1693            let wedge_product_code = gen_binary_operator(
1694                basis_element_count,
1695                &objects,
1696                op_trait,
1697                op_fn,
1698                &obj,
1699                |a, b| generate_symbolic_product(a, b, wedge_product_f),
1700                false, // implicit_promotion_to_compound
1701                Some((alias_trait, alias_fn)), // alias
1702            );
1703
1704            // Add a method A.anti_wedge(B) which computes A ∨ B
1705            let op_trait = quote! { AntiWedge };
1706            let op_fn = Ident::new("anti_wedge", Span::call_site());
1707            let alias_trait = quote! { Meet };
1708            let alias_fn = Ident::new("meet", Span::call_site());
1709            let anti_wedge_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1710                let (coef_i, i) = right_complement(&right_complement_signs, coef_i, i);
1711                let (coef_j, j) = right_complement(&right_complement_signs, coef_j, j);
1712
1713                let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1714                // Select grade s + t
1715                let s = basis[i].len();
1716                let t = basis[j].len();
1717                let u = basis[ix].len();
1718                let coef = if s + t == u { coef } else { 0 };
1719
1720                let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1721                (coef, ix)
1722            };
1723            let anti_wedge_product_code = gen_binary_operator(
1724                basis_element_count,
1725                &objects,
1726                op_trait,
1727                op_fn,
1728                &obj,
1729                |a, b| generate_symbolic_product(a, b, anti_wedge_product_f),
1730                false, // implicit_promotion_to_compound
1731                Some((alias_trait, alias_fn)), // alias
1732            );
1733
1734            // Add a method A.dot(B) which computes A • B
1735            let op_trait = quote! { Dot };
1736            let op_fn = Ident::new("dot", Span::call_site());
1737
1738            let dot_product_code = gen_binary_operator(
1739                basis_element_count,
1740                &objects,
1741                op_trait,
1742                op_fn,
1743                &obj,
1744                |a, b| generate_symbolic_product(a, b, dot_product_f),
1745                false, // implicit_promotion_to_compound
1746                None, // alias
1747            );
1748
1749            // Add a method A.anti_dot(B) which computes A ∘ B
1750            let op_trait = quote! { AntiDot };
1751            let op_fn = Ident::new("anti_dot", Span::call_site());
1752
1753            let anti_dot_product_code = gen_binary_operator(
1754                basis_element_count,
1755                &objects,
1756                op_trait,
1757                op_fn,
1758                &obj,
1759                |a, b| generate_symbolic_product(a, b, anti_dot_product_f),
1760                false, // implicit_promotion_to_compound
1761                None, // alias
1762            );
1763
1764            // Implement the geometric product ⟑
1765            let op_trait = quote! { WedgeDot };
1766            let op_fn = Ident::new("wedge_dot", Span::call_site());
1767
1768            let wedge_dot_product_code = gen_binary_operator(
1769                basis_element_count,
1770                &objects,
1771                op_trait,
1772                op_fn,
1773                &obj,
1774                |a, b| generate_symbolic_product(a, b, geometric_product_f),
1775                false, // implicit_promotion_to_compound
1776                None, // alias
1777            );
1778
1779            // Implement the geometric antiproduct ⟇
1780            let op_trait = quote! { AntiWedgeDot };
1781            let op_fn = Ident::new("anti_wedge_dot", Span::call_site());
1782            let alias_trait = quote! { Compose };
1783            let alias_fn = Ident::new("compose", Span::call_site());
1784
1785            let anti_wedge_dot_product_code = gen_binary_operator(
1786                basis_element_count,
1787                &objects,
1788                op_trait,
1789                op_fn,
1790                &obj,
1791                |a, b| generate_symbolic_product(a, b, geometric_antiproduct_f),
1792                false, // implicit_promotion_to_compound
1793                Some((alias_trait, alias_fn)), // alias
1794            );
1795
1796            // Overload * for scalar multiplication
1797            let op_trait = quote! { core::ops::Mul };
1798            let op_fn = Ident::new("mul", Span::call_site());
1799            let scalar_product = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1800                let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1801                // Require s or t be a scalar
1802                let s = basis[i].len();
1803                let t = basis[j].len();
1804                let coef = if s == 0 || t == 0 { coef } else { 0 };
1805                (coef, ix)
1806            };
1807            let scalar_product_code = gen_binary_operator(
1808                basis_element_count,
1809                &objects,
1810                op_trait,
1811                op_fn,
1812                &obj,
1813                |a, b| generate_symbolic_product(a, b, scalar_product),
1814                true, // implicit_promotion_to_compound
1815                None, // alias
1816            );
1817
1818            // Implement A.anti_mul(B) for anti-scalar multiplication
1819            let op_trait = quote! { AntiMul };
1820            let op_fn = Ident::new("anti_mul", Span::call_site());
1821            let anti_scalar_f = |coef_i: isize, i: usize, coef_j: isize, j: usize|  {
1822                let (coef_i, i) = right_complement(&right_complement_signs, coef_i, i);
1823                let (coef_j, j) = right_complement(&right_complement_signs, coef_j, j);
1824
1825                let (coef, ix) = scalar_product(coef_i, i, coef_j, j);
1826
1827                let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1828                (coef, ix)
1829            };
1830            let anti_scalar_product_code = gen_binary_operator(
1831                basis_element_count,
1832                &objects,
1833                op_trait,
1834                op_fn,
1835                &obj,
1836                |a, b| generate_symbolic_product(a, b, anti_scalar_f),
1837                true, // implicit_promotion_to_compound
1838                None, // alias
1839            );
1840
1841            // Add a method A.bulk_expansion(B) which computes A ∧ B★
1842            let op_trait = quote! { BulkExpansion };
1843            let op_fn = Ident::new("bulk_expansion", Span::call_site());
1844            let bulk_expansion_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1845                let (coef_j, j) = bulk_dual_f(coef_j, j);
1846                let (coef, ix) = wedge_product_f(coef_i, i, coef_j, j);
1847                (coef, ix)
1848            };
1849            let bulk_expansion_code = gen_binary_operator(
1850                basis_element_count,
1851                &objects,
1852                op_trait,
1853                op_fn,
1854                &obj,
1855                |a, b| generate_symbolic_product(a, b, bulk_expansion_f),
1856                false, // implicit_promotion_to_compound
1857                None, // alias
1858            );
1859
1860            // Add a method A.weight_expansion(B) which computes A ∧ B☆
1861            let op_trait = quote! { WeightExpansion };
1862            let op_fn = Ident::new("weight_expansion", Span::call_site());
1863            let alias_trait = quote! { SupersetOrthogonalTo };
1864            let alias_fn = Ident::new("superset_orthogonal_to", Span::call_site());
1865            let weight_expansion_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1866                let (coef_j, j) = weight_dual_f(coef_j, j);
1867                let (coef, ix) = wedge_product_f(coef_i, i, coef_j, j);
1868                (coef, ix)
1869            };
1870            let weight_expansion_code = gen_binary_operator(
1871                basis_element_count,
1872                &objects,
1873                op_trait,
1874                op_fn,
1875                &obj,
1876                |a, b| generate_symbolic_product(a, b, weight_expansion_f),
1877                false, // implicit_promotion_to_compound
1878                Some((alias_trait, alias_fn)), // alias
1879            );
1880
1881            // Add a method A.bulk_contraction(B) which computes A ∨ B★
1882            let op_trait = quote! { BulkContraction };
1883            let op_fn = Ident::new("bulk_contraction", Span::call_site());
1884            let bulk_contraction_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1885                let (coef_j, j) = bulk_dual_f(coef_j, j);
1886                let (coef, ix) = anti_wedge_product_f(coef_i, i, coef_j, j);
1887                (coef, ix)
1888            };
1889            let bulk_contraction_code = gen_binary_operator(
1890                basis_element_count,
1891                &objects,
1892                op_trait,
1893                op_fn,
1894                &obj,
1895                |a, b| generate_symbolic_product(a, b, bulk_contraction_f),
1896                false, // implicit_promotion_to_compound
1897                None, // alias
1898            );
1899
1900            // Add a method A.weight_contraction(B) which computes A ∨ B☆
1901            let op_trait = quote! { WeightContraction };
1902            let op_fn = Ident::new("weight_contraction", Span::call_site());
1903            let alias_trait = quote! { SubsetOrthogonalTo };
1904            let alias_fn = Ident::new("subset_orthogonal_to", Span::call_site());
1905            let weight_contraction_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1906                let (coef_j, j) = weight_dual_f(coef_j, j);
1907                let (coef, ix) = anti_wedge_product_f(coef_i, i, coef_j, j);
1908                (coef, ix)
1909            };
1910            let weight_contraction_code = gen_binary_operator(
1911                basis_element_count,
1912                &objects,
1913                op_trait,
1914                op_fn,
1915                &obj,
1916                |a, b| generate_symbolic_product(a, b, weight_contraction_f),
1917                false, // implicit_promotion_to_compound
1918                Some((alias_trait, alias_fn)), // alias
1919            );
1920
1921            // Add a method A.anti_commutator(B) which computes (A ⟇ B - B ⟇ A) / 2
1922            let op_trait = quote! { AntiCommutator };
1923            let op_fn = Ident::new("anti_commutator", Span::call_site());
1924            let commutator_product = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1925                let (coef_1, ix) = geometric_antiproduct_f(coef_i, i, coef_j, j);
1926                let (coef_2, ix_2) = geometric_antiproduct_f(coef_j, j, coef_i, i);
1927
1928                let coef = coef_1 - coef_2;
1929                assert!(ix == ix_2);
1930                assert!(coef % 2 == 0);
1931                let coef = coef / 2;
1932                (coef, ix)
1933            };
1934            let commutator_product_code = gen_binary_operator(
1935                basis_element_count,
1936                &objects,
1937                op_trait,
1938                op_fn,
1939                &obj,
1940                |a, b| generate_symbolic_product(a, b, commutator_product),
1941                false, // implicit_promotion_to_compound
1942                None, // alias
1943            );
1944
1945            // Add a method A.transform(B) which computes B̰ ⟇ A ⟇ B
1946            let op_trait = quote! { Transform };
1947            let op_fn = Ident::new("transform", Span::call_site());
1948            let transform_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1949                // Compute first half of B̰ ⟇ A ⟇ B
1950                // Where i maps to B, and j maps to A.
1951                // In part 2, we will compute the geometric antiproduct of this intermediate result
1952                // with B
1953
1954                let (coef_i, i) = anti_reverse_f(coef_i, i);
1955                geometric_antiproduct_f(coef_i, i, coef_j, j)
1956            };
1957            // Compute second half of B̰ ⟇ A ⟇ B
1958            // In part 1, we computed the intermediate result B̰ ⟇ A which maps to i here.
1959            // j maps to B.
1960            let transform_2 = geometric_antiproduct_f;
1961
1962            let transform_code = gen_binary_operator(
1963                basis_element_count,
1964                &objects,
1965                op_trait,
1966                op_fn,
1967                &obj,
1968                |a, b| {
1969                    generate_symbolic_double_product(
1970                        a,
1971                        b,
1972                        transform_1,
1973                        transform_2,
1974                    )
1975                },
1976                false, // implicit_promotion_to_compound
1977                None, // alias
1978            );
1979
1980            // Add a method A.reverse_transform(B) which computes B ⟇ A ⟇ B̰
1981            let op_trait = quote! { TransformInverse };
1982            let op_fn = Ident::new("transform_inverse", Span::call_site());
1983            // Compute first half of B ⟇ A ⟇ B̰
1984            // Where i maps to B, and j maps to A.
1985            // In part 2, we will compute the geometric antiproduct of this intermediate result
1986            // with B̰
1987            let reverse_transform_1 = geometric_antiproduct_f;
1988            let reverse_transform_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1989                // Compute second half of B ⟇ A ⟇ B̰
1990                // In part 1, we computed the intermediate result B ⟇ A which maps to i here.
1991                // j maps to B.
1992
1993                let (coef_j, j) = anti_reverse_f(coef_j, j);
1994                geometric_antiproduct_f(coef_i, i, coef_j, j)
1995            };
1996            let reverse_transform_code = gen_binary_operator(
1997                basis_element_count,
1998                &objects,
1999                op_trait,
2000                op_fn,
2001                &obj,
2002                |a, b| {
2003                    generate_symbolic_double_product(
2004                        a,
2005                        b,
2006                        reverse_transform_1,
2007                        reverse_transform_2,
2008                    )
2009                },
2010                false, // implicit_promotion_to_compound
2011                None, // alias
2012            );
2013
2014            // Implement A.projection(B) which computes B ∨ (A ∧  B☆)
2015            let op_trait = quote! { Projection };
2016            let op_fn = Ident::new("projection", Span::call_site());
2017            let projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2018                // Compute second half of B ∨ (A ∧ B☆)
2019                // Where i maps to B, and j maps to A.
2020                // In part 2, we will compute the geometric product of B with this intermediate result
2021                weight_expansion_f(coef_j, j, coef_i, i)
2022            };
2023            let projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2024                // Compute second half of B ∨ (A ∧ B☆)
2025                // In part 1, we computed the intermediate result A ∧ B☆ which maps to i here.
2026                // j maps to B.
2027                anti_wedge_product_f(coef_j, j, coef_i, i)
2028            };
2029            let projection_code = gen_binary_operator(
2030                basis_element_count,
2031                &objects,
2032                op_trait,
2033                op_fn,
2034                &obj,
2035                |a, b| {
2036                    generate_symbolic_double_product(
2037                        a,
2038                        b,
2039                        projection_product_1,
2040                        projection_product_2,
2041                    )
2042                },
2043                false, // implicit_promotion_to_compound
2044                None, // alias
2045            );
2046
2047            // Implement A.anti_projection(B) which computes B ∧ (A ∨ B☆)
2048            let op_trait = quote! { AntiProjection };
2049            let op_fn = Ident::new("anti_projection", Span::call_site());
2050            let anti_projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2051                // Compute second half of B ∧ (A ∨ B☆)
2052                // Where i maps to B, and j maps to A.
2053                // In part 2, we will compute the geometric product of B with this intermediate result
2054                weight_contraction_f(coef_j, j, coef_i, i)
2055            };
2056            let anti_projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2057                // Compute second half of B ∧ (A ∨ B☆)
2058                // In part 1, we computed the intermediate result A ∨ B☆ which maps to i here.
2059                // j maps to B.
2060                wedge_product_f(coef_j, j, coef_i, i)
2061            };
2062            let anti_projection_code = gen_binary_operator(
2063                basis_element_count,
2064                &objects,
2065                op_trait,
2066                op_fn,
2067                &obj,
2068                |a, b| {
2069                    generate_symbolic_double_product(
2070                        a,
2071                        b,
2072                        anti_projection_product_1,
2073                        anti_projection_product_2,
2074                    )
2075                },
2076                false, // implicit_promotion_to_compound
2077                None, // alias
2078            );
2079
2080            // Implement A.central_projection(B) which computes B ∨ (A ∧ B★)
2081            let op_trait = quote! { CentralProjection };
2082            let op_fn = Ident::new("central_projection", Span::call_site());
2083            let central_projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2084                // Compute second half of B ∨ (A ∧ B★)
2085                // Where i maps to B, and j maps to A.
2086                // In part 2, we will compute the geometric product of B with this intermediate result
2087                bulk_expansion_f(coef_j, j, coef_i, i)
2088            };
2089            let central_projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2090                // Compute second half of B ∨ (A ∧ B★)
2091                // In part 1, we computed the intermediate result A ∧ B★ which maps to i here.
2092                // j maps to B.
2093                anti_wedge_product_f(coef_j, j, coef_i, i)
2094            };
2095            let central_projection_code = gen_binary_operator(
2096                basis_element_count,
2097                &objects,
2098                op_trait,
2099                op_fn,
2100                &obj,
2101                |a, b| {
2102                    generate_symbolic_double_product(
2103                        a,
2104                        b,
2105                        central_projection_product_1,
2106                        central_projection_product_2,
2107                    )
2108                },
2109                false, // implicit_promotion_to_compound
2110                None, // alias
2111            );
2112
2113            // Implement A.central_anti_projection(B) which computes B ∧ (A ∨ B★)
2114            let op_trait = quote! { CentralAntiProjection };
2115            let op_fn = Ident::new("central_anti_projection", Span::call_site());
2116            let central_anti_projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2117                // Compute second half of B ∧ (A ∨ B★)
2118                // Where i maps to B, and j maps to A.
2119                // In part 2, we will compute the geometric product of B with this intermediate result
2120                bulk_contraction_f(coef_j, j, coef_i, i)
2121            };
2122            let central_anti_projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2123                // Compute second half of B ∧ (A ∨ B★)
2124                // In part 1, we computed the intermediate result A ∨ B★ which maps to i here.
2125                // j maps to B.
2126                wedge_product_f(coef_j, j, coef_i, i)
2127            };
2128            let central_anti_projection_code = gen_binary_operator(
2129                basis_element_count,
2130                &objects,
2131                op_trait,
2132                op_fn,
2133                &obj,
2134                |a, b| {
2135                    generate_symbolic_double_product(
2136                        a,
2137                        b,
2138                        central_anti_projection_product_1,
2139                        central_anti_projection_product_2,
2140                    )
2141                },
2142                false, // implicit_promotion_to_compound
2143                None, // alias
2144            );
2145
2146            // Implement motor_to which computes A̰ ⟇ B
2147            let op_trait = quote! { MotorTo };
2148            let op_fn = Ident::new("motor_to", Span::call_site());
2149            let motor_to_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2150                let (coef_i, i) = anti_reverse_f(coef_i, i);
2151                geometric_antiproduct_f(coef_i, i, coef_j, j)
2152            };
2153
2154            let motor_to_code = gen_binary_operator(
2155                basis_element_count,
2156                &objects,
2157                op_trait,
2158                op_fn,
2159                &obj,
2160                |a, b| generate_symbolic_product(a, b, motor_to_f),
2161                true, // implicit_promotion_to_compound
2162                None, // alias
2163            );
2164
2165            quote! {
2166                // ===========================================================================
2167                // #name
2168                // ===========================================================================
2169
2170                #from_code
2171                #anti_abs_code
2172                #anti_recip_code
2173                #anti_sqrt_code
2174                #anti_trig_code
2175                #neg_code
2176                #reverse_code
2177                #anti_reverse_code
2178                #bulk_code
2179                #weight_code
2180                #bulk_dual_code
2181                #weight_dual_code
2182                #right_complement_code
2183                #left_complement_code
2184                #bulk_norm_squared_code
2185                #bulk_norm_code
2186                #weight_norm_squared_code
2187                #weight_norm_code
2188                #hat_code
2189                #add_code
2190                #sub_code
2191                #wedge_product_code
2192                #anti_wedge_product_code
2193                #dot_product_code
2194                #anti_dot_product_code
2195                #wedge_dot_product_code
2196                #anti_wedge_dot_product_code
2197                #scalar_product_code
2198                #anti_scalar_product_code
2199                #commutator_product_code
2200                #bulk_expansion_code
2201                #weight_expansion_code
2202                #bulk_contraction_code
2203                #weight_contraction_code
2204                #projection_code
2205                #anti_projection_code
2206                #central_projection_code
2207                #central_anti_projection_code
2208                #transform_code
2209                #reverse_transform_code
2210                #motor_to_code
2211            }
2212        })
2213        .collect();
2214
2215    Ok(impl_code)
2216}
2217
2218/// A wrapper around Vec<syn::Item> so we can implement Parse on it
2219struct VecItem(Vec<Item>);
2220
2221impl Parse for VecItem {
2222    fn parse(input: ParseStream) -> Result<Self> {
2223        let mut items = Vec::<Item>::new();
2224        while !input.is_empty() {
2225            items.push(input.parse()?);
2226        }
2227        Ok(VecItem(items))
2228    }
2229}
2230
2231struct BasisVectorIdents(Vec<Ident>);
2232
2233impl Parse for BasisVectorIdents {
2234    fn parse(input: ParseStream) -> Result<Self> {
2235        let ident_list = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;
2236        Ok(BasisVectorIdents(ident_list.into_iter().collect()))
2237    }
2238}
2239
2240struct Metric(Vec<isize>);
2241
2242impl Parse for Metric {
2243    fn parse(input: ParseStream) -> Result<Self> {
2244        let lit_list = Punctuated::<LitInt, Token![,]>::parse_terminated(input)?;
2245        Ok(Metric(
2246            lit_list
2247                .into_iter()
2248                .map(|lit| lit.base10_parse::<isize>())
2249                .collect::<Result<Vec<_>>>()?,
2250        ))
2251    }
2252}
2253
2254fn geometric_algebra2(code: TokenStream) -> Result<TokenStream> {
2255    let VecItem(mut items) = parse2(code)?;
2256
2257    let mut metric = None;
2258    let mut basis = None;
2259    let mut multivector_structs = Vec::<MultivectorStruct>::new();
2260
2261    // Idents to match when parsing
2262    let multivector_ident = Ident::new("multivector", Span::call_site());
2263    let metric_ident = Ident::new("metric", Span::call_site());
2264    let basis_ident = Ident::new("basis", Span::call_site());
2265
2266    let mut err: Result<()> = Ok(());
2267
2268    let mut append_err = |new_e: Error| {
2269        err = match &mut err {
2270            Ok(()) => Err(new_e),
2271            Err(old_e) => {
2272                old_e.combine(new_e);
2273                Err(old_e.clone())
2274            }
2275        };
2276    };
2277
2278    items.retain_mut(|item| {
2279        match item {
2280            Item::Macro(ItemMacro {
2281                mac: item_macro, ..
2282            }) => {
2283                let macro_ident = item_macro.path.get_ident();
2284                if macro_ident == Some(&basis_ident) {
2285                    if basis.is_some() {
2286                        append_err(Error::new(
2287                            macro_ident.unwrap().span(),
2288                            "Duplicate basis definition",
2289                        ));
2290                    } else {
2291                        // Parse the basis as a bracket-enclosed comma-separated list of identifiers
2292                        let parsed_basis: Result<BasisVectorIdents> = item_macro.parse_body();
2293                        match parsed_basis {
2294                            Ok(BasisVectorIdents(parsed_basis)) => {
2295                                basis = Some(parsed_basis);
2296                            }
2297                            Err(e) => {
2298                                append_err(e);
2299                            }
2300                        }
2301                    }
2302                    false // Do not retain the basis! macro
2303                } else if macro_ident == Some(&metric_ident) {
2304                    if metric.is_some() {
2305                        append_err(Error::new(
2306                            macro_ident.unwrap().span(),
2307                            "Duplicate metric definition",
2308                        ));
2309                    } else {
2310                        // Parse the metric as a bracket-enclosed comma-separated list of identifiers
2311                        let parsed_metric: Result<Metric> = item_macro.parse_body();
2312                        match parsed_metric {
2313                            Ok(Metric(parsed_metric)) => {
2314                                metric = Some(parsed_metric);
2315                            }
2316                            Err(e) => {
2317                                append_err(e);
2318                            }
2319                        }
2320                    }
2321                    false // Do not retain the metric! macro
2322                } else {
2323                    true // Retain unrecognized macros
2324                }
2325            }
2326            Item::Struct(item_struct) => {
2327                let mut has_multivector_attribute = false;
2328
2329                item_struct.attrs.retain(|attr| {
2330                    if let Attribute {
2331                        style: AttrStyle::Outer,
2332                        meta: Meta::Path(Path { segments, .. }),
2333                        ..
2334                    } = attr
2335                    {
2336                        // Do not retain the #[multivector] attribute
2337                        // but make note that we found it
2338                        if segments.len() == 1
2339                            && segments.first().unwrap().ident == multivector_ident
2340                        {
2341                            has_multivector_attribute = true;
2342                            false
2343                        } else {
2344                            true
2345                        }
2346                    } else {
2347                        true
2348                    }
2349                });
2350
2351                if has_multivector_attribute {
2352                    // This struct is a multivector!
2353                    // Parse its fields and add it to the struct_info list
2354                    if let Fields::Named(FieldsNamed { named: fields, .. }) = &item_struct.fields {
2355                        // TODO check types
2356                        match fields
2357                            .iter()
2358                            .map(|field| Ok(field.ident.as_ref().unwrap().clone()))
2359                            .collect::<Result<Vec<_>>>()
2360                        {
2361                            Ok(components) => {
2362                                multivector_structs.push(MultivectorStruct {
2363                                    ident: item_struct.ident.clone(),
2364                                    components,
2365                                });
2366                            }
2367                            Err(e) => {
2368                                append_err(e);
2369                            }
2370                        }
2371                    } else {
2372                        append_err(Error::new(
2373                            item_struct.ident.span(),
2374                            "Multivector must have named fields",
2375                        ));
2376                    }
2377                }
2378
2379                true // Retain structs
2380            }
2381            _ => {
2382                true // Retain unrecognized items
2383            }
2384        }
2385    });
2386    err?;
2387
2388    let Some(basis) = basis else {
2389        return Err(Error::new(
2390            Span::call_site(),
2391            "Missing basis![..] definition",
2392        ));
2393    };
2394
2395    let Some(metric) = metric else {
2396        return Err(Error::new(
2397            Span::call_site(),
2398            "Missing metric![..] definition",
2399        ));
2400    };
2401
2402    // Generate code
2403    let generated_code = implement_geometric_algebra(basis, metric, multivector_structs)?;
2404
2405    // Add generated code to original code
2406    let mut code = TokenStream::new();
2407    code.append_all(items);
2408    generated_code.to_tokens(&mut code);
2409    Ok(code)
2410}
2411
2412#[proc_macro]
2413pub fn geometric_algebra(code: proc_macro::TokenStream) -> proc_macro::TokenStream {
2414    // Parse input
2415    geometric_algebra2(code.into())
2416        .unwrap_or_else(Error::into_compile_error)
2417        .into()
2418}
2419
2420#[cfg(test)]
2421mod tests {
2422    use super::*;
2423
2424    #[test]
2425    fn derive_geometric_algebra() {
2426        let _result = geometric_algebra2(quote! {
2427            basis![w, x, y];
2428            metric![0, 1, 1];
2429            #[multivector]
2430            struct Vector<T> {
2431                x: T,
2432                y: T,
2433                w: T,
2434            }
2435            #[multivector]
2436            struct Bivector<T> {
2437                wx: T,
2438                wy: T,
2439                xy: T,
2440            }
2441            #[multivector]
2442            struct AntiScalar<T> {
2443                wxy: T,
2444            }
2445            #[multivector]
2446            struct AntiEven<T> {
2447                a: T,
2448                wx: T,
2449                wy: T,
2450                xy: T,
2451            }
2452            #[multivector]
2453            struct AntiOdd<T> {
2454                x: T,
2455                y: T,
2456                w: T,
2457                wxy: T,
2458            }
2459            #[multivector]
2460            struct Multivector<T> {
2461                a: T,
2462                x: T,
2463                y: T,
2464                w: T,
2465                wx: T,
2466                wy: T,
2467                xy: T,
2468                wxy: T,
2469            }
2470        })
2471        .unwrap();
2472        //println!("{}", result);
2473        //panic!();
2474    }
2475}