Skip to main content

clifford_codegen/symbolic/
product.rs

1//! Symbolic product computation.
2//!
3//! This module computes the symbolic output of product operations,
4//! representing each output field as a symbolic expression of input fields.
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8
9use symbolica::atom::{Atom, DefaultNamespace};
10use symbolica::parser::ParseSettings;
11
12use crate::algebra::{Algebra, Blade, ProductTable};
13use crate::spec::TypeSpec;
14
15/// The kind of product to compute symbolically.
16///
17/// Product naming follows [Rigid Geometric Algebra](https://rigidgeometricalgebra.org/) conventions.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ProductKind {
20    /// Geometric product (full product).
21    Geometric,
22    /// Geometric antiproduct (complement(complement(a) × complement(b))).
23    /// Used for versor composition with antisandwich-based transformations.
24    Antigeometric,
25    /// Wedge product (∧, exterior, grade-raising).
26    Wedge,
27    /// Left contraction (a ⌋ b, grade gb - ga when ga <= gb).
28    LeftContraction,
29    /// Right contraction (a ⌊ b, grade ga - gb when gb <= ga).
30    RightContraction,
31    /// Antiwedge product (∨, regressive/meet).
32    Antiwedge,
33    /// Bulk contraction (a ∨ b★).
34    BulkContraction,
35    /// Weight contraction (a ∨ b☆).
36    WeightContraction,
37    /// Bulk expansion (a ∧ b★).
38    BulkExpansion,
39    /// Weight expansion (a ∧ b☆).
40    WeightExpansion,
41    /// Dot product (• metric inner, same-grade only, returns scalar).
42    Dot,
43    /// Antidot product (⊚ metric antiproduct inner, same-antigrade only, returns scalar).
44    Antidot,
45    /// Scalar product (grade-0 projection of geometric product).
46    Scalar,
47    /// Projection: target ∨ (self ∧ target☆).
48    Project,
49    /// Antiprojection: target ∧ (self ∨ target☆).
50    Antiproject,
51}
52
53/// A symbolic field expression.
54#[derive(Debug, Clone)]
55pub struct SymbolicField {
56    /// Field name in the output type.
57    pub name: String,
58    /// Blade index this field represents.
59    pub blade_index: usize,
60    /// Symbolic expression for this field's value.
61    pub expression: Atom,
62}
63
64/// Computes symbolic product outputs.
65///
66/// Given input types with symbolic field values, computes the symbolic
67/// expressions for each output field.
68pub struct SymbolicProduct {
69    /// The product table.
70    table: ProductTable,
71}
72
73impl SymbolicProduct {
74    /// Creates a new symbolic product computer.
75    pub fn new(algebra: &Algebra) -> Self {
76        let table = ProductTable::new(algebra);
77        Self { table }
78    }
79
80    /// Creates symbolic variables for a type's fields.
81    ///
82    /// Returns a map from field name to symbolic atom.
83    pub fn create_field_symbols(&self, ty: &TypeSpec, prefix: &str) -> HashMap<String, Atom> {
84        ty.fields
85            .iter()
86            .map(|f| {
87                let symbol_name = format!("{}_{}", prefix, f.name);
88                let input = DefaultNamespace {
89                    namespace: Cow::Borrowed(env!("CARGO_CRATE_NAME")),
90                    data: symbol_name.as_str(),
91                    file: Cow::Borrowed(file!()),
92                    line: line!() as usize,
93                };
94                let atom = Atom::parse(input, ParseSettings::symbolica()).unwrap();
95                (f.name.clone(), atom)
96            })
97            .collect()
98    }
99
100    /// Computes the symbolic output of a product.
101    ///
102    /// # Arguments
103    ///
104    /// * `type_a` - Left operand type
105    /// * `type_b` - Right operand type
106    /// * `output_type` - Output type
107    /// * `kind` - Product kind (geometric, outer, etc.)
108    /// * `a_symbols` - Symbolic values for type_a fields
109    /// * `b_symbols` - Symbolic values for type_b fields
110    ///
111    /// # Returns
112    ///
113    /// Symbolic expressions for each output field.
114    pub fn compute(
115        &self,
116        type_a: &TypeSpec,
117        type_b: &TypeSpec,
118        output_type: &TypeSpec,
119        kind: ProductKind,
120        a_symbols: &HashMap<String, Atom>,
121        b_symbols: &HashMap<String, Atom>,
122    ) -> Vec<SymbolicField> {
123        output_type
124            .fields
125            .iter()
126            .map(|output_field| {
127                let expr = self.compute_field(
128                    type_a,
129                    type_b,
130                    output_field.blade_index,
131                    kind,
132                    a_symbols,
133                    b_symbols,
134                );
135                // Apply output field sign for non-canonical blade ordering
136                // (e.g., e31 = -e13, so sign = -1 and we negate the expression)
137                let signed_expr = if output_field.sign < 0 { -expr } else { expr };
138                SymbolicField {
139                    name: output_field.name.clone(),
140                    blade_index: output_field.blade_index,
141                    expression: signed_expr,
142                }
143            })
144            .collect()
145    }
146
147    /// Computes the symbolic output of a sandwich product: v × x × rev(v).
148    ///
149    /// # Arguments
150    ///
151    /// * `versor_type` - Versor type (Motor, Rotor, Flector)
152    /// * `operand_type` - Operand type being transformed
153    /// * `versor_symbols` - Symbolic values for versor fields
154    /// * `operand_symbols` - Symbolic values for operand fields
155    /// * `use_antiproduct` - If true, uses antiproduct and antireverse (for antisandwich)
156    ///
157    /// # Returns
158    ///
159    /// Symbolic expressions for each output field (output has same type as operand).
160    pub fn compute_sandwich(
161        &self,
162        versor_type: &TypeSpec,
163        operand_type: &TypeSpec,
164        versor_symbols: &HashMap<String, Atom>,
165        operand_symbols: &HashMap<String, Atom>,
166        use_antiproduct: bool,
167    ) -> Vec<SymbolicField> {
168        operand_type
169            .fields
170            .iter()
171            .map(|output_field| {
172                let expr = self.compute_sandwich_field(
173                    versor_type,
174                    operand_type,
175                    output_field,
176                    versor_symbols,
177                    operand_symbols,
178                    use_antiproduct,
179                );
180                // Apply output field sign for non-canonical blade ordering
181                let signed_expr = if output_field.sign < 0 { -expr } else { expr };
182                SymbolicField {
183                    name: output_field.name.clone(),
184                    blade_index: output_field.blade_index,
185                    expression: signed_expr,
186                }
187            })
188            .collect()
189    }
190
191    /// Computes the symbolic expression for a single sandwich output field.
192    ///
193    /// Computes: Σ_{i,j,k} v_i × x_j × rev(v_k) for the target blade.
194    fn compute_sandwich_field(
195        &self,
196        versor_type: &TypeSpec,
197        operand_type: &TypeSpec,
198        output_field: &crate::spec::FieldSpec,
199        versor_symbols: &HashMap<String, Atom>,
200        operand_symbols: &HashMap<String, Atom>,
201        use_antiproduct: bool,
202    ) -> Atom {
203        let dim = self.table.dim();
204        let result_blade = output_field.blade_index;
205        let mut terms: Vec<Atom> = Vec::new();
206
207        for field_v1 in &versor_type.fields {
208            for field_x in &operand_type.fields {
209                for field_v2 in &versor_type.fields {
210                    let v1_blade = field_v1.blade_index;
211                    let x_blade = field_x.blade_index;
212                    let v2_blade = field_v2.blade_index;
213
214                    // Compute v_i × x_j (or v_i ⊛ x_j)
215                    let (sign_vx, vx) = if use_antiproduct {
216                        self.table.antiproduct(v1_blade, x_blade)
217                    } else {
218                        self.table.geometric(v1_blade, x_blade)
219                    };
220                    if sign_vx == 0 {
221                        continue;
222                    }
223
224                    // Compute the reverse/antireverse sign for v2
225                    let v2_grade = Blade::from_index(v2_blade).grade();
226                    let rev_sign: i8 = if use_antiproduct {
227                        // Antireverse sign: (-1)^((n-k)(n-k-1)/2)
228                        let antigrade = dim - v2_grade;
229                        if (antigrade * antigrade.saturating_sub(1) / 2).is_multiple_of(2) {
230                            1
231                        } else {
232                            -1
233                        }
234                    } else {
235                        // Reverse sign: (-1)^(k(k-1)/2)
236                        if (v2_grade * v2_grade.saturating_sub(1) / 2).is_multiple_of(2) {
237                            1
238                        } else {
239                            -1
240                        }
241                    };
242
243                    // Compute (v_i × x_j) × rev(v_k) (or (v_i ⊛ x_j) ⊛ antirev(v_k))
244                    let (sign_vxr, result) = if use_antiproduct {
245                        self.table.antiproduct(vx, v2_blade)
246                    } else {
247                        self.table.geometric(vx, v2_blade)
248                    };
249                    if sign_vxr == 0 {
250                        continue;
251                    }
252
253                    if result != result_blade {
254                        continue;
255                    }
256
257                    // Apply input field signs for non-canonical blade orderings.
258                    // v1 and v2 are both from versor_type, x is from operand_type.
259                    let input_sign = field_v1.sign * field_x.sign * field_v2.sign;
260                    let final_sign = sign_vx * sign_vxr * rev_sign * input_sign;
261                    if final_sign == 0 {
262                        continue;
263                    }
264
265                    let v1_sym = versor_symbols.get(&field_v1.name).unwrap();
266                    let x_sym = operand_symbols.get(&field_x.name).unwrap();
267                    let v2_sym = versor_symbols.get(&field_v2.name).unwrap();
268
269                    // Create term: sign * v1 * x * v2
270                    let product = v1_sym * x_sym * v2_sym;
271                    let term = if final_sign > 0 { product } else { -product };
272                    terms.push(term);
273                }
274            }
275        }
276
277        if terms.is_empty() {
278            Atom::num(0)
279        } else {
280            // Sort terms by string representation for deterministic output
281            terms.sort_by_cached_key(|t| t.to_string());
282            terms.into_iter().reduce(|acc, t| acc + t).unwrap()
283        }
284    }
285
286    /// Computes the symbolic expression for a single output field.
287    fn compute_field(
288        &self,
289        type_a: &TypeSpec,
290        type_b: &TypeSpec,
291        result_blade: usize,
292        kind: ProductKind,
293        a_symbols: &HashMap<String, Atom>,
294        b_symbols: &HashMap<String, Atom>,
295    ) -> Atom {
296        let mut terms: Vec<Atom> = Vec::new();
297
298        for field_a in &type_a.fields {
299            for field_b in &type_b.fields {
300                let a_blade = field_a.blade_index;
301                let b_blade = field_b.blade_index;
302
303                // Compute the product based on kind using ProductTable methods
304                let (sign, result) = match kind {
305                    ProductKind::Geometric => self.table.geometric(a_blade, b_blade),
306                    ProductKind::Antigeometric => self.table.antiproduct(a_blade, b_blade),
307                    ProductKind::Wedge => self.table.exterior(a_blade, b_blade),
308                    ProductKind::LeftContraction => self.table.left_contraction(a_blade, b_blade),
309                    ProductKind::RightContraction => self.table.right_contraction(a_blade, b_blade),
310                    ProductKind::Antiwedge => self.table.regressive(a_blade, b_blade),
311                    ProductKind::BulkContraction => self.table.bulk_contraction(a_blade, b_blade),
312                    ProductKind::WeightContraction => {
313                        self.table.weight_contraction(a_blade, b_blade)
314                    }
315                    ProductKind::BulkExpansion => self.table.bulk_expansion(a_blade, b_blade),
316                    ProductKind::WeightExpansion => self.table.weight_expansion(a_blade, b_blade),
317                    ProductKind::Dot => self.table.dot(a_blade, b_blade),
318                    ProductKind::Antidot => self.table.antidot(a_blade, b_blade),
319                    ProductKind::Scalar => {
320                        // Scalar product: grade-0 projection of geometric product
321                        let (s, r) = self.table.geometric(a_blade, b_blade);
322                        // Only include if result is grade 0 (scalar blade index = 0)
323                        let result_grade = Blade::from_index(r).grade();
324                        if result_grade == 0 { (s, r) } else { (0, 0) }
325                    }
326                    ProductKind::Project => self.table.project(a_blade, b_blade),
327                    ProductKind::Antiproject => self.table.antiproject(a_blade, b_blade),
328                };
329
330                if result != result_blade || sign == 0 {
331                    continue;
332                }
333
334                let a_sym = a_symbols.get(&field_a.name).unwrap();
335                let b_sym = b_symbols.get(&field_b.name).unwrap();
336
337                // Apply input field signs for non-canonical blade orderings.
338                // If a field is specified as e.g., e31 (sign = -1), the stored value
339                // represents the negated canonical blade value.
340                let input_sign = sign * field_a.sign * field_b.sign;
341
342                // Create term: total_sign * a_field * b_field
343                let product = a_sym * b_sym;
344                let term = if input_sign > 0 { product } else { -product };
345                terms.push(term);
346            }
347        }
348
349        if terms.is_empty() {
350            // Return symbolic zero
351            Atom::num(0)
352        } else {
353            // Sort terms by string representation for deterministic output
354            terms.sort_by_cached_key(|t| t.to_string());
355            // Sum all terms
356            terms.into_iter().reduce(|acc, t| acc + t).unwrap()
357        }
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::spec::parse_spec;
365    use std::sync::Mutex;
366
367    // Symbolica uses global state that conflicts when tests run in parallel.
368    // Tests prefixed with `symbolica_` are configured to run serially via nextest.
369    // The mutex provides a fallback for `cargo test` users.
370    static SYMBOLICA_LOCK: Mutex<()> = Mutex::new(());
371
372    #[test]
373    fn symbolica_create_symbols_for_vector() {
374        let _guard = SYMBOLICA_LOCK.lock().unwrap();
375        let spec = parse_spec(
376            r#"
377            [algebra]
378            name = "test"
379            complete = false
380
381            [signature]
382            positive = ["e1", "e2", "e3"]
383
384            [types.Vector]
385            grades = [1]
386            field_map = [
387                { name = "x", blade = "e1" },
388                { name = "y", blade = "e2" },
389                { name = "z", blade = "e3" }
390            ]
391            "#,
392        )
393        .unwrap();
394
395        let algebra = Algebra::euclidean(3);
396        let product = SymbolicProduct::new(&algebra);
397        let vector = spec.types.iter().find(|t| t.name == "Vector").unwrap();
398
399        let symbols = product.create_field_symbols(vector, "a");
400
401        assert!(symbols.contains_key("x"));
402        assert!(symbols.contains_key("y"));
403        assert!(symbols.contains_key("z"));
404    }
405}