clifford_codegen/symbolic/
to_rust.rs1use std::collections::HashMap;
7
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote};
10use symbolica::atom::{Atom, AtomCore, AtomView};
11use symbolica::coefficient::CoefficientView;
12
13use crate::spec::TypeSpec;
14
15pub struct AtomToRust {
34 symbol_map: HashMap<String, (String, String)>,
36 wrapped_prefixes: std::collections::HashSet<String>,
38}
39
40impl AtomToRust {
41 pub fn new(types: &[&TypeSpec], prefixes: &[&str]) -> Self {
54 Self::new_with_wrappers(types, prefixes, &[])
55 }
56
57 pub fn new_with_wrappers(
76 types: &[&TypeSpec],
77 prefixes: &[&str],
78 wrapper_prefixes: &[&str],
79 ) -> Self {
80 let mut symbol_map = HashMap::new();
81
82 for (ty, prefix) in types.iter().zip(prefixes.iter()) {
83 for field in &ty.fields {
84 let symbol_name = format!("{}_{}", prefix, field.name);
85 symbol_map.insert(symbol_name, ((*prefix).to_string(), field.name.clone()));
86 }
87 }
88
89 let wrapped_prefixes = wrapper_prefixes.iter().map(|s| (*s).to_string()).collect();
90
91 Self {
92 symbol_map,
93 wrapped_prefixes,
94 }
95 }
96
97 pub fn convert(&self, atom: &Atom) -> TokenStream {
99 self.convert_view(atom.as_atom_view())
100 }
101
102 fn convert_view(&self, view: AtomView<'_>) -> TokenStream {
104 match view {
105 AtomView::Num(n) => self.convert_num(n),
106 AtomView::Var(v) => self.convert_var(v),
107 AtomView::Add(a) => self.convert_add(a),
108 AtomView::Mul(m) => self.convert_mul(m),
109 AtomView::Pow(p) => self.convert_pow(p),
110 AtomView::Fun(_) => {
111 quote! { T::zero() }
113 }
114 }
115 }
116
117 fn convert_num(&self, num: symbolica::atom::NumView<'_>) -> TokenStream {
119 let coeff = num.get_coeff_view();
120
121 match coeff {
122 CoefficientView::Natural(n_re, d_re, n_im, d_im) => {
123 if n_im != 0 || d_im != 1 {
125 return quote! { T::zero() };
127 }
128
129 if d_re == 1 {
130 self.convert_integer(n_re)
132 } else {
133 let val = n_re as f64 / d_re as f64;
135 quote! { T::from_f64(#val) }
136 }
137 }
138 CoefficientView::Float(_, _) => {
139 quote! { T::zero() }
142 }
143 CoefficientView::Large(_, _) => {
144 quote! { T::zero() }
147 }
148 CoefficientView::FiniteField(_, _) => {
149 quote! { T::zero() }
151 }
152 CoefficientView::RationalPolynomial(_) => {
153 quote! { T::zero() }
155 }
156 CoefficientView::Indeterminate | CoefficientView::Infinity(_) => {
157 quote! { T::zero() }
159 }
160 }
161 }
162
163 fn convert_integer(&self, n: i64) -> TokenStream {
165 match n {
166 0 => quote! { T::zero() },
167 1 => quote! { T::one() },
168 2 => quote! { T::TWO },
169 -1 => quote! { -T::one() },
170 -2 => quote! { -T::TWO },
171 _ if n >= i8::MIN as i64 && n <= i8::MAX as i64 => {
172 let n_i8 = n as i8;
173 quote! { T::from_i8(#n_i8) }
174 }
175 _ => {
176 let n_f64 = n as f64;
177 quote! { T::from_f64(#n_f64) }
178 }
179 }
180 }
181
182 fn convert_var(&self, var: symbolica::atom::VarView<'_>) -> TokenStream {
184 let symbol = var.get_symbol();
185 let name = symbol.to_string();
186
187 if let Some((prefix, field)) = self.symbol_map.get(&name) {
188 let prefix_ident = format_ident!("{}", prefix);
190 let field_ident = format_ident!("{}", field);
191
192 if self.wrapped_prefixes.contains(prefix) {
194 quote! { #prefix_ident.as_inner().#field_ident() }
195 } else {
196 quote! { #prefix_ident.#field_ident() }
197 }
198 } else {
199 let ident = format_ident!("{}", name.replace("clifford_codegen::", ""));
201 quote! { #ident }
202 }
203 }
204
205 fn convert_add(&self, add: symbolica::atom::AddView<'_>) -> TokenStream {
207 let terms: Vec<AtomView<'_>> = add.iter().collect();
208
209 if terms.is_empty() {
210 return quote! { T::zero() };
211 }
212
213 let mut converted: Vec<(bool, TokenStream)> = terms
215 .iter()
216 .map(|term| self.extract_negation(*term))
217 .collect();
218
219 converted.sort_by_cached_key(|(_, tokens)| tokens.to_string());
221
222 let (first_neg, first_expr) = converted.remove(0);
224 let mut result = if first_neg {
225 quote! { -(#first_expr) }
226 } else {
227 first_expr
228 };
229
230 for (is_neg, term_expr) in converted {
231 if is_neg {
232 result = quote! { #result - #term_expr };
233 } else {
234 result = quote! { #result + #term_expr };
235 }
236 }
237
238 result
239 }
240
241 fn extract_negation(&self, view: AtomView<'_>) -> (bool, TokenStream) {
245 match view {
246 AtomView::Num(n) => {
247 let coeff = n.get_coeff_view();
248 if self.is_negative_coefficient(&coeff) {
249 let atom = view.to_owned();
251 let negated = -&atom;
252 (true, self.convert(&negated))
253 } else {
254 (false, self.convert_num(n))
255 }
256 }
257 AtomView::Mul(m) => {
258 let factors: Vec<AtomView<'_>> = m.iter().collect();
260 if let Some(AtomView::Num(n)) = factors.first() {
261 let coeff = n.get_coeff_view();
262 if self.is_negative_coefficient(&coeff) {
263 let atom = view.to_owned();
265 let negated = -&atom;
266 let expanded = negated.expand();
267 return (true, self.convert(&expanded));
268 }
269 }
270 (false, self.convert_mul(m))
271 }
272 _ => (false, self.convert_view(view)),
273 }
274 }
275
276 fn is_negative_coefficient(&self, coeff: &CoefficientView) -> bool {
278 match coeff {
279 CoefficientView::Natural(n_re, _, _, _) => *n_re < 0,
280 CoefficientView::Float(_, _) | CoefficientView::Large(_, _) => false,
283 _ => false,
284 }
285 }
286
287 fn convert_mul(&self, mul: symbolica::atom::MulView<'_>) -> TokenStream {
289 let factors: Vec<AtomView<'_>> = mul.iter().collect();
290
291 if factors.is_empty() {
292 return quote! { T::one() };
293 }
294
295 let (coeff, remaining) = self.split_coefficient(&factors);
297
298 if remaining.is_empty() {
299 return coeff.unwrap_or_else(|| quote! { T::one() });
301 }
302
303 let mut factor_exprs: Vec<TokenStream> =
305 remaining.iter().map(|f| self.convert_view(*f)).collect();
306
307 factor_exprs.sort_by_cached_key(|tokens| tokens.to_string());
309
310 let product = if factor_exprs.len() == 1 {
312 factor_exprs[0].clone()
313 } else {
314 let first = &factor_exprs[0];
315 let rest = &factor_exprs[1..];
316 quote! { #first #(* #rest)* }
317 };
318
319 match coeff {
321 None => product,
322 Some(c) => {
323 let c_str = c.to_string();
325 if c_str.contains("T :: one ()") && !c_str.contains('-') {
326 product
327 } else if c_str == "- T :: one ()" {
328 quote! { -(#product) }
329 } else {
330 quote! { #c * #product }
331 }
332 }
333 }
334 }
335
336 fn split_coefficient<'b>(
341 &self,
342 factors: &[AtomView<'b>],
343 ) -> (Option<TokenStream>, Vec<AtomView<'b>>) {
344 let mut coeff = None;
345 let mut remaining = Vec::new();
346
347 for factor in factors {
348 if let AtomView::Num(n) = factor {
349 if coeff.is_none() {
350 coeff = Some(self.convert_num(*n));
351 continue;
352 }
353 }
354 remaining.push(*factor);
355 }
356
357 (coeff, remaining)
358 }
359
360 fn convert_pow(&self, pow: symbolica::atom::PowView<'_>) -> TokenStream {
362 let (base, exp) = pow.get_base_exp();
363
364 if let AtomView::Num(n) = exp {
366 let coeff = n.get_coeff_view();
367 if let CoefficientView::Natural(exp_val, 1, 0, 1) = coeff {
368 match exp_val {
369 0 => return quote! { T::one() },
370 1 => return self.convert_view(base),
371 2 => {
372 let base_expr = self.convert_view(base);
373 return quote! { #base_expr * #base_expr };
374 }
375 3 => {
376 let base_expr = self.convert_view(base);
377 return quote! { #base_expr * #base_expr * #base_expr };
378 }
379 _ => {}
380 }
381 }
382 }
383
384 let base_expr = self.convert_view(base);
386 let exp_expr = self.convert_view(exp);
387 quote! { #base_expr.powf(#exp_expr) }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::spec::FieldSpec;
395 use std::sync::Mutex;
396 use symbolica::atom::Atom;
397
398 static SYMBOLICA_LOCK: Mutex<()> = Mutex::new(());
402
403 fn make_test_type(name: &str, fields: &[(&str, usize)]) -> TypeSpec {
404 TypeSpec {
405 name: name.to_string(),
406 grades: vec![1],
407 description: None,
408 fields: fields
409 .iter()
410 .map(|(n, idx)| FieldSpec {
411 name: n.to_string(),
412 blade_index: *idx,
413 grade: 1,
414 sign: 1,
415 })
416 .collect(),
417 alias_of: None,
418 versor: None,
419 is_sparse: false,
420 inverse_sandwich_targets: vec![],
421 }
422 }
423
424 #[test]
425 fn symbolica_convert_integer_constants() {
426 let _guard = SYMBOLICA_LOCK.lock().unwrap();
427 let type_a = make_test_type("A", &[("x", 1)]);
428 let converter = AtomToRust::new(&[&type_a], &["a"]);
429
430 let zero = Atom::num(0);
432 assert!(converter.convert(&zero).to_string().contains("zero"));
433
434 let one = Atom::num(1);
436 assert!(converter.convert(&one).to_string().contains("one"));
437
438 let two = Atom::num(2);
440 assert!(converter.convert(&two).to_string().contains("TWO"));
441 }
442
443 #[test]
444 fn symbolica_convert_negative_integers() {
445 let _guard = SYMBOLICA_LOCK.lock().unwrap();
446 let type_a = make_test_type("A", &[("x", 1)]);
447 let converter = AtomToRust::new(&[&type_a], &["a"]);
448
449 let neg_one = Atom::num(-1);
451 let result = converter.convert(&neg_one).to_string();
452 assert!(result.contains("one") && result.contains("-"));
453 }
454}