clifford_codegen/symbolic/product.rs
1//! Symbolic product computation.
2//!
3//! This module computes the symbolic output of product operations,
4//! representing each output field as a symbolic expression of input fields.
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8
9use symbolica::atom::{Atom, DefaultNamespace};
10use symbolica::parser::ParseSettings;
11
12use crate::algebra::{Algebra, Blade, ProductTable};
13use crate::spec::TypeSpec;
14
15/// The kind of product to compute symbolically.
16///
17/// Product naming follows [Rigid Geometric Algebra](https://rigidgeometricalgebra.org/) conventions.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ProductKind {
20 /// Geometric product (full product).
21 Geometric,
22 /// Geometric antiproduct (complement(complement(a) × complement(b))).
23 /// Used for versor composition with antisandwich-based transformations.
24 Antigeometric,
25 /// Wedge product (∧, exterior, grade-raising).
26 Wedge,
27 /// Left contraction (a ⌋ b, grade gb - ga when ga <= gb).
28 LeftContraction,
29 /// Right contraction (a ⌊ b, grade ga - gb when gb <= ga).
30 RightContraction,
31 /// Antiwedge product (∨, regressive/meet).
32 Antiwedge,
33 /// Bulk contraction (a ∨ b★).
34 BulkContraction,
35 /// Weight contraction (a ∨ b☆).
36 WeightContraction,
37 /// Bulk expansion (a ∧ b★).
38 BulkExpansion,
39 /// Weight expansion (a ∧ b☆).
40 WeightExpansion,
41 /// Dot product (• metric inner, same-grade only, returns scalar).
42 Dot,
43 /// Antidot product (⊚ metric antiproduct inner, same-antigrade only, returns scalar).
44 Antidot,
45 /// Scalar product (grade-0 projection of geometric product).
46 Scalar,
47 /// Projection: target ∨ (self ∧ target☆).
48 Project,
49 /// Antiprojection: target ∧ (self ∨ target☆).
50 Antiproject,
51}
52
53/// A symbolic field expression.
54#[derive(Debug, Clone)]
55pub struct SymbolicField {
56 /// Field name in the output type.
57 pub name: String,
58 /// Blade index this field represents.
59 pub blade_index: usize,
60 /// Symbolic expression for this field's value.
61 pub expression: Atom,
62}
63
64/// Computes symbolic product outputs.
65///
66/// Given input types with symbolic field values, computes the symbolic
67/// expressions for each output field.
68pub struct SymbolicProduct {
69 /// The product table.
70 table: ProductTable,
71}
72
73impl SymbolicProduct {
74 /// Creates a new symbolic product computer.
75 pub fn new(algebra: &Algebra) -> Self {
76 let table = ProductTable::new(algebra);
77 Self { table }
78 }
79
80 /// Creates symbolic variables for a type's fields.
81 ///
82 /// Returns a map from field name to symbolic atom.
83 pub fn create_field_symbols(&self, ty: &TypeSpec, prefix: &str) -> HashMap<String, Atom> {
84 ty.fields
85 .iter()
86 .map(|f| {
87 let symbol_name = format!("{}_{}", prefix, f.name);
88 let input = DefaultNamespace {
89 namespace: Cow::Borrowed(env!("CARGO_CRATE_NAME")),
90 data: symbol_name.as_str(),
91 file: Cow::Borrowed(file!()),
92 line: line!() as usize,
93 };
94 let atom = Atom::parse(input, ParseSettings::symbolica()).unwrap();
95 (f.name.clone(), atom)
96 })
97 .collect()
98 }
99
100 /// Computes the symbolic output of a product.
101 ///
102 /// # Arguments
103 ///
104 /// * `type_a` - Left operand type
105 /// * `type_b` - Right operand type
106 /// * `output_type` - Output type
107 /// * `kind` - Product kind (geometric, outer, etc.)
108 /// * `a_symbols` - Symbolic values for type_a fields
109 /// * `b_symbols` - Symbolic values for type_b fields
110 ///
111 /// # Returns
112 ///
113 /// Symbolic expressions for each output field.
114 pub fn compute(
115 &self,
116 type_a: &TypeSpec,
117 type_b: &TypeSpec,
118 output_type: &TypeSpec,
119 kind: ProductKind,
120 a_symbols: &HashMap<String, Atom>,
121 b_symbols: &HashMap<String, Atom>,
122 ) -> Vec<SymbolicField> {
123 output_type
124 .fields
125 .iter()
126 .map(|output_field| {
127 let expr = self.compute_field(
128 type_a,
129 type_b,
130 output_field.blade_index,
131 kind,
132 a_symbols,
133 b_symbols,
134 );
135 // Apply output field sign for non-canonical blade ordering
136 // (e.g., e31 = -e13, so sign = -1 and we negate the expression)
137 let signed_expr = if output_field.sign < 0 { -expr } else { expr };
138 SymbolicField {
139 name: output_field.name.clone(),
140 blade_index: output_field.blade_index,
141 expression: signed_expr,
142 }
143 })
144 .collect()
145 }
146
147 /// Computes the symbolic output of a sandwich product: v × x × rev(v).
148 ///
149 /// # Arguments
150 ///
151 /// * `versor_type` - Versor type (Motor, Rotor, Flector)
152 /// * `operand_type` - Operand type being transformed
153 /// * `versor_symbols` - Symbolic values for versor fields
154 /// * `operand_symbols` - Symbolic values for operand fields
155 /// * `use_antiproduct` - If true, uses antiproduct and antireverse (for antisandwich)
156 ///
157 /// # Returns
158 ///
159 /// Symbolic expressions for each output field (output has same type as operand).
160 pub fn compute_sandwich(
161 &self,
162 versor_type: &TypeSpec,
163 operand_type: &TypeSpec,
164 versor_symbols: &HashMap<String, Atom>,
165 operand_symbols: &HashMap<String, Atom>,
166 use_antiproduct: bool,
167 ) -> Vec<SymbolicField> {
168 operand_type
169 .fields
170 .iter()
171 .map(|output_field| {
172 let expr = self.compute_sandwich_field(
173 versor_type,
174 operand_type,
175 output_field,
176 versor_symbols,
177 operand_symbols,
178 use_antiproduct,
179 );
180 // Apply output field sign for non-canonical blade ordering
181 let signed_expr = if output_field.sign < 0 { -expr } else { expr };
182 SymbolicField {
183 name: output_field.name.clone(),
184 blade_index: output_field.blade_index,
185 expression: signed_expr,
186 }
187 })
188 .collect()
189 }
190
191 /// Computes the symbolic expression for a single sandwich output field.
192 ///
193 /// Computes: Σ_{i,j,k} v_i × x_j × rev(v_k) for the target blade.
194 fn compute_sandwich_field(
195 &self,
196 versor_type: &TypeSpec,
197 operand_type: &TypeSpec,
198 output_field: &crate::spec::FieldSpec,
199 versor_symbols: &HashMap<String, Atom>,
200 operand_symbols: &HashMap<String, Atom>,
201 use_antiproduct: bool,
202 ) -> Atom {
203 let dim = self.table.dim();
204 let result_blade = output_field.blade_index;
205 let mut terms: Vec<Atom> = Vec::new();
206
207 for field_v1 in &versor_type.fields {
208 for field_x in &operand_type.fields {
209 for field_v2 in &versor_type.fields {
210 let v1_blade = field_v1.blade_index;
211 let x_blade = field_x.blade_index;
212 let v2_blade = field_v2.blade_index;
213
214 // Compute v_i × x_j (or v_i ⊛ x_j)
215 let (sign_vx, vx) = if use_antiproduct {
216 self.table.antiproduct(v1_blade, x_blade)
217 } else {
218 self.table.geometric(v1_blade, x_blade)
219 };
220 if sign_vx == 0 {
221 continue;
222 }
223
224 // Compute the reverse/antireverse sign for v2
225 let v2_grade = Blade::from_index(v2_blade).grade();
226 let rev_sign: i8 = if use_antiproduct {
227 // Antireverse sign: (-1)^((n-k)(n-k-1)/2)
228 let antigrade = dim - v2_grade;
229 if (antigrade * antigrade.saturating_sub(1) / 2).is_multiple_of(2) {
230 1
231 } else {
232 -1
233 }
234 } else {
235 // Reverse sign: (-1)^(k(k-1)/2)
236 if (v2_grade * v2_grade.saturating_sub(1) / 2).is_multiple_of(2) {
237 1
238 } else {
239 -1
240 }
241 };
242
243 // Compute (v_i × x_j) × rev(v_k) (or (v_i ⊛ x_j) ⊛ antirev(v_k))
244 let (sign_vxr, result) = if use_antiproduct {
245 self.table.antiproduct(vx, v2_blade)
246 } else {
247 self.table.geometric(vx, v2_blade)
248 };
249 if sign_vxr == 0 {
250 continue;
251 }
252
253 if result != result_blade {
254 continue;
255 }
256
257 // Apply input field signs for non-canonical blade orderings.
258 // v1 and v2 are both from versor_type, x is from operand_type.
259 let input_sign = field_v1.sign * field_x.sign * field_v2.sign;
260 let final_sign = sign_vx * sign_vxr * rev_sign * input_sign;
261 if final_sign == 0 {
262 continue;
263 }
264
265 let v1_sym = versor_symbols.get(&field_v1.name).unwrap();
266 let x_sym = operand_symbols.get(&field_x.name).unwrap();
267 let v2_sym = versor_symbols.get(&field_v2.name).unwrap();
268
269 // Create term: sign * v1 * x * v2
270 let product = v1_sym * x_sym * v2_sym;
271 let term = if final_sign > 0 { product } else { -product };
272 terms.push(term);
273 }
274 }
275 }
276
277 if terms.is_empty() {
278 Atom::num(0)
279 } else {
280 // Sort terms by string representation for deterministic output
281 terms.sort_by_cached_key(|t| t.to_string());
282 terms.into_iter().reduce(|acc, t| acc + t).unwrap()
283 }
284 }
285
286 /// Computes the symbolic expression for a single output field.
287 fn compute_field(
288 &self,
289 type_a: &TypeSpec,
290 type_b: &TypeSpec,
291 result_blade: usize,
292 kind: ProductKind,
293 a_symbols: &HashMap<String, Atom>,
294 b_symbols: &HashMap<String, Atom>,
295 ) -> Atom {
296 let mut terms: Vec<Atom> = Vec::new();
297
298 for field_a in &type_a.fields {
299 for field_b in &type_b.fields {
300 let a_blade = field_a.blade_index;
301 let b_blade = field_b.blade_index;
302
303 // Compute the product based on kind using ProductTable methods
304 let (sign, result) = match kind {
305 ProductKind::Geometric => self.table.geometric(a_blade, b_blade),
306 ProductKind::Antigeometric => self.table.antiproduct(a_blade, b_blade),
307 ProductKind::Wedge => self.table.exterior(a_blade, b_blade),
308 ProductKind::LeftContraction => self.table.left_contraction(a_blade, b_blade),
309 ProductKind::RightContraction => self.table.right_contraction(a_blade, b_blade),
310 ProductKind::Antiwedge => self.table.regressive(a_blade, b_blade),
311 ProductKind::BulkContraction => self.table.bulk_contraction(a_blade, b_blade),
312 ProductKind::WeightContraction => {
313 self.table.weight_contraction(a_blade, b_blade)
314 }
315 ProductKind::BulkExpansion => self.table.bulk_expansion(a_blade, b_blade),
316 ProductKind::WeightExpansion => self.table.weight_expansion(a_blade, b_blade),
317 ProductKind::Dot => self.table.dot(a_blade, b_blade),
318 ProductKind::Antidot => self.table.antidot(a_blade, b_blade),
319 ProductKind::Scalar => {
320 // Scalar product: grade-0 projection of geometric product
321 let (s, r) = self.table.geometric(a_blade, b_blade);
322 // Only include if result is grade 0 (scalar blade index = 0)
323 let result_grade = Blade::from_index(r).grade();
324 if result_grade == 0 { (s, r) } else { (0, 0) }
325 }
326 ProductKind::Project => self.table.project(a_blade, b_blade),
327 ProductKind::Antiproject => self.table.antiproject(a_blade, b_blade),
328 };
329
330 if result != result_blade || sign == 0 {
331 continue;
332 }
333
334 let a_sym = a_symbols.get(&field_a.name).unwrap();
335 let b_sym = b_symbols.get(&field_b.name).unwrap();
336
337 // Apply input field signs for non-canonical blade orderings.
338 // If a field is specified as e.g., e31 (sign = -1), the stored value
339 // represents the negated canonical blade value.
340 let input_sign = sign * field_a.sign * field_b.sign;
341
342 // Create term: total_sign * a_field * b_field
343 let product = a_sym * b_sym;
344 let term = if input_sign > 0 { product } else { -product };
345 terms.push(term);
346 }
347 }
348
349 if terms.is_empty() {
350 // Return symbolic zero
351 Atom::num(0)
352 } else {
353 // Sort terms by string representation for deterministic output
354 terms.sort_by_cached_key(|t| t.to_string());
355 // Sum all terms
356 terms.into_iter().reduce(|acc, t| acc + t).unwrap()
357 }
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::spec::parse_spec;
365 use std::sync::Mutex;
366
367 // Symbolica uses global state that conflicts when tests run in parallel.
368 // Tests prefixed with `symbolica_` are configured to run serially via nextest.
369 // The mutex provides a fallback for `cargo test` users.
370 static SYMBOLICA_LOCK: Mutex<()> = Mutex::new(());
371
372 #[test]
373 fn symbolica_create_symbols_for_vector() {
374 let _guard = SYMBOLICA_LOCK.lock().unwrap();
375 let spec = parse_spec(
376 r#"
377 [algebra]
378 name = "test"
379 complete = false
380
381 [signature]
382 positive = ["e1", "e2", "e3"]
383
384 [types.Vector]
385 grades = [1]
386 field_map = [
387 { name = "x", blade = "e1" },
388 { name = "y", blade = "e2" },
389 { name = "z", blade = "e3" }
390 ]
391 "#,
392 )
393 .unwrap();
394
395 let algebra = Algebra::euclidean(3);
396 let product = SymbolicProduct::new(&algebra);
397 let vector = spec.types.iter().find(|t| t.name == "Vector").unwrap();
398
399 let symbols = product.create_field_symbols(vector, "a");
400
401 assert!(symbols.contains_key("x"));
402 assert!(symbols.contains_key("y"));
403 assert!(symbols.contains_key("z"));
404 }
405}