1extern crate proc_macro;
2use core::cmp::Ordering;
3use core::ops::{AddAssign, Mul};
4use proc_macro2::Span;
5use proc_macro2::TokenStream;
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::parse::{Error, Parse, ParseStream, Result};
8use syn::punctuated::Punctuated;
9use syn::{
10 parse2, AttrStyle, Attribute, Fields, FieldsNamed, Ident, Item, ItemMacro, LitInt, Meta, Path,
11 Token,
12};
13
14struct MultivectorStruct {
15 ident: Ident,
16 components: Vec<Ident>,
17}
18
19fn bubble_sort_count_swaps(l: &mut [usize]) -> usize {
20 let mut swaps: usize = 0;
21 for i in (0..l.len()).rev() {
22 for j in 0..i {
23 if l[j] > l[j + 1] {
24 (l[j], l[j + 1]) = (l[j + 1], l[j]);
25 swaps += 1
26 }
27 }
28 }
29 swaps
30}
31
32fn sign_from_parity(swaps: usize) -> isize {
33 match swaps % 2 {
34 0 => 1,
35 1 => -1,
36 _ => panic!("Expected parity to be 0 or 1"),
37 }
38}
39
40#[derive(Default, Clone)]
41struct SymbolicSumExpr(Vec<SymbolicProdExpr>);
42
43#[derive(PartialEq, Eq, Clone)]
44struct SymbolicProdExpr(isize, Vec<Symbol>);
45
46#[derive(PartialOrd, Ord, PartialEq, Eq, Clone)]
47enum Symbol {
48 Scalar(Ident),
49 StructField(Ident, Ident), }
51
52impl ToTokens for Symbol {
53 fn to_tokens(&self, tokens: &mut TokenStream) {
54 match self {
55 Symbol::StructField(var, field) => {
56 var.to_tokens(tokens);
57 <Token![.]>::default().to_tokens(tokens);
58 field.to_tokens(tokens);
59 }
60 Symbol::Scalar(var) => {
61 var.to_tokens(tokens);
62 }
63 }
64 }
65}
66
67impl PartialOrd for SymbolicProdExpr {
68 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
69 Some(self.cmp(other))
70 }
71}
72
73impl Ord for SymbolicProdExpr {
74 fn cmp(&self, SymbolicProdExpr(other_coef, other_symbols): &Self) -> Ordering {
75 let SymbolicProdExpr(self_coef, self_symbols) = self;
76 self_symbols
77 .cmp(&other_symbols)
78 .then_with(|| self_coef.cmp(&other_coef))
79 }
80}
81
82impl Mul<SymbolicProdExpr> for SymbolicProdExpr {
83 type Output = SymbolicProdExpr;
84 fn mul(mut self, SymbolicProdExpr(r_coef, mut r_symbols): Self) -> SymbolicProdExpr {
85 let SymbolicProdExpr(l_coef, l_symbols) = &mut self;
86 *l_coef *= r_coef;
87 l_symbols.append(&mut r_symbols);
88 self
89 }
90}
91
92impl Mul<isize> for SymbolicProdExpr {
93 type Output = SymbolicProdExpr;
94 fn mul(mut self, r: isize) -> SymbolicProdExpr {
95 let SymbolicProdExpr(l_coef, _) = &mut self;
96 *l_coef *= r;
97 self
98 }
99}
100
101impl SymbolicProdExpr {
102 fn simplify(mut self) -> Self {
103 let SymbolicProdExpr(coef, symbols) = &mut self;
104 if *coef == 0 {
106 symbols.clear();
107 } else {
108 symbols.sort();
109 }
110 self
111 }
112}
113
114impl ToTokens for SymbolicSumExpr {
115 fn to_tokens(&self, tokens: &mut TokenStream) {
116 let SymbolicSumExpr(terms) = self;
117 if terms.len() == 0 {
118 tokens.append_all(quote! { T::default() });
119 } else {
120 for (count, prod_expr) in terms.iter().enumerate() {
121 let SymbolicProdExpr(coef, prod_terms) = prod_expr;
122 let coef = *coef;
123
124 if coef >= 0 {
125 if count != 0 {
126 tokens.append_all(quote! { + });
127 }
128 } else {
129 tokens.append_all(quote! { - });
130 }
131 let coef = coef.abs();
132
133 if prod_terms.len() == 0 {
134 if coef == 0 {
136 tokens.append_all(quote! { T::default() });
137 } else if coef == 1 {
138 tokens.append_all(quote! {
139 T::one()
140 });
141 } else {
142 panic!("Scalar was not 0, -1 or 1");
143 }
144 } else {
145 if coef == 0 {
147 tokens.append_all(quote! { T::default() * });
148 } else if coef == 1 {
149 } else if coef == 2 {
151 tokens.append_all(quote! { (T::one() + T::one()) * });
152 } else {
153 panic!("No representation for large coefficient {}", coef);
154 }
155 for (sym_count, sym) in prod_terms.iter().enumerate() {
156 if sym_count > 0 {
157 tokens.append_all(quote! { * });
158 }
159 sym.to_tokens(tokens);
160 }
161 }
162 }
163 }
164 }
165}
166
167impl SymbolicSumExpr {
168 fn simplify(self) -> Self {
169 let SymbolicSumExpr(terms) = self;
170
171 let mut terms: Vec<_> = terms.into_iter().map(|prod| prod.simplify()).collect();
173
174 terms.sort();
176
177 let mut new_expression = vec![];
179 let mut prev_coef = 0;
180 let mut prev_symbols = vec![];
181 for SymbolicProdExpr(coef, symbols) in terms.into_iter() {
182 if prev_symbols == symbols {
183 prev_coef += coef;
184 } else {
185 new_expression.push(SymbolicProdExpr(prev_coef, prev_symbols));
186 prev_coef = coef;
187 prev_symbols = symbols;
188 }
189 }
190 new_expression.push(SymbolicProdExpr(prev_coef, prev_symbols));
191
192 let mut terms = new_expression;
193
194 terms.retain(|SymbolicProdExpr(coef, _)| *coef != 0);
196
197 SymbolicSumExpr(terms)
198 }
199}
200
201impl AddAssign<SymbolicProdExpr> for SymbolicSumExpr {
202 fn add_assign(&mut self, r_term: SymbolicProdExpr) {
203 let SymbolicSumExpr(l_terms) = self;
204 l_terms.push(r_term);
205 }
206}
207
208impl AddAssign<SymbolicSumExpr> for SymbolicSumExpr {
209 fn add_assign(&mut self, SymbolicSumExpr(mut r_terms): SymbolicSumExpr) {
210 let SymbolicSumExpr(l_terms) = self;
211 l_terms.append(&mut r_terms);
212 }
213}
214
215impl Mul<SymbolicProdExpr> for SymbolicSumExpr {
216 type Output = SymbolicSumExpr;
217 fn mul(self, r: SymbolicProdExpr) -> SymbolicSumExpr {
218 let SymbolicSumExpr(l) = self;
219 SymbolicSumExpr(l.into_iter().map(|lp| lp * r.clone()).collect())
220 }
221}
222
223impl Mul<isize> for SymbolicSumExpr {
224 type Output = SymbolicSumExpr;
225 fn mul(self, r: isize) -> SymbolicSumExpr {
226 let SymbolicSumExpr(l) = self;
227 SymbolicSumExpr(l.into_iter().map(|lp| lp * r).collect())
228 }
229}
230
231fn right_complement(right_complement_signs: &Vec<isize>, coef: isize, i: usize) -> (isize, usize) {
234 let complement_ix = right_complement_signs.len() - i - 1;
235 (coef * right_complement_signs[i], complement_ix)
236}
237
238fn left_complement(right_complement_signs: &Vec<isize>, coef: isize, i: usize) -> (isize, usize) {
242 let complement_ix = right_complement_signs.len() - i - 1;
243 (coef * right_complement_signs[complement_ix], complement_ix)
244}
245
246#[derive(PartialEq)]
252enum Object {
253 Scalar,
254 Struct(StructObject),
255}
256
257#[derive(PartialEq)]
258struct StructObject {
259 name: Ident,
260 select_components: Vec<Option<(Ident, isize)>>,
261 is_compound: bool,
262}
263
264impl Object {
265 fn type_name(&self) -> TokenStream {
266 match self {
267 Object::Scalar => quote! { T },
268 Object::Struct(StructObject { name, .. }) => {
269 quote! { #name < T > }
270 }
271 }
272 }
273 fn type_name_colons(&self) -> TokenStream {
274 match self {
275 Object::Scalar => quote! { T },
276 Object::Struct(StructObject { name, .. }) => {
277 quote! { #name :: < T > }
278 }
279 }
280 }
281 fn has_component(&self, i: usize) -> bool {
282 match self {
283 Object::Scalar => i == 0,
284 Object::Struct(StructObject {
285 select_components, ..
286 }) => select_components[i].is_some(),
287 }
288 }
289 fn is_compound(&self) -> bool {
290 match self {
291 Object::Scalar => false,
292 Object::Struct(StructObject { is_compound, .. }) => *is_compound,
293 }
294 }
295 fn select_components(&self, var: Ident, len: usize) -> Vec<Option<(Symbol, isize)>> {
296 match self {
297 Object::Scalar => {
298 let mut result = vec![None; len];
299 result[0] = Some((Symbol::Scalar(var), 1));
300 result
301 }
302 Object::Struct(StructObject {
303 select_components, ..
304 }) => select_components
305 .iter()
306 .map(|select_component| {
307 select_component.as_ref().map(|(field, coef)| {
308 (Symbol::StructField(var.clone(), field.clone()), *coef)
309 })
310 })
311 .collect(),
312 }
313 }
314}
315
316fn generate_symbolic_rearrangement<F: Fn(isize, usize) -> (isize, usize)>(
322 select_components: &[Option<(Symbol, isize)>],
323 op: F,
324) -> Vec<SymbolicSumExpr> {
325 let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components.len()];
327
328 for (i, is_selected) in select_components.iter().enumerate() {
329 if let Some((symbol, coef)) = is_selected {
330 let (coef, result_basis_ix) = op(*coef, i);
331 result[result_basis_ix] += SymbolicProdExpr(coef, vec![symbol.clone()]);
332 }
333 }
334
335 result.into_iter().map(|expr| expr.simplify()).collect()
336}
337
338fn generate_symbolic_norm<F: Fn(isize, usize, isize, usize) -> (isize, usize)>(
339 select_components: &[Option<(Symbol, isize)>],
340 product: F,
341 sqrt: bool,
342) -> Vec<SymbolicSumExpr> {
343 let mut expressions: Vec<SymbolicSumExpr> = vec![Default::default(); select_components.len()];
345 for (i, (i_symbol, i_coef)) in select_components
346 .iter()
347 .enumerate()
348 .filter_map(|(i, selected)| selected.as_ref().map(|selected| (i, selected)))
349 {
350 for (j, (j_symbol, j_coef)) in select_components
351 .iter()
352 .enumerate()
353 .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
354 {
355 let (coef, ix) = product(*i_coef, i, *j_coef, j);
356
357 expressions[ix] += SymbolicProdExpr(coef, vec![i_symbol.clone(), j_symbol.clone()]);
358 }
359 }
360
361 let expressions: Vec<_> = expressions
362 .into_iter()
363 .map(|expr| expr.simplify())
364 .collect();
365
366 if sqrt {
367 {
370 let is_scalar = expressions
371 .iter()
372 .enumerate()
373 .all(|(i, expr)| i == 0 || expr.0.len() == 0);
374 let is_anti_scalar = expressions
375 .iter()
376 .enumerate()
377 .all(|(i, expr)| i == expressions.len() - 1 || expr.0.len() == 0);
378 let expression = if is_scalar {
379 Some(expressions[0].clone())
380 } else if is_anti_scalar {
381 Some(expressions[expressions.len() - 1].clone())
382 } else {
383 None
384 };
385 if let Some(expression) = expression {
386 let SymbolicSumExpr(terms) = &expression;
387 if terms.len() == 1 {
388 let SymbolicProdExpr(coef, terms) = &terms[0];
389 if *coef == 1 && terms.len() == 2 && terms[0] == terms[1] {
390 let sqrt_expression =
391 SymbolicSumExpr(vec![SymbolicProdExpr(1, vec![terms[0].clone()])]);
392 let target_ix = if is_scalar {
393 0
394 } else if is_anti_scalar {
395 expressions.len() - 1
396 } else {
397 panic!("Took sqrt of something that wasn't a scalar or antiscalar");
398 };
399 Some(
400 (0..select_components.len())
401 .map(|i| {
402 if i == target_ix {
403 sqrt_expression.clone()
404 } else {
405 Default::default()
406 }
407 })
408 .collect(),
409 )
410 } else {
411 None }
413 } else {
414 None }
416 } else {
417 None }
419 }
420 .unwrap_or(vec![Default::default(); select_components.len()])
421 } else {
422 expressions
424 }
425}
426
427fn generate_symbolic_sum(
431 select_components_a: &[Option<(Symbol, isize)>],
432 select_components_b: &[Option<(Symbol, isize)>],
433 coef_a: isize,
434 coef_b: isize,
435) -> Vec<SymbolicSumExpr> {
436 let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components_a.len()];
438
439 for (i, (a_selected, b_selected)) in select_components_a
440 .iter()
441 .zip(select_components_b.iter())
442 .enumerate()
443 {
444 if let Some((symbol_a, coef_symbol_a)) = a_selected {
445 result[i] += SymbolicProdExpr(coef_symbol_a * coef_a, vec![symbol_a.clone()]);
446 }
447 if let Some((symbol_b, coef_symbol_b)) = b_selected {
448 result[i] += SymbolicProdExpr(coef_symbol_b * coef_b, vec![symbol_b.clone()]);
449 }
450 }
451
452 result.into_iter().map(|expr| expr.simplify()).collect()
453}
454
455fn generate_symbolic_product<F: Fn(isize, usize, isize, usize) -> (isize, usize)>(
459 select_components_a: &[Option<(Symbol, isize)>],
460 select_components_b: &[Option<(Symbol, isize)>],
461 product: F,
462) -> Vec<SymbolicSumExpr> {
463 let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components_a.len()];
465 for (i, (i_symbol, i_coef)) in select_components_a
466 .iter()
467 .enumerate()
468 .filter_map(|(i, selected)| selected.as_ref().map(|selected| (i, selected)))
469 {
470 for (j, (j_symbol, j_coef)) in select_components_b
471 .iter()
472 .enumerate()
473 .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
474 {
475 let (coef, ix) = product(*i_coef, i, *j_coef, j);
476
477 result[ix] += SymbolicProdExpr(coef, vec![i_symbol.clone(), j_symbol.clone()]);
478 }
479 }
480
481 result.into_iter().map(|expr| expr.simplify()).collect()
482}
483
484fn generate_symbolic_double_product<
491 F1: Fn(isize, usize, isize, usize) -> (isize, usize),
492 F2: Fn(isize, usize, isize, usize) -> (isize, usize),
493>(
494 select_components_a: &[Option<(Symbol, isize)>],
495 select_components_b: &[Option<(Symbol, isize)>],
496 product_1: F1,
497 product_2: F2,
498) -> Vec<SymbolicSumExpr> {
499 let mut intermediate_result: Vec<SymbolicSumExpr> =
502 vec![Default::default(); select_components_a.len()];
503 for (i, (i_symbol, i_coef)) in select_components_b
504 .iter()
505 .enumerate()
506 .filter_map(|(i, selected)| selected.as_ref().map(|selected| (i, selected)))
507 {
508 for (j, (j_symbol, j_coef)) in select_components_a
509 .iter()
510 .enumerate()
511 .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
512 {
513 let (coef, ix) = product_1(*i_coef, i, *j_coef, j);
514 intermediate_result[ix] +=
515 SymbolicProdExpr(coef, vec![i_symbol.clone(), j_symbol.clone()]);
516 }
517 }
518 let intermediate_result: Vec<_> = intermediate_result
519 .into_iter()
520 .map(|expr| expr.simplify())
521 .collect();
522
523 let mut result: Vec<SymbolicSumExpr> = vec![Default::default(); select_components_a.len()];
527 for (i, intermediate_term) in intermediate_result.iter().enumerate() {
528 for (j, (j_symbol, j_coef)) in select_components_b
529 .iter()
530 .enumerate()
531 .filter_map(|(j, selected)| selected.as_ref().map(|selected| (j, selected)))
532 {
533 let (coef, ix) = product_2(1, i, *j_coef, j);
534 let new_term = SymbolicProdExpr(coef, vec![j_symbol.clone()]);
535 let result_term = intermediate_term.clone() * new_term;
536 result[ix] += result_term;
537 }
538 }
539
540 result.into_iter().map(|expr| expr.simplify()).collect()
541}
542
543fn find_output_object<'a>(
544 objects: &'a [Object],
545 output_expressions: &[SymbolicSumExpr],
546) -> Option<&'a Object> {
547 let select_output_components: Vec<_> = output_expressions
548 .iter()
549 .map(|SymbolicSumExpr(e)| e.len() != 0)
550 .collect();
551 objects.iter().find(|o| {
552 select_output_components
553 .iter()
554 .enumerate()
555 .find(|&(i, &out_c)| out_c && !o.has_component(i))
556 .is_none()
557 })
558}
559
560fn gen_unary_operator(
561 objects: &[Object],
562 op_trait: TokenStream,
563 op_fn: Ident,
564 obj: &Object,
565 expressions: &[SymbolicSumExpr],
566 alias: Option<(TokenStream, Ident)>,
567) -> TokenStream {
568 if matches!(obj, Object::Scalar) {
569 return quote! {};
573 };
574
575 let output_object = find_output_object(&objects, &expressions);
577
578 let Some(output_object) = output_object else {
579 return quote! {};
582 };
583
584 if matches!(output_object, Object::Scalar) && expressions[0].0.len() == 0 {
585 return quote! {};
589 }
590
591 let output_type_name = &output_object.type_name_colons();
592 let type_name = &obj.type_name();
593
594 let return_expr = match output_object {
595 Object::Scalar => {
596 let expr = &expressions[0];
597 quote! { #expr }
598 }
599 Object::Struct(output_struct_object) => {
600 let output_fields: TokenStream = output_struct_object
601 .select_components
602 .iter()
603 .zip(expressions.iter())
604 .map(|(select_component, expr)| {
605 if let Some((field, coef)) = select_component {
606 let expr = expr.clone() * *coef;
607 quote! { #field: #expr, }
608 } else {
609 return quote! {};
610 }
611 })
612 .collect();
613 quote! {
614 #output_type_name {
615 #output_fields
616 }
617 }
618 }
619 };
620
621 let associated_output_type = quote! { type Output = #output_type_name; };
622
623 let code = quote! {
624 impl < T: Ring > #op_trait for #type_name {
625 #associated_output_type
626
627 fn #op_fn (self) -> #output_type_name {
628 #return_expr
629 }
630 }
631 };
632
633 let alias_code = if let Some((alias_trait, alias_fn)) = alias {
634 quote! {
635 impl < T: Ring > #alias_trait for #type_name {
636 #associated_output_type
637
638 fn #alias_fn (self) -> #output_type_name {
639 self.#op_fn()
640 }
641 }
642 }
643 } else {
644 quote! {}
645 };
646
647 quote! {
648 #code
649 #alias_code
650 }
651}
652
653fn gen_binary_operator<
654 F: Fn(&[Option<(Symbol, isize)>], &[Option<(Symbol, isize)>]) -> Vec<SymbolicSumExpr>,
655>(
656 basis_element_count: usize,
657 objects: &[Object],
658 op_trait: TokenStream,
659 op_fn: Ident,
660 lhs_obj: &Object,
661 op: F,
662 implicit_promotion_to_compound: bool,
663 alias: Option<(TokenStream, Ident)>,
664) -> TokenStream {
665 objects
666 .iter()
667 .map(|rhs_obj| {
668 if matches!(lhs_obj, Object::Scalar) {
669 return quote! {};
676 };
677
678 let expressions = op(
679 &lhs_obj
680 .select_components(Ident::new("self", Span::call_site()), basis_element_count),
681 &rhs_obj.select_components(Ident::new("r", Span::call_site()), basis_element_count),
682 );
683
684 let output_object = find_output_object(&objects, &expressions);
686
687 let Some(output_object) = output_object else {
688 return quote! {};
691 };
692
693 let rhs_type_name = &rhs_obj.type_name();
694 let lhs_type_name = &lhs_obj.type_name();
695 let output_type_name = &output_object.type_name_colons();
696
697 if !implicit_promotion_to_compound
698 && output_object.is_compound()
699 && !(lhs_obj.is_compound() || rhs_obj.is_compound())
700 {
701 return quote! {};
705 }
706
707 if matches!(output_object, Object::Scalar) && expressions[0].0.len() == 0 {
708 return quote! {};
713 }
714
715 let return_expr = match output_object {
716 Object::Scalar => {
717 let expr = &expressions[0];
718 quote! { #expr }
719 }
720 Object::Struct(output_struct_object) => {
721 let output_fields: TokenStream = output_struct_object
722 .select_components
723 .iter()
724 .zip(expressions.iter())
725 .map(|(select_component, expr)| {
726 if let Some((field, coef)) = select_component {
727 let expr = expr.clone() * *coef;
728 quote! { #field: #expr, }
729 } else {
730 return quote! {};
731 }
732 })
733 .collect();
734 quote! {
735 #output_type_name {
736 #output_fields
737 }
738 }
739 }
740 };
741
742 let associated_output_type = quote! { type Output = #output_type_name; };
743
744 let code = quote! {
745 impl < T: Ring > #op_trait < #rhs_type_name > for #lhs_type_name {
746 #associated_output_type
747
748 fn #op_fn (self, r: #rhs_type_name) -> #output_type_name {
749 #return_expr
750 }
751 }
752 };
753
754 let alias_code = if let Some((alias_trait, alias_fn)) = &alias {
755 quote! {
756 impl < T: Ring > #alias_trait < #rhs_type_name > for #lhs_type_name {
757 #associated_output_type
758
759 fn #alias_fn (self, r: #rhs_type_name) -> #output_type_name {
760 self.#op_fn(r)
761 }
762 }
763 }
764 } else {
765 quote! {}
766 };
767
768 quote! {
769 #code
770 #alias_code
771 }
772 })
773 .collect()
774}
775
776fn gen_antiscalar_operator(
777 basis_element_count: usize,
778 op_trait: Ident,
779 anti_op_trait: Ident,
780 op_fns: &[(Ident, Ident)],
781 obj: &Object,
782) -> TokenStream {
783 let is_antiscalar = (0..basis_element_count).all(|i| {
784 if i == basis_element_count - 1 {
785 obj.has_component(i)
786 } else {
787 !obj.has_component(i)
788 }
789 });
790
791 if !is_antiscalar {
792 return quote! {};
794 }
795
796 let Object::Struct(struct_obj) = obj else {
797 return quote! {};
802 };
803
804 let (field, coef) = struct_obj.select_components[basis_element_count - 1]
805 .as_ref()
806 .unwrap();
807
808 if *coef != 1 {
809 return quote! {};
814 }
815
816 let field_expr = quote! {
817 self . #field
818 };
819
820 let type_name = &obj.type_name();
821 let struct_name = struct_obj.name.clone();
822
823 let functions_code: TokenStream = op_fns
824 .iter()
825 .map(|(fn_ident, anti_fn_ident)| {
826 quote! {
827 fn #anti_fn_ident (self) -> Self::Output {
828 #struct_name {
829 #field: #field_expr . #fn_ident ()
830 }
831 }
832 }
833 })
834 .collect();
835
836 quote! {
837 impl < T: #op_trait > #anti_op_trait for #type_name {
838 type Output = #struct_name < < T as #op_trait >::Output >;
839
840 #functions_code
841 }
842 }
843}
844
845fn implement_geometric_algebra(
846 basis_vector_idents: Vec<Ident>,
847 metric: Vec<isize>,
848 multivector_structs: Vec<MultivectorStruct>,
849) -> Result<TokenStream> {
850 if basis_vector_idents.len() == 0 {
852 return Err(Error::new(Span::call_site(), "Basis vector set is empty"));
853 }
854 if basis_vector_idents.len() != metric.len() {
855 return Err(Error::new(
856 Span::call_site(),
857 "Metric and basis are different sizes",
858 ));
859 }
860 if multivector_structs.len() == 0 {
861 return Err(Error::new(
862 Span::call_site(),
863 "No multivector structs defined",
864 ));
865 }
866
867 let dimension = metric.len();
870
871 let basis = {
872 let mut basis_km1_vectors = vec![vec![]];
883 let mut basis = vec![vec![]];
884
885 for _ in 1..=dimension {
886 let mut basis_k_vectors = vec![];
887 for b1 in basis_km1_vectors {
888 for be2 in 0..dimension {
889 match b1.binary_search(&be2) {
891 Ok(_) => {}
892 Err(pos) => {
893 let mut b = b1.clone();
894 b.insert(pos, be2);
895 if !basis_k_vectors.contains(&b) {
896 basis_k_vectors.push(b.clone());
897 basis.push(b);
898 }
899 }
900 }
901 }
902 }
903 basis_km1_vectors = basis_k_vectors;
904 }
905
906 basis
909 };
910
911 let basis_element_count = basis.len();
912
913 let (ident_prefix, ident_variants) = {
916 let first = basis_vector_idents
917 .first()
918 .map(|x| x.to_string())
919 .unwrap_or_default();
920
921 let len = first.chars().count();
922 assert!(len >= 1, "Identifier should be non-zero length");
923 let prefix = first.chars().take(len - 1).collect::<String>();
924
925 let variants = basis_vector_idents
926 .iter()
927 .map(|basis_ident| {
928 let basis_ident_str = basis_ident.to_string();
929 if basis_ident_str.chars().count() != len || !basis_ident_str.starts_with(&prefix) {
930 return Err(Error::new(
931 basis_ident.span(),
932 "Bad identifier name: must be common prefix + 1 char",
933 ));
934 }
935
936 Ok(basis_ident_str.chars().nth(len - 1).unwrap())
937 })
938 .collect::<Result<Vec<_>>>()?;
939 (prefix, variants)
940 };
941 let ident_prefix_len = ident_prefix.chars().count();
942
943 let mut objects = vec![Object::Scalar];
945 let struct_objects = multivector_structs
946 .iter()
947 .map(|multivector_struct| {
948 let name = multivector_struct.ident.clone();
949 let mut select_components = vec![None; basis_element_count];
950
951 for component_ident in multivector_struct.components.iter() {
953 let component_ident_str = component_ident.to_string();
954
955 if component_ident_str.chars().count() <= ident_prefix_len
956 || !component_ident_str.starts_with(&ident_prefix)
957 {
958 return Err(Error::new(
959 component_ident.span(),
960 "Bad identifier: must be common prefix + 1 or more chars (and there is already a scalar field)",
961 ));
962 }
963 let variant_product = component_ident_str
964 .chars()
965 .skip(ident_prefix_len)
966 .collect::<Vec<_>>();
967 let mut b = variant_product
968 .iter()
969 .map(|c| {
970 ident_variants
971 .iter()
972 .position(|c2| c2 == c)
973 .ok_or(Error::new(
974 component_ident.span(),
975 "Bad identifier: must be composed of basis vectors (and there is already a scalar field)",
976 ))
977 })
978 .collect::<Result<Vec<_>>>();
979
980 if let Err(e) = b {
983 if select_components[0].is_some() {
984 return Err(e);
985 }
986 b = Ok(vec![]);
987 }
988
989 let mut b = b?;
990
991 let sign = sign_from_parity(bubble_sort_count_swaps(&mut b));
992 let basis_index = basis.iter().position(|b2| b2 == &b).ok_or(Error::new(
993 component_ident.span(),
994 "Bad identifier: cannot repeat basis vectors",
995 ))?;
996
997 if select_components[basis_index].is_some() {
998 return Err(Error::new(
999 component_ident.span(),
1000 "Bad identifier: duplicate",
1001 ));
1002 }
1003 select_components[basis_index] = Some((component_ident.clone(), sign));
1004 }
1005
1006 let grades = select_components
1008 .iter()
1009 .enumerate()
1010 .filter_map(|(i, component)| component.as_ref().map(|_| basis[i].len()))
1011 .collect::<Vec<_>>();
1012 let is_compound = grades.windows(2).any(|g| g[0] != g[1]);
1013
1014 Ok(Object::Struct(StructObject {
1015 name,
1016 select_components,
1017 is_compound,
1018 }))
1019 })
1020 .collect::<Result<Vec<_>>>()?;
1021
1022 objects.extend(struct_objects);
1023
1024 let right_complement_signs: Vec<_> = (0..basis_element_count)
1025 .map(|i| {
1026 let dual_i = basis_element_count - i - 1;
1027
1028 let mut product: Vec<usize> = basis[i]
1032 .iter()
1033 .cloned()
1034 .chain(basis[dual_i].iter().cloned())
1035 .collect();
1036 sign_from_parity(bubble_sort_count_swaps(product.as_mut()))
1037 })
1038 .collect();
1039
1040 let dot_product_multiplication_table: Vec<Vec<isize>> = {
1043 let multiply_basis_vectors = |ei: usize, ej: usize| {
1044 if ei == ej {
1047 metric[ei]
1048 } else {
1049 0
1050 }
1051 };
1052
1053 let multiply_basis_elements = |i: usize, j: usize| {
1054 let bi = &basis[i];
1058 let bj = &basis[j];
1059 if bi.len() != bj.len() {
1060 return 0;
1061 }
1062 if bi.len() == 0 {
1063 return 1; }
1065
1066 let gram_matrix: Vec<Vec<isize>> = bi
1067 .iter()
1068 .map(|&ei| {
1069 bj.iter()
1070 .map(|&ej| multiply_basis_vectors(ei, ej))
1071 .collect()
1072 })
1073 .collect();
1074
1075 fn determinant(m: &Vec<Vec<isize>>) -> isize {
1076 if m.len() == 1 {
1077 m[0][0]
1078 } else {
1079 let n = m.len();
1080 (0..n)
1081 .map(move |j| {
1082 let i = 0;
1083 let sign = match (i + j) % 2 {
1084 0 => 1,
1085 1 => -1,
1086 _ => panic!("Expected parity to be 0 or 1"),
1087 };
1088
1089 let minor: Vec<Vec<_>> = (0..n)
1090 .flat_map(|i2| {
1091 if i2 == i {
1092 None
1093 } else {
1094 Some(
1095 (0..n)
1096 .flat_map(|j2| {
1097 if j2 == j {
1098 None
1099 } else {
1100 Some(m[i2][j2])
1101 }
1102 })
1103 .collect(),
1104 )
1105 }
1106 })
1107 .collect();
1108
1109 sign * m[i][j] * determinant(&minor)
1110 })
1111 .sum()
1112 }
1113 }
1114 determinant(&gram_matrix)
1115 };
1116
1117 (0..basis_element_count)
1118 .map(|i| {
1119 (0..basis_element_count)
1120 .map(move |j| multiply_basis_elements(i, j))
1121 .collect()
1122 })
1123 .collect()
1124 };
1125
1126 let dot_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1127 let coef_mul = dot_product_multiplication_table[i][j];
1128 (coef_i * coef_j * coef_mul, 0)
1129 };
1130
1131 let anti_dot_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1132 let (i_coef, i) = right_complement(&right_complement_signs, coef_i, i);
1133 let (j_coef, j) = right_complement(&right_complement_signs, coef_j, j);
1134 let (coef, ix) = dot_product_f(i_coef, i, j_coef, j);
1135 let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1136 (coef, ix)
1137 };
1138
1139 let geometric_product_multiplication_table: Vec<Vec<(isize, usize)>> = {
1143 let multiply_basis_elements = |i: usize, j: usize| {
1144 let mut product: Vec<_> = basis[i]
1145 .iter()
1146 .cloned()
1147 .chain(basis[j].iter().cloned())
1148 .collect();
1149 let swaps = bubble_sort_count_swaps(product.as_mut());
1150 let mut coef = match swaps % 2 {
1151 0 => 1,
1152 1 => -1,
1153 _ => panic!("Expected parity to be 0 or 1"),
1154 };
1155
1156 let mut new_product = vec![];
1158 let mut prev_e = None;
1159 for e in product.into_iter() {
1160 if Some(e) == prev_e {
1161 coef *= metric[e];
1162 prev_e = None;
1163 } else {
1164 if let Some(prev_e) = prev_e {
1165 new_product.push(prev_e);
1166 }
1167 prev_e = Some(e);
1168 }
1169 }
1170 if let Some(prev_e) = prev_e {
1171 new_product.push(prev_e);
1172 }
1173
1174 basis
1176 .iter()
1177 .enumerate()
1178 .find_map(|(i, b)| {
1179 let mut b_sorted = b.clone();
1180 let swaps = bubble_sort_count_swaps(b_sorted.as_mut());
1181 (new_product == b_sorted).then(|| {
1182 let coef = coef
1183 * match swaps % 2 {
1184 0 => 1,
1185 1 => -1,
1186 _ => panic!("Expected parity to be 0 or 1"),
1187 };
1188 (coef, i)
1189 })
1190 })
1191 .expect("Product of basis elements not found in basis set")
1192 };
1193
1194 (0..basis_element_count)
1195 .map(|i| {
1196 (0..basis_element_count)
1197 .map(move |j| multiply_basis_elements(i, j))
1198 .collect()
1199 })
1200 .collect()
1201 };
1202
1203 let geometric_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1204 let (coef_mul, ix) = geometric_product_multiplication_table[i][j];
1205 (coef_i * coef_j * coef_mul, ix)
1206 };
1207
1208 let geometric_antiproduct_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1209 let (coef_i, i) = right_complement(&right_complement_signs, coef_i, i);
1210 let (coef_j, j) = right_complement(&right_complement_signs, coef_j, j);
1211 let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1212 let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1213 (coef, ix)
1214 };
1215
1216 let impl_code: TokenStream = objects
1217 .iter()
1218 .map(|obj| {
1219 let from_code: TokenStream = {
1221 if matches!(obj, Object::Scalar) {
1222 return quote! {};
1226 }
1227
1228 objects.iter().map(|other_obj| {
1229 let is_subset = (0..basis_element_count).all(|i| obj.has_component(i) || !other_obj.has_component(i));
1233 if !is_subset { return quote! {}; }
1234
1235 let is_same_object = (0..basis_element_count).all(|i| obj.has_component(i) == other_obj.has_component(i));
1236 if is_same_object { return quote! {}; }
1237
1238 let is_not_empty = (0..basis_element_count).any(|i| obj.has_component(i) && other_obj.has_component(i));
1239 if !is_not_empty { return quote! {}; }
1240
1241 let my_type_name = obj.type_name();
1242 let my_type_name_colons = obj.type_name_colons();
1243 let other_type_name = other_obj.type_name();
1244
1245 let expressions: Vec<_> = other_obj.select_components(Ident::new("value", Span::call_site()), basis_element_count).iter().map(|select_component| {
1246 match select_component {
1247 Some((symbol, coef)) => SymbolicSumExpr(vec![SymbolicProdExpr(*coef, vec![symbol.clone()])]),
1248 None => Default::default(),
1249 }
1250 }).collect();
1251
1252 let return_expr = match obj {
1253 Object::Scalar => {
1254 let expr = &expressions[0];
1255 quote! { #expr }
1256 }
1257 Object::Struct(output_struct_object) => {
1258 let output_fields: TokenStream = output_struct_object
1259 .select_components
1260 .iter()
1261 .zip(expressions.iter())
1262 .map(|(select_component, expr)| {
1263 if let Some((field, coef)) = select_component {
1264 let expr = expr.clone() * *coef;
1265 quote! { #field: #expr, }
1266 } else {
1267 return quote! {};
1268 }
1269 })
1270 .collect();
1271 quote! {
1272 #my_type_name_colons {
1273 #output_fields
1274 }
1275 }
1276 }
1277 };
1278
1279 quote! {
1280 impl<T: core::default::Default> From<#other_type_name> for #my_type_name {
1281 fn from(value: #other_type_name) -> #my_type_name {
1282 #return_expr
1283 }
1284 }
1285 }
1286 }).collect()
1287 };
1288
1289 let obj_self_components = &obj.select_components(Ident::new("self", Span::call_site()), basis_element_count);
1290
1291 let anti_abs_code = gen_antiscalar_operator(
1293 basis_element_count,
1294 Ident::new("Abs", Span::call_site()),
1295 Ident::new("AntiAbs", Span::call_site()),
1296 &[(
1297 Ident::new("abs", Span::call_site()),
1298 Ident::new("anti_abs", Span::call_site()),
1299 )],
1300 obj,
1301 );
1302
1303 let anti_recip_code = gen_antiscalar_operator(
1305 basis_element_count,
1306 Ident::new("Recip", Span::call_site()),
1307 Ident::new("AntiRecip", Span::call_site()),
1308 &[(
1309 Ident::new("recip", Span::call_site()),
1310 Ident::new("anti_recip", Span::call_site()),
1311 )],
1312 obj,
1313 );
1314
1315 let anti_sqrt_code = gen_antiscalar_operator(
1317 basis_element_count,
1318 Ident::new("Sqrt", Span::call_site()),
1319 Ident::new("AntiSqrt", Span::call_site()),
1320 &[(
1321 Ident::new("sqrt", Span::call_site()),
1322 Ident::new("anti_sqrt", Span::call_site()),
1323 )],
1324 obj,
1325 );
1326
1327 let anti_trig_code = gen_antiscalar_operator(
1329 basis_element_count,
1330 Ident::new("Trig", Span::call_site()),
1331 Ident::new("AntiTrig", Span::call_site()),
1332 &[
1333 (
1334 Ident::new("cos", Span::call_site()),
1335 Ident::new("anti_cos", Span::call_site()),
1336 ),
1337 (
1338 Ident::new("sin", Span::call_site()),
1339 Ident::new("anti_sin", Span::call_site()),
1340 ),
1341 (
1342 Ident::new("sinc", Span::call_site()),
1343 Ident::new("anti_sinc", Span::call_site()),
1344 ),
1345 ],
1346 obj,
1347 );
1348
1349 let op_trait = quote! { core::ops::Neg };
1351 let op_fn = Ident::new("neg", Span::call_site());
1352 let neg_code = gen_unary_operator(
1353 &objects,
1354 op_trait,
1355 op_fn,
1356 &obj,
1357 &generate_symbolic_rearrangement(&obj_self_components, |coef: isize, i: usize| (-coef, i)),
1358 None,
1359 );
1360
1361 let op_trait = quote! { Reverse };
1363 let op_fn = Ident::new("reverse", Span::call_site());
1364 let reverse_f = |coef: isize, i: usize| {
1365 let coef_rev = sign_from_parity((basis[i].len() / 2) % 2);
1366 (coef * coef_rev, i)
1367 };
1368 let reverse_expressions = generate_symbolic_rearrangement(&obj_self_components, reverse_f);
1369 let reverse_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &reverse_expressions, None);
1370
1371 let op_trait = quote! { AntiReverse };
1373 let op_fn = Ident::new("anti_reverse", Span::call_site());
1374 let alias_trait = quote! { InverseTransformation };
1375 let alias_fn = Ident::new("inverse_transformation", Span::call_site());
1376 let anti_reverse_f = |coef: isize, i: usize| {
1377 let (coef, i) = right_complement(&right_complement_signs, coef, i);
1378 let (coef, i) = reverse_f(coef, i);
1379 let (coef, i) = left_complement(&right_complement_signs, coef, i);
1380 (coef, i)
1381 };
1382 let anti_reverse_expressions = generate_symbolic_rearrangement(&obj_self_components, anti_reverse_f);
1383 let anti_reverse_code = gen_unary_operator(
1384 &objects,
1385 op_trait,
1386 op_fn,
1387 &obj,
1388 &anti_reverse_expressions,
1389 Some((alias_trait, alias_fn)), );
1391
1392 let op_trait = quote! { RightComplement };
1394 let op_fn = Ident::new("right_complement", Span::call_site());
1395 let right_complement_expressions = generate_symbolic_rearrangement(&obj_self_components, |coef: isize, i: usize| right_complement(&right_complement_signs, coef, i));
1396 let right_complement_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &right_complement_expressions, None);
1397
1398 let op_trait = quote! { LeftComplement };
1400 let op_fn = Ident::new("left_complement", Span::call_site());
1401 let left_complement_expressions = generate_symbolic_rearrangement(&obj_self_components, |coef: isize, i: usize| left_complement(&right_complement_signs, coef, i));
1402 let left_complement_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &left_complement_expressions, None);
1403
1404 let op_trait = quote! { Bulk };
1406 let op_fn = Ident::new("bulk", Span::call_site());
1407 let bulk_f = |coef: isize, i: usize| {
1408 let (zero_or_one, _) = dot_product_f(1, i, 1, i);
1411 (coef * zero_or_one, i)
1412 };
1413 let bulk_expressions = generate_symbolic_rearrangement(&obj_self_components, bulk_f);
1414 let bulk_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_expressions, None);
1415
1416 let op_trait = quote! { Weight };
1418 let op_fn = Ident::new("weight", Span::call_site());
1419 let weight_f = |coef: isize, i: usize| {
1420 let (coef, i) = right_complement(&right_complement_signs, coef, i);
1421 let (coef, i) = bulk_f(coef, i);
1422 let (coef, i) = left_complement(&right_complement_signs, coef, i);
1423 (coef, i)
1424 };
1425 let weight_expressions = generate_symbolic_rearrangement(&obj_self_components, weight_f);
1426 let weight_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_expressions, None);
1427
1428 let op_trait = quote! { BulkDual };
1430 let op_fn = Ident::new("bulk_dual", Span::call_site());
1431 let bulk_dual_f = |coef: isize, i: usize| {
1432 let (coef, i) = bulk_f(coef, i);
1433 let (coef, i) = right_complement(&right_complement_signs, coef, i);
1434 (coef, i)
1435 };
1436 let bulk_dual_expressions = generate_symbolic_rearrangement(&obj_self_components, bulk_dual_f);
1437 let bulk_dual_code =
1438 gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_dual_expressions, None);
1439
1440 let op_trait = quote! { WeightDual };
1442 let op_fn = Ident::new("weight_dual", Span::call_site());
1443 let alias_trait = quote! { Normal };
1444 let alias_fn = Ident::new("normal", Span::call_site());
1445 let weight_dual_f = |coef: isize, i: usize| {
1446 let (coef, i) = weight_f(coef, i);
1447 let (coef, i) = right_complement(&right_complement_signs, coef, i);
1448 (coef, i)
1449 };
1450 let weight_dual_expressions = generate_symbolic_rearrangement(&obj_self_components, weight_dual_f);
1451 let weight_dual_code =
1452 gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_dual_expressions, Some((alias_trait, alias_fn)));
1453
1454
1455 let op_trait = quote! { BulkNormSquared };
1457 let op_fn = Ident::new("bulk_norm_squared", Span::call_site());
1458 let bulk_norm_squared_expressions = generate_symbolic_norm(&obj_self_components, dot_product_f, false);
1461 let bulk_norm_squared_code =
1462 gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_norm_squared_expressions, None);
1463
1464 let bulk_norm_code = if !bulk_norm_squared_code.is_empty() {
1465 let op_trait = quote! { BulkNorm };
1467 let op_fn = Ident::new("bulk_norm", Span::call_site());
1468 let bulk_norm_expressions = generate_symbolic_norm(&obj_self_components, dot_product_f, true);
1469 let bulk_norm_code = gen_unary_operator(&objects, op_trait, op_fn, &obj, &bulk_norm_expressions, None);
1470 if !bulk_norm_code.is_empty() {
1471 bulk_norm_code
1473 } else {
1474 let type_name = obj.type_name();
1476 quote! {
1477 impl < T: Ring > BulkNorm for #type_name
1478 where <Self as BulkNormSquared>::Output: Sqrt {
1479 type Output = <<Self as BulkNormSquared>::Output as Sqrt>::Output;
1480
1481 fn bulk_norm (self) -> Self::Output {
1482 self.bulk_norm_squared().sqrt()
1483 }
1484 }
1485 }
1486 }
1487 } else {
1488 quote! {} };
1490
1491 let op_trait = quote! { WeightNormSquared };
1492 let op_fn = Ident::new("weight_norm_squared", Span::call_site());
1493 let weight_norm_squared_expressions = generate_symbolic_norm(&obj_self_components, anti_dot_product_f, false);
1494
1495 let weight_norm_squared_code =
1496 gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_norm_squared_expressions, None);
1497
1498 let weight_norm_code = if !weight_norm_squared_code.is_empty() {
1499 let op_trait = quote! { WeightNorm };
1501 let op_fn = Ident::new("weight_norm", Span::call_site());
1502 let weight_norm_expressions = generate_symbolic_norm(&obj_self_components, anti_dot_product_f, true);
1503 let weight_norm_code =
1504 gen_unary_operator(&objects, op_trait, op_fn, &obj, &weight_norm_expressions, None);
1505 if !weight_norm_code.is_empty() {
1506 weight_norm_code
1508 } else {
1509 let type_name = obj.type_name();
1511 quote! {
1512 impl < T: Ring > WeightNorm for #type_name
1513 where <Self as WeightNormSquared>::Output: AntiSqrt {
1514 type Output = <<Self as WeightNormSquared>::Output as AntiSqrt>::Output;
1515
1516 fn weight_norm (self) -> Self::Output {
1517 self.weight_norm_squared().anti_sqrt()
1518 }
1519 }
1520 }
1521 }
1522 } else {
1523 quote! {} };
1525
1526 let hat_code = if !matches!(obj, Object::Scalar) {
1624 let type_name = obj.type_name();
1625 quote! {
1626 impl<T: Ring + Recip<Output=T>> Normalized for #type_name
1627 where
1628 Self: BulkNorm<Output=T>
1629 {
1630 type Output = Self;
1631 fn normalized(self) -> Self {
1632 self * self.bulk_norm().recip()
1633 }
1634 }
1635
1636 impl<T: Ring + Recip<Output=T>> Unitized for #type_name
1637 where
1638 Self: WeightNorm<Output=AntiScalar<T>> {
1641 type Output = Self;
1642 fn unitized(self) -> Self {
1643 self.anti_mul(self.weight_norm().anti_recip())
1644 }
1645 }
1646 }
1647 } else {
1648 quote! {}
1649 };
1650
1651 let op_trait = quote! { core::ops::Add };
1653 let op_fn = Ident::new("add", Span::call_site());
1654 let add_code = gen_binary_operator(
1655 basis_element_count,
1656 &objects,
1657 op_trait,
1658 op_fn,
1659 &obj,
1660 |a, b| generate_symbolic_sum(a, b, 1, 1),
1661 true, None, );
1664
1665 let op_trait = quote! { core::ops::Sub };
1667 let op_fn = Ident::new("sub", Span::call_site());
1668 let sub_code = gen_binary_operator(
1669 basis_element_count,
1670 &objects,
1671 op_trait,
1672 op_fn,
1673 &obj,
1674 |a, b| generate_symbolic_sum(a, b, 1, -1),
1675 true, None, );
1678
1679 let op_trait = quote! { Wedge };
1681 let op_fn = Ident::new("wedge", Span::call_site());
1682 let alias_trait = quote! { Join };
1683 let alias_fn = Ident::new("join", Span::call_site());
1684 let wedge_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1685 let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1686 let s = basis[i].len();
1688 let t = basis[j].len();
1689 let u = basis[ix].len();
1690 let coef = if s + t == u { coef } else { 0 };
1691 (coef, ix)
1692 };
1693 let wedge_product_code = gen_binary_operator(
1694 basis_element_count,
1695 &objects,
1696 op_trait,
1697 op_fn,
1698 &obj,
1699 |a, b| generate_symbolic_product(a, b, wedge_product_f),
1700 false, Some((alias_trait, alias_fn)), );
1703
1704 let op_trait = quote! { AntiWedge };
1706 let op_fn = Ident::new("anti_wedge", Span::call_site());
1707 let alias_trait = quote! { Meet };
1708 let alias_fn = Ident::new("meet", Span::call_site());
1709 let anti_wedge_product_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1710 let (coef_i, i) = right_complement(&right_complement_signs, coef_i, i);
1711 let (coef_j, j) = right_complement(&right_complement_signs, coef_j, j);
1712
1713 let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1714 let s = basis[i].len();
1716 let t = basis[j].len();
1717 let u = basis[ix].len();
1718 let coef = if s + t == u { coef } else { 0 };
1719
1720 let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1721 (coef, ix)
1722 };
1723 let anti_wedge_product_code = gen_binary_operator(
1724 basis_element_count,
1725 &objects,
1726 op_trait,
1727 op_fn,
1728 &obj,
1729 |a, b| generate_symbolic_product(a, b, anti_wedge_product_f),
1730 false, Some((alias_trait, alias_fn)), );
1733
1734 let op_trait = quote! { Dot };
1736 let op_fn = Ident::new("dot", Span::call_site());
1737
1738 let dot_product_code = gen_binary_operator(
1739 basis_element_count,
1740 &objects,
1741 op_trait,
1742 op_fn,
1743 &obj,
1744 |a, b| generate_symbolic_product(a, b, dot_product_f),
1745 false, None, );
1748
1749 let op_trait = quote! { AntiDot };
1751 let op_fn = Ident::new("anti_dot", Span::call_site());
1752
1753 let anti_dot_product_code = gen_binary_operator(
1754 basis_element_count,
1755 &objects,
1756 op_trait,
1757 op_fn,
1758 &obj,
1759 |a, b| generate_symbolic_product(a, b, anti_dot_product_f),
1760 false, None, );
1763
1764 let op_trait = quote! { WedgeDot };
1766 let op_fn = Ident::new("wedge_dot", Span::call_site());
1767
1768 let wedge_dot_product_code = gen_binary_operator(
1769 basis_element_count,
1770 &objects,
1771 op_trait,
1772 op_fn,
1773 &obj,
1774 |a, b| generate_symbolic_product(a, b, geometric_product_f),
1775 false, None, );
1778
1779 let op_trait = quote! { AntiWedgeDot };
1781 let op_fn = Ident::new("anti_wedge_dot", Span::call_site());
1782 let alias_trait = quote! { Compose };
1783 let alias_fn = Ident::new("compose", Span::call_site());
1784
1785 let anti_wedge_dot_product_code = gen_binary_operator(
1786 basis_element_count,
1787 &objects,
1788 op_trait,
1789 op_fn,
1790 &obj,
1791 |a, b| generate_symbolic_product(a, b, geometric_antiproduct_f),
1792 false, Some((alias_trait, alias_fn)), );
1795
1796 let op_trait = quote! { core::ops::Mul };
1798 let op_fn = Ident::new("mul", Span::call_site());
1799 let scalar_product = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1800 let (coef, ix) = geometric_product_f(coef_i, i, coef_j, j);
1801 let s = basis[i].len();
1803 let t = basis[j].len();
1804 let coef = if s == 0 || t == 0 { coef } else { 0 };
1805 (coef, ix)
1806 };
1807 let scalar_product_code = gen_binary_operator(
1808 basis_element_count,
1809 &objects,
1810 op_trait,
1811 op_fn,
1812 &obj,
1813 |a, b| generate_symbolic_product(a, b, scalar_product),
1814 true, None, );
1817
1818 let op_trait = quote! { AntiMul };
1820 let op_fn = Ident::new("anti_mul", Span::call_site());
1821 let anti_scalar_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1822 let (coef_i, i) = right_complement(&right_complement_signs, coef_i, i);
1823 let (coef_j, j) = right_complement(&right_complement_signs, coef_j, j);
1824
1825 let (coef, ix) = scalar_product(coef_i, i, coef_j, j);
1826
1827 let (coef, ix) = left_complement(&right_complement_signs, coef, ix);
1828 (coef, ix)
1829 };
1830 let anti_scalar_product_code = gen_binary_operator(
1831 basis_element_count,
1832 &objects,
1833 op_trait,
1834 op_fn,
1835 &obj,
1836 |a, b| generate_symbolic_product(a, b, anti_scalar_f),
1837 true, None, );
1840
1841 let op_trait = quote! { BulkExpansion };
1843 let op_fn = Ident::new("bulk_expansion", Span::call_site());
1844 let bulk_expansion_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1845 let (coef_j, j) = bulk_dual_f(coef_j, j);
1846 let (coef, ix) = wedge_product_f(coef_i, i, coef_j, j);
1847 (coef, ix)
1848 };
1849 let bulk_expansion_code = gen_binary_operator(
1850 basis_element_count,
1851 &objects,
1852 op_trait,
1853 op_fn,
1854 &obj,
1855 |a, b| generate_symbolic_product(a, b, bulk_expansion_f),
1856 false, None, );
1859
1860 let op_trait = quote! { WeightExpansion };
1862 let op_fn = Ident::new("weight_expansion", Span::call_site());
1863 let alias_trait = quote! { SupersetOrthogonalTo };
1864 let alias_fn = Ident::new("superset_orthogonal_to", Span::call_site());
1865 let weight_expansion_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1866 let (coef_j, j) = weight_dual_f(coef_j, j);
1867 let (coef, ix) = wedge_product_f(coef_i, i, coef_j, j);
1868 (coef, ix)
1869 };
1870 let weight_expansion_code = gen_binary_operator(
1871 basis_element_count,
1872 &objects,
1873 op_trait,
1874 op_fn,
1875 &obj,
1876 |a, b| generate_symbolic_product(a, b, weight_expansion_f),
1877 false, Some((alias_trait, alias_fn)), );
1880
1881 let op_trait = quote! { BulkContraction };
1883 let op_fn = Ident::new("bulk_contraction", Span::call_site());
1884 let bulk_contraction_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1885 let (coef_j, j) = bulk_dual_f(coef_j, j);
1886 let (coef, ix) = anti_wedge_product_f(coef_i, i, coef_j, j);
1887 (coef, ix)
1888 };
1889 let bulk_contraction_code = gen_binary_operator(
1890 basis_element_count,
1891 &objects,
1892 op_trait,
1893 op_fn,
1894 &obj,
1895 |a, b| generate_symbolic_product(a, b, bulk_contraction_f),
1896 false, None, );
1899
1900 let op_trait = quote! { WeightContraction };
1902 let op_fn = Ident::new("weight_contraction", Span::call_site());
1903 let alias_trait = quote! { SubsetOrthogonalTo };
1904 let alias_fn = Ident::new("subset_orthogonal_to", Span::call_site());
1905 let weight_contraction_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1906 let (coef_j, j) = weight_dual_f(coef_j, j);
1907 let (coef, ix) = anti_wedge_product_f(coef_i, i, coef_j, j);
1908 (coef, ix)
1909 };
1910 let weight_contraction_code = gen_binary_operator(
1911 basis_element_count,
1912 &objects,
1913 op_trait,
1914 op_fn,
1915 &obj,
1916 |a, b| generate_symbolic_product(a, b, weight_contraction_f),
1917 false, Some((alias_trait, alias_fn)), );
1920
1921 let op_trait = quote! { AntiCommutator };
1923 let op_fn = Ident::new("anti_commutator", Span::call_site());
1924 let commutator_product = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1925 let (coef_1, ix) = geometric_antiproduct_f(coef_i, i, coef_j, j);
1926 let (coef_2, ix_2) = geometric_antiproduct_f(coef_j, j, coef_i, i);
1927
1928 let coef = coef_1 - coef_2;
1929 assert!(ix == ix_2);
1930 assert!(coef % 2 == 0);
1931 let coef = coef / 2;
1932 (coef, ix)
1933 };
1934 let commutator_product_code = gen_binary_operator(
1935 basis_element_count,
1936 &objects,
1937 op_trait,
1938 op_fn,
1939 &obj,
1940 |a, b| generate_symbolic_product(a, b, commutator_product),
1941 false, None, );
1944
1945 let op_trait = quote! { Transform };
1947 let op_fn = Ident::new("transform", Span::call_site());
1948 let transform_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1949 let (coef_i, i) = anti_reverse_f(coef_i, i);
1955 geometric_antiproduct_f(coef_i, i, coef_j, j)
1956 };
1957 let transform_2 = geometric_antiproduct_f;
1961
1962 let transform_code = gen_binary_operator(
1963 basis_element_count,
1964 &objects,
1965 op_trait,
1966 op_fn,
1967 &obj,
1968 |a, b| {
1969 generate_symbolic_double_product(
1970 a,
1971 b,
1972 transform_1,
1973 transform_2,
1974 )
1975 },
1976 false, None, );
1979
1980 let op_trait = quote! { TransformInverse };
1982 let op_fn = Ident::new("transform_inverse", Span::call_site());
1983 let reverse_transform_1 = geometric_antiproduct_f;
1988 let reverse_transform_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
1989 let (coef_j, j) = anti_reverse_f(coef_j, j);
1994 geometric_antiproduct_f(coef_i, i, coef_j, j)
1995 };
1996 let reverse_transform_code = gen_binary_operator(
1997 basis_element_count,
1998 &objects,
1999 op_trait,
2000 op_fn,
2001 &obj,
2002 |a, b| {
2003 generate_symbolic_double_product(
2004 a,
2005 b,
2006 reverse_transform_1,
2007 reverse_transform_2,
2008 )
2009 },
2010 false, None, );
2013
2014 let op_trait = quote! { Projection };
2016 let op_fn = Ident::new("projection", Span::call_site());
2017 let projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2018 weight_expansion_f(coef_j, j, coef_i, i)
2022 };
2023 let projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2024 anti_wedge_product_f(coef_j, j, coef_i, i)
2028 };
2029 let projection_code = gen_binary_operator(
2030 basis_element_count,
2031 &objects,
2032 op_trait,
2033 op_fn,
2034 &obj,
2035 |a, b| {
2036 generate_symbolic_double_product(
2037 a,
2038 b,
2039 projection_product_1,
2040 projection_product_2,
2041 )
2042 },
2043 false, None, );
2046
2047 let op_trait = quote! { AntiProjection };
2049 let op_fn = Ident::new("anti_projection", Span::call_site());
2050 let anti_projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2051 weight_contraction_f(coef_j, j, coef_i, i)
2055 };
2056 let anti_projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2057 wedge_product_f(coef_j, j, coef_i, i)
2061 };
2062 let anti_projection_code = gen_binary_operator(
2063 basis_element_count,
2064 &objects,
2065 op_trait,
2066 op_fn,
2067 &obj,
2068 |a, b| {
2069 generate_symbolic_double_product(
2070 a,
2071 b,
2072 anti_projection_product_1,
2073 anti_projection_product_2,
2074 )
2075 },
2076 false, None, );
2079
2080 let op_trait = quote! { CentralProjection };
2082 let op_fn = Ident::new("central_projection", Span::call_site());
2083 let central_projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2084 bulk_expansion_f(coef_j, j, coef_i, i)
2088 };
2089 let central_projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2090 anti_wedge_product_f(coef_j, j, coef_i, i)
2094 };
2095 let central_projection_code = gen_binary_operator(
2096 basis_element_count,
2097 &objects,
2098 op_trait,
2099 op_fn,
2100 &obj,
2101 |a, b| {
2102 generate_symbolic_double_product(
2103 a,
2104 b,
2105 central_projection_product_1,
2106 central_projection_product_2,
2107 )
2108 },
2109 false, None, );
2112
2113 let op_trait = quote! { CentralAntiProjection };
2115 let op_fn = Ident::new("central_anti_projection", Span::call_site());
2116 let central_anti_projection_product_1 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2117 bulk_contraction_f(coef_j, j, coef_i, i)
2121 };
2122 let central_anti_projection_product_2 = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2123 wedge_product_f(coef_j, j, coef_i, i)
2127 };
2128 let central_anti_projection_code = gen_binary_operator(
2129 basis_element_count,
2130 &objects,
2131 op_trait,
2132 op_fn,
2133 &obj,
2134 |a, b| {
2135 generate_symbolic_double_product(
2136 a,
2137 b,
2138 central_anti_projection_product_1,
2139 central_anti_projection_product_2,
2140 )
2141 },
2142 false, None, );
2145
2146 let op_trait = quote! { MotorTo };
2148 let op_fn = Ident::new("motor_to", Span::call_site());
2149 let motor_to_f = |coef_i: isize, i: usize, coef_j: isize, j: usize| {
2150 let (coef_i, i) = anti_reverse_f(coef_i, i);
2151 geometric_antiproduct_f(coef_i, i, coef_j, j)
2152 };
2153
2154 let motor_to_code = gen_binary_operator(
2155 basis_element_count,
2156 &objects,
2157 op_trait,
2158 op_fn,
2159 &obj,
2160 |a, b| generate_symbolic_product(a, b, motor_to_f),
2161 true, None, );
2164
2165 quote! {
2166 #from_code
2171 #anti_abs_code
2172 #anti_recip_code
2173 #anti_sqrt_code
2174 #anti_trig_code
2175 #neg_code
2176 #reverse_code
2177 #anti_reverse_code
2178 #bulk_code
2179 #weight_code
2180 #bulk_dual_code
2181 #weight_dual_code
2182 #right_complement_code
2183 #left_complement_code
2184 #bulk_norm_squared_code
2185 #bulk_norm_code
2186 #weight_norm_squared_code
2187 #weight_norm_code
2188 #hat_code
2189 #add_code
2190 #sub_code
2191 #wedge_product_code
2192 #anti_wedge_product_code
2193 #dot_product_code
2194 #anti_dot_product_code
2195 #wedge_dot_product_code
2196 #anti_wedge_dot_product_code
2197 #scalar_product_code
2198 #anti_scalar_product_code
2199 #commutator_product_code
2200 #bulk_expansion_code
2201 #weight_expansion_code
2202 #bulk_contraction_code
2203 #weight_contraction_code
2204 #projection_code
2205 #anti_projection_code
2206 #central_projection_code
2207 #central_anti_projection_code
2208 #transform_code
2209 #reverse_transform_code
2210 #motor_to_code
2211 }
2212 })
2213 .collect();
2214
2215 Ok(impl_code)
2216}
2217
2218struct VecItem(Vec<Item>);
2220
2221impl Parse for VecItem {
2222 fn parse(input: ParseStream) -> Result<Self> {
2223 let mut items = Vec::<Item>::new();
2224 while !input.is_empty() {
2225 items.push(input.parse()?);
2226 }
2227 Ok(VecItem(items))
2228 }
2229}
2230
2231struct BasisVectorIdents(Vec<Ident>);
2232
2233impl Parse for BasisVectorIdents {
2234 fn parse(input: ParseStream) -> Result<Self> {
2235 let ident_list = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;
2236 Ok(BasisVectorIdents(ident_list.into_iter().collect()))
2237 }
2238}
2239
2240struct Metric(Vec<isize>);
2241
2242impl Parse for Metric {
2243 fn parse(input: ParseStream) -> Result<Self> {
2244 let lit_list = Punctuated::<LitInt, Token![,]>::parse_terminated(input)?;
2245 Ok(Metric(
2246 lit_list
2247 .into_iter()
2248 .map(|lit| lit.base10_parse::<isize>())
2249 .collect::<Result<Vec<_>>>()?,
2250 ))
2251 }
2252}
2253
2254fn geometric_algebra2(code: TokenStream) -> Result<TokenStream> {
2255 let VecItem(mut items) = parse2(code)?;
2256
2257 let mut metric = None;
2258 let mut basis = None;
2259 let mut multivector_structs = Vec::<MultivectorStruct>::new();
2260
2261 let multivector_ident = Ident::new("multivector", Span::call_site());
2263 let metric_ident = Ident::new("metric", Span::call_site());
2264 let basis_ident = Ident::new("basis", Span::call_site());
2265
2266 let mut err: Result<()> = Ok(());
2267
2268 let mut append_err = |new_e: Error| {
2269 err = match &mut err {
2270 Ok(()) => Err(new_e),
2271 Err(old_e) => {
2272 old_e.combine(new_e);
2273 Err(old_e.clone())
2274 }
2275 };
2276 };
2277
2278 items.retain_mut(|item| {
2279 match item {
2280 Item::Macro(ItemMacro {
2281 mac: item_macro, ..
2282 }) => {
2283 let macro_ident = item_macro.path.get_ident();
2284 if macro_ident == Some(&basis_ident) {
2285 if basis.is_some() {
2286 append_err(Error::new(
2287 macro_ident.unwrap().span(),
2288 "Duplicate basis definition",
2289 ));
2290 } else {
2291 let parsed_basis: Result<BasisVectorIdents> = item_macro.parse_body();
2293 match parsed_basis {
2294 Ok(BasisVectorIdents(parsed_basis)) => {
2295 basis = Some(parsed_basis);
2296 }
2297 Err(e) => {
2298 append_err(e);
2299 }
2300 }
2301 }
2302 false } else if macro_ident == Some(&metric_ident) {
2304 if metric.is_some() {
2305 append_err(Error::new(
2306 macro_ident.unwrap().span(),
2307 "Duplicate metric definition",
2308 ));
2309 } else {
2310 let parsed_metric: Result<Metric> = item_macro.parse_body();
2312 match parsed_metric {
2313 Ok(Metric(parsed_metric)) => {
2314 metric = Some(parsed_metric);
2315 }
2316 Err(e) => {
2317 append_err(e);
2318 }
2319 }
2320 }
2321 false } else {
2323 true }
2325 }
2326 Item::Struct(item_struct) => {
2327 let mut has_multivector_attribute = false;
2328
2329 item_struct.attrs.retain(|attr| {
2330 if let Attribute {
2331 style: AttrStyle::Outer,
2332 meta: Meta::Path(Path { segments, .. }),
2333 ..
2334 } = attr
2335 {
2336 if segments.len() == 1
2339 && segments.first().unwrap().ident == multivector_ident
2340 {
2341 has_multivector_attribute = true;
2342 false
2343 } else {
2344 true
2345 }
2346 } else {
2347 true
2348 }
2349 });
2350
2351 if has_multivector_attribute {
2352 if let Fields::Named(FieldsNamed { named: fields, .. }) = &item_struct.fields {
2355 match fields
2357 .iter()
2358 .map(|field| Ok(field.ident.as_ref().unwrap().clone()))
2359 .collect::<Result<Vec<_>>>()
2360 {
2361 Ok(components) => {
2362 multivector_structs.push(MultivectorStruct {
2363 ident: item_struct.ident.clone(),
2364 components,
2365 });
2366 }
2367 Err(e) => {
2368 append_err(e);
2369 }
2370 }
2371 } else {
2372 append_err(Error::new(
2373 item_struct.ident.span(),
2374 "Multivector must have named fields",
2375 ));
2376 }
2377 }
2378
2379 true }
2381 _ => {
2382 true }
2384 }
2385 });
2386 err?;
2387
2388 let Some(basis) = basis else {
2389 return Err(Error::new(
2390 Span::call_site(),
2391 "Missing basis![..] definition",
2392 ));
2393 };
2394
2395 let Some(metric) = metric else {
2396 return Err(Error::new(
2397 Span::call_site(),
2398 "Missing metric![..] definition",
2399 ));
2400 };
2401
2402 let generated_code = implement_geometric_algebra(basis, metric, multivector_structs)?;
2404
2405 let mut code = TokenStream::new();
2407 code.append_all(items);
2408 generated_code.to_tokens(&mut code);
2409 Ok(code)
2410}
2411
2412#[proc_macro]
2413pub fn geometric_algebra(code: proc_macro::TokenStream) -> proc_macro::TokenStream {
2414 geometric_algebra2(code.into())
2416 .unwrap_or_else(Error::into_compile_error)
2417 .into()
2418}
2419
2420#[cfg(test)]
2421mod tests {
2422 use super::*;
2423
2424 #[test]
2425 fn derive_geometric_algebra() {
2426 let _result = geometric_algebra2(quote! {
2427 basis![w, x, y];
2428 metric![0, 1, 1];
2429 #[multivector]
2430 struct Vector<T> {
2431 x: T,
2432 y: T,
2433 w: T,
2434 }
2435 #[multivector]
2436 struct Bivector<T> {
2437 wx: T,
2438 wy: T,
2439 xy: T,
2440 }
2441 #[multivector]
2442 struct AntiScalar<T> {
2443 wxy: T,
2444 }
2445 #[multivector]
2446 struct AntiEven<T> {
2447 a: T,
2448 wx: T,
2449 wy: T,
2450 xy: T,
2451 }
2452 #[multivector]
2453 struct AntiOdd<T> {
2454 x: T,
2455 y: T,
2456 w: T,
2457 wxy: T,
2458 }
2459 #[multivector]
2460 struct Multivector<T> {
2461 a: T,
2462 x: T,
2463 y: T,
2464 w: T,
2465 wx: T,
2466 wy: T,
2467 xy: T,
2468 wxy: T,
2469 }
2470 })
2471 .unwrap();
2472 }
2475}