Skip to main content

clifford_codegen/symbolic/
to_rust.rs

1//! Convert Symbolica expressions to Rust code.
2//!
3//! This module provides utilities for converting Symbolica `Atom` expressions
4//! into Rust `TokenStream` code that can be compiled.
5
6use std::collections::HashMap;
7
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote};
10use symbolica::atom::{Atom, AtomCore, AtomView};
11use symbolica::coefficient::CoefficientView;
12
13use crate::spec::TypeSpec;
14
15/// Converts Symbolica `Atom` expressions to Rust `TokenStream`.
16///
17/// This converter handles the translation of symbolic mathematical expressions
18/// into efficient Rust code, properly handling:
19/// - Field accessors (e.g., `a_x` → `a.x()`)
20/// - Numeric constants (0, 1, 2, etc.)
21/// - Arithmetic operations (+, -, *, powers)
22///
23/// # Example
24///
25/// ```ignore
26/// use clifford_codegen::symbolic::AtomToRust;
27///
28/// let converter = AtomToRust::new(&[&type_a, &type_b], &["a", "b"]);
29/// let atom = Atom::parse("a_x * b_y + a_y * b_x").unwrap();
30/// let code = converter.convert(&atom);
31/// // Produces: a.x() * b.y() + a.y() * b.x()
32/// ```
33pub struct AtomToRust {
34    /// Map from symbol name (e.g., "a_x") to (prefix, field) for accessor generation.
35    symbol_map: HashMap<String, (String, String)>,
36    /// Set of prefixes that require `.as_inner()` for field access (wrapper types).
37    wrapped_prefixes: std::collections::HashSet<String>,
38}
39
40impl AtomToRust {
41    /// Creates a new converter for the given types and prefixes.
42    ///
43    /// # Arguments
44    ///
45    /// * `types` - The types whose fields will appear in expressions
46    /// * `prefixes` - The variable prefixes used (e.g., "a", "b")
47    ///
48    /// # Example
49    ///
50    /// ```ignore
51    /// let converter = AtomToRust::new(&[&vector_type, &bivector_type], &["a", "b"]);
52    /// ```
53    pub fn new(types: &[&TypeSpec], prefixes: &[&str]) -> Self {
54        Self::new_with_wrappers(types, prefixes, &[])
55    }
56
57    /// Creates a converter with wrapper type support.
58    ///
59    /// # Arguments
60    ///
61    /// * `types` - The types whose fields will appear in expressions
62    /// * `prefixes` - The variable prefixes used (e.g., "self", "rhs")
63    /// * `wrapped_prefixes` - Prefixes that use wrapper types (need `.as_inner()`)
64    ///
65    /// # Example
66    ///
67    /// ```ignore
68    /// // For Wedge<B> for Unit<A>: self is wrapped, rhs is not
69    /// let converter = AtomToRust::new_with_wrappers(
70    ///     &[&type_a, &type_b],
71    ///     &["self", "rhs"],
72    ///     &["self"],  // self needs .as_inner()
73    /// );
74    /// ```
75    pub fn new_with_wrappers(
76        types: &[&TypeSpec],
77        prefixes: &[&str],
78        wrapper_prefixes: &[&str],
79    ) -> Self {
80        let mut symbol_map = HashMap::new();
81
82        for (ty, prefix) in types.iter().zip(prefixes.iter()) {
83            for field in &ty.fields {
84                let symbol_name = format!("{}_{}", prefix, field.name);
85                symbol_map.insert(symbol_name, ((*prefix).to_string(), field.name.clone()));
86            }
87        }
88
89        let wrapped_prefixes = wrapper_prefixes.iter().map(|s| (*s).to_string()).collect();
90
91        Self {
92            symbol_map,
93            wrapped_prefixes,
94        }
95    }
96
97    /// Converts a Symbolica `Atom` to a Rust `TokenStream`.
98    pub fn convert(&self, atom: &Atom) -> TokenStream {
99        self.convert_view(atom.as_atom_view())
100    }
101
102    /// Converts an `AtomView` to a Rust `TokenStream`.
103    fn convert_view(&self, view: AtomView<'_>) -> TokenStream {
104        match view {
105            AtomView::Num(n) => self.convert_num(n),
106            AtomView::Var(v) => self.convert_var(v),
107            AtomView::Add(a) => self.convert_add(a),
108            AtomView::Mul(m) => self.convert_mul(m),
109            AtomView::Pow(p) => self.convert_pow(p),
110            AtomView::Fun(_) => {
111                // Functions are not expected in our expressions
112                quote! { T::zero() }
113            }
114        }
115    }
116
117    /// Converts a numeric value to Rust.
118    fn convert_num(&self, num: symbolica::atom::NumView<'_>) -> TokenStream {
119        let coeff = num.get_coeff_view();
120
121        match coeff {
122            CoefficientView::Natural(n_re, d_re, n_im, d_im) => {
123                // We only handle real numbers
124                if n_im != 0 || d_im != 1 {
125                    // Complex number - shouldn't happen in our use case
126                    return quote! { T::zero() };
127                }
128
129                if d_re == 1 {
130                    // Integer
131                    self.convert_integer(n_re)
132                } else {
133                    // Rational - convert to float
134                    let val = n_re as f64 / d_re as f64;
135                    quote! { T::from_f64(#val) }
136                }
137            }
138            CoefficientView::Float(_, _) => {
139                // Float coefficients are rare in our use case
140                // Fall back to zero - this shouldn't happen in practice
141                quote! { T::zero() }
142            }
143            CoefficientView::Large(_, _) => {
144                // Large coefficients are rare in our use case
145                // Fall back to zero - this shouldn't happen in practice
146                quote! { T::zero() }
147            }
148            CoefficientView::FiniteField(_, _) => {
149                // Finite field - shouldn't happen
150                quote! { T::zero() }
151            }
152            CoefficientView::RationalPolynomial(_) => {
153                // Rational polynomial - shouldn't happen
154                quote! { T::zero() }
155            }
156            CoefficientView::Indeterminate | CoefficientView::Infinity(_) => {
157                // Indeterminate or infinity - shouldn't happen in our use case
158                quote! { T::zero() }
159            }
160        }
161    }
162
163    /// Converts an integer to the appropriate Rust expression.
164    fn convert_integer(&self, n: i64) -> TokenStream {
165        match n {
166            0 => quote! { T::zero() },
167            1 => quote! { T::one() },
168            2 => quote! { T::TWO },
169            -1 => quote! { -T::one() },
170            -2 => quote! { -T::TWO },
171            _ if n >= i8::MIN as i64 && n <= i8::MAX as i64 => {
172                let n_i8 = n as i8;
173                quote! { T::from_i8(#n_i8) }
174            }
175            _ => {
176                let n_f64 = n as f64;
177                quote! { T::from_f64(#n_f64) }
178            }
179        }
180    }
181
182    /// Converts a variable reference to Rust.
183    fn convert_var(&self, var: symbolica::atom::VarView<'_>) -> TokenStream {
184        let symbol = var.get_symbol();
185        let name = symbol.to_string();
186
187        if let Some((prefix, field)) = self.symbol_map.get(&name) {
188            // It's a field reference like "a_x"
189            let prefix_ident = format_ident!("{}", prefix);
190            let field_ident = format_ident!("{}", field);
191
192            // For wrapper types, access via .as_inner()
193            if self.wrapped_prefixes.contains(prefix) {
194                quote! { #prefix_ident.as_inner().#field_ident() }
195            } else {
196                quote! { #prefix_ident.#field_ident() }
197            }
198        } else {
199            // Unknown variable - emit as identifier (for constants or errors)
200            let ident = format_ident!("{}", name.replace("clifford_codegen::", ""));
201            quote! { #ident }
202        }
203    }
204
205    /// Converts an addition to Rust, properly handling signs.
206    fn convert_add(&self, add: symbolica::atom::AddView<'_>) -> TokenStream {
207        let terms: Vec<AtomView<'_>> = add.iter().collect();
208
209        if terms.is_empty() {
210            return quote! { T::zero() };
211        }
212
213        // Convert all terms to (is_negative, token_stream) pairs first
214        let mut converted: Vec<(bool, TokenStream)> = terms
215            .iter()
216            .map(|term| self.extract_negation(*term))
217            .collect();
218
219        // Sort by the final Rust code string for deterministic output
220        converted.sort_by_cached_key(|(_, tokens)| tokens.to_string());
221
222        // Build the result from sorted terms
223        let (first_neg, first_expr) = converted.remove(0);
224        let mut result = if first_neg {
225            quote! { -(#first_expr) }
226        } else {
227            first_expr
228        };
229
230        for (is_neg, term_expr) in converted {
231            if is_neg {
232                result = quote! { #result - #term_expr };
233            } else {
234                result = quote! { #result + #term_expr };
235            }
236        }
237
238        result
239    }
240
241    /// Extracts negation from an expression.
242    ///
243    /// Returns (is_negative, positive_expression).
244    fn extract_negation(&self, view: AtomView<'_>) -> (bool, TokenStream) {
245        match view {
246            AtomView::Num(n) => {
247                let coeff = n.get_coeff_view();
248                if self.is_negative_coefficient(&coeff) {
249                    // Negate and return positive version
250                    let atom = view.to_owned();
251                    let negated = -&atom;
252                    (true, self.convert(&negated))
253                } else {
254                    (false, self.convert_num(n))
255                }
256            }
257            AtomView::Mul(m) => {
258                // Check if first factor is a negative number
259                let factors: Vec<AtomView<'_>> = m.iter().collect();
260                if let Some(AtomView::Num(n)) = factors.first() {
261                    let coeff = n.get_coeff_view();
262                    if self.is_negative_coefficient(&coeff) {
263                        // Negate the multiplication
264                        let atom = view.to_owned();
265                        let negated = -&atom;
266                        let expanded = negated.expand();
267                        return (true, self.convert(&expanded));
268                    }
269                }
270                (false, self.convert_mul(m))
271            }
272            _ => (false, self.convert_view(view)),
273        }
274    }
275
276    /// Checks if a coefficient is negative.
277    fn is_negative_coefficient(&self, coeff: &CoefficientView) -> bool {
278        match coeff {
279            CoefficientView::Natural(n_re, _, _, _) => *n_re < 0,
280            // Float and Large coefficients are rare in our use case
281            // Fall back to false - this shouldn't happen in practice
282            CoefficientView::Float(_, _) | CoefficientView::Large(_, _) => false,
283            _ => false,
284        }
285    }
286
287    /// Converts a multiplication to Rust.
288    fn convert_mul(&self, mul: symbolica::atom::MulView<'_>) -> TokenStream {
289        let factors: Vec<AtomView<'_>> = mul.iter().collect();
290
291        if factors.is_empty() {
292            return quote! { T::one() };
293        }
294
295        // Split out numeric coefficient if present
296        let (coeff, remaining) = self.split_coefficient(&factors);
297
298        if remaining.is_empty() {
299            // Just a number
300            return coeff.unwrap_or_else(|| quote! { T::one() });
301        }
302
303        // Convert remaining factors to TokenStreams
304        let mut factor_exprs: Vec<TokenStream> =
305            remaining.iter().map(|f| self.convert_view(*f)).collect();
306
307        // Sort by final Rust code string for deterministic output
308        factor_exprs.sort_by_cached_key(|tokens| tokens.to_string());
309
310        // Build product
311        let product = if factor_exprs.len() == 1 {
312            factor_exprs[0].clone()
313        } else {
314            let first = &factor_exprs[0];
315            let rest = &factor_exprs[1..];
316            quote! { #first #(* #rest)* }
317        };
318
319        // Apply coefficient
320        match coeff {
321            None => product,
322            Some(c) => {
323                // Check if it's just T::one() or -T::one()
324                let c_str = c.to_string();
325                if c_str.contains("T :: one ()") && !c_str.contains('-') {
326                    product
327                } else if c_str == "- T :: one ()" {
328                    quote! { -(#product) }
329                } else {
330                    quote! { #c * #product }
331                }
332            }
333        }
334    }
335
336    /// Splits out the numeric coefficient from factors.
337    ///
338    /// Returns (coefficient_tokens, remaining_factors) where remaining_factors
339    /// contains copies of non-numeric AtomViews.
340    fn split_coefficient<'b>(
341        &self,
342        factors: &[AtomView<'b>],
343    ) -> (Option<TokenStream>, Vec<AtomView<'b>>) {
344        let mut coeff = None;
345        let mut remaining = Vec::new();
346
347        for factor in factors {
348            if let AtomView::Num(n) = factor {
349                if coeff.is_none() {
350                    coeff = Some(self.convert_num(*n));
351                    continue;
352                }
353            }
354            remaining.push(*factor);
355        }
356
357        (coeff, remaining)
358    }
359
360    /// Converts a power expression to Rust.
361    fn convert_pow(&self, pow: symbolica::atom::PowView<'_>) -> TokenStream {
362        let (base, exp) = pow.get_base_exp();
363
364        // Check for simple integer exponents
365        if let AtomView::Num(n) = exp {
366            let coeff = n.get_coeff_view();
367            if let CoefficientView::Natural(exp_val, 1, 0, 1) = coeff {
368                match exp_val {
369                    0 => return quote! { T::one() },
370                    1 => return self.convert_view(base),
371                    2 => {
372                        let base_expr = self.convert_view(base);
373                        return quote! { #base_expr * #base_expr };
374                    }
375                    3 => {
376                        let base_expr = self.convert_view(base);
377                        return quote! { #base_expr * #base_expr * #base_expr };
378                    }
379                    _ => {}
380                }
381            }
382        }
383
384        // General case: use powi for integers, powf otherwise
385        let base_expr = self.convert_view(base);
386        let exp_expr = self.convert_view(exp);
387        quote! { #base_expr.powf(#exp_expr) }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::spec::FieldSpec;
395    use std::sync::Mutex;
396    use symbolica::atom::Atom;
397
398    // Symbolica uses global state that conflicts when tests run in parallel.
399    // Tests prefixed with `symbolica_` are configured to run serially via nextest.
400    // The mutex provides a fallback for `cargo test` users.
401    static SYMBOLICA_LOCK: Mutex<()> = Mutex::new(());
402
403    fn make_test_type(name: &str, fields: &[(&str, usize)]) -> TypeSpec {
404        TypeSpec {
405            name: name.to_string(),
406            grades: vec![1],
407            description: None,
408            fields: fields
409                .iter()
410                .map(|(n, idx)| FieldSpec {
411                    name: n.to_string(),
412                    blade_index: *idx,
413                    grade: 1,
414                    sign: 1,
415                })
416                .collect(),
417            alias_of: None,
418            versor: None,
419            is_sparse: false,
420            inverse_sandwich_targets: vec![],
421        }
422    }
423
424    #[test]
425    fn symbolica_convert_integer_constants() {
426        let _guard = SYMBOLICA_LOCK.lock().unwrap();
427        let type_a = make_test_type("A", &[("x", 1)]);
428        let converter = AtomToRust::new(&[&type_a], &["a"]);
429
430        // Test zero
431        let zero = Atom::num(0);
432        assert!(converter.convert(&zero).to_string().contains("zero"));
433
434        // Test one
435        let one = Atom::num(1);
436        assert!(converter.convert(&one).to_string().contains("one"));
437
438        // Test two
439        let two = Atom::num(2);
440        assert!(converter.convert(&two).to_string().contains("TWO"));
441    }
442
443    #[test]
444    fn symbolica_convert_negative_integers() {
445        let _guard = SYMBOLICA_LOCK.lock().unwrap();
446        let type_a = make_test_type("A", &[("x", 1)]);
447        let converter = AtomToRust::new(&[&type_a], &["a"]);
448
449        // Test -1
450        let neg_one = Atom::num(-1);
451        let result = converter.convert(&neg_one).to_string();
452        assert!(result.contains("one") && result.contains("-"));
453    }
454}