clifford_codegen/codegen/
unary.rs1use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8use crate::algebra::{Algebra, ProductTable};
9use crate::spec::{AlgebraSpec, TypeSpec};
10
11pub struct UnaryGenerator<'a> {
13 spec: &'a AlgebraSpec,
15 table: ProductTable,
17}
18
19impl<'a> UnaryGenerator<'a> {
20 pub fn new(spec: &'a AlgebraSpec) -> Self {
22 let algebra = Algebra::from_metrics(spec.signature.metrics_by_index());
23 let table = ProductTable::new(&algebra);
24 Self { spec, table }
25 }
26
27 pub fn generate_all(&self) -> TokenStream {
29 let reverse = self.generate_all_reverse();
30 let antireverse = self.generate_all_antireverse();
31 let complement = self.generate_all_complement();
32
33 quote! {
34 #reverse
38
39 #antireverse
43
44 #complement
48 }
49 }
50
51 fn generate_all_reverse(&self) -> TokenStream {
53 let ops: Vec<TokenStream> = self
54 .spec
55 .types
56 .iter()
57 .filter(|t| t.alias_of.is_none())
58 .filter_map(|ty| self.generate_reverse(ty))
59 .collect();
60
61 quote! { #(#ops)* }
62 }
63
64 fn generate_reverse(&self, ty: &TypeSpec) -> Option<TokenStream> {
66 let type_name = format_ident!("{}", ty.name);
67 let fn_name = format_ident!("reverse_{}", ty.name.to_lowercase());
68
69 let field_exprs: Vec<TokenStream> = ty
71 .fields
72 .iter()
73 .map(|field| {
74 let field_name = format_ident!("{}", field.name);
75 let grade = field.grade;
76 if (grade * grade.saturating_sub(1) / 2).is_multiple_of(2) {
78 quote! { a.#field_name() }
79 } else {
80 quote! { -a.#field_name() }
81 }
82 })
83 .collect();
84
85 let constructor = quote! { #type_name::new_unchecked(#(#field_exprs),*) };
86
87 let doc = format!(
88 "Reverses the {} (negates grades where k(k-1)/2 is odd).",
89 ty.name
90 );
91
92 Some(quote! {
93 #[doc = #doc]
94 #[inline]
95 pub fn #fn_name<T: Float>(a: &#type_name<T>) -> #type_name<T> {
96 #constructor
97 }
98 })
99 }
100
101 fn generate_all_antireverse(&self) -> TokenStream {
103 let ops: Vec<TokenStream> = self
104 .spec
105 .types
106 .iter()
107 .filter(|t| t.alias_of.is_none())
108 .filter_map(|ty| self.generate_antireverse(ty))
109 .collect();
110
111 quote! { #(#ops)* }
112 }
113
114 fn generate_antireverse(&self, ty: &TypeSpec) -> Option<TokenStream> {
116 let type_name = format_ident!("{}", ty.name);
117 let fn_name = format_ident!("antireverse_{}", ty.name.to_lowercase());
118 let dim = self.spec.signature.dim();
119
120 let field_exprs: Vec<TokenStream> = ty
122 .fields
123 .iter()
124 .map(|field| {
125 let field_name = format_ident!("{}", field.name);
126 let grade = field.grade;
127 let antigrade = dim - grade;
128 if (antigrade * antigrade.saturating_sub(1) / 2).is_multiple_of(2) {
130 quote! { a.#field_name() }
131 } else {
132 quote! { -a.#field_name() }
133 }
134 })
135 .collect();
136
137 let constructor = quote! { #type_name::new_unchecked(#(#field_exprs),*) };
138
139 let doc = format!(
140 "Antireverses the {} (negates grades where (n-k)(n-k-1)/2 is odd).",
141 ty.name
142 );
143
144 Some(quote! {
145 #[doc = #doc]
146 #[inline]
147 pub fn #fn_name<T: Float>(a: &#type_name<T>) -> #type_name<T> {
148 #constructor
149 }
150 })
151 }
152
153 fn generate_all_complement(&self) -> TokenStream {
155 let ops: Vec<TokenStream> = self
156 .spec
157 .types
158 .iter()
159 .filter(|t| t.alias_of.is_none())
160 .filter_map(|ty| self.generate_complement(ty))
161 .collect();
162
163 quote! { #(#ops)* }
164 }
165
166 fn generate_complement(&self, ty: &TypeSpec) -> Option<TokenStream> {
170 let dim = self.spec.signature.dim();
171
172 let mut output_grades: Vec<usize> = ty.grades.iter().map(|&g| dim - g).collect();
174 output_grades.sort();
175
176 let output_type = self.spec.types.iter().find(|t| {
178 if t.alias_of.is_some() {
179 return false;
180 }
181 let mut t_grades = t.grades.clone();
182 t_grades.sort();
183 t_grades == output_grades
184 })?;
185
186 let type_name = format_ident!("{}", ty.name);
187 let output_name = format_ident!("{}", output_type.name);
188 let fn_name = format_ident!("complement_{}", ty.name.to_lowercase());
189
190 let field_exprs: Vec<TokenStream> = output_type
192 .fields
193 .iter()
194 .map(|out_field| {
195 let out_blade = out_field.blade_index;
197
198 let mut expr = quote! { T::zero() };
200 for in_field in &ty.fields {
201 let (sign, comp_blade) = self.table.complement(in_field.blade_index);
202 if comp_blade == out_blade && sign != 0 {
203 let in_name = format_ident!("{}", in_field.name);
204 let total_sign = sign * in_field.sign * out_field.sign;
208 if total_sign > 0 {
209 expr = quote! { a.#in_name() };
210 } else {
211 expr = quote! { -a.#in_name() };
212 }
213 break;
214 }
215 }
216 expr
217 })
218 .collect();
219
220 let constructor = quote! { #output_name::new_unchecked(#(#field_exprs),*) };
221
222 let doc = format!(
223 "Computes the right complement of {} -> {}.",
224 ty.name, output_type.name
225 );
226
227 Some(quote! {
228 #[doc = #doc]
229 #[inline]
230 pub fn #fn_name<T: Float>(a: &#type_name<T>) -> #output_name<T> {
231 #constructor
232 }
233 })
234 }
235}