clifford_codegen/symbolic/constraint_simplify.rs
1//! Constraint-based expression simplification.
2//!
3//! This module applies type constraints to simplify symbolic expressions.
4//! For example, if a Rotor satisfies `s*s + xy*xy + xz*xz + yz*yz = 1`,
5//! any occurrence of that pattern in an expression can be replaced with `1`.
6
7use symbolica::atom::{Atom, AtomCore};
8
9use crate::spec::TypeSpec;
10
11/// Simplifies expressions using type constraints.
12///
13/// This simplifier recognizes patterns from type constraints and substitutes
14/// them with the constraint's constant value.
15///
16/// # Example
17///
18/// For a Rotor with constraint `s*s + xy*xy + xz*xz + yz*yz = 1`:
19/// - Pattern: `a_s*a_s + a_xy*a_xy + a_xz*a_xz + a_yz*a_yz`
20/// - Substituted with: `1`
21pub struct ConstraintSimplifier {
22 /// Map from pattern atom to replacement value.
23 /// Key: the expanded constraint LHS as Atom
24 /// Value: the RHS constant as Atom
25 substitutions: Vec<(Atom, Atom)>,
26}
27
28impl ConstraintSimplifier {
29 /// Creates a new constraint simplifier for the given types.
30 ///
31 /// # Arguments
32 ///
33 /// * `types` - The types being operated on
34 /// * `prefixes` - The variable prefixes used (e.g., "a", "b")
35 ///
36 /// Note: Constraints are now auto-derived during code generation rather than
37 /// stored in TypeSpec. This simplifier currently returns an empty instance.
38 /// Constraint-based simplification can be reimplemented using ConstraintDeriver
39 /// if needed.
40 #[allow(unused_variables)]
41 pub fn new(types: &[&TypeSpec], prefixes: &[&str]) -> Self {
42 // Constraints are now auto-derived, not stored in TypeSpec.
43 // Return empty simplifier for now.
44 Self {
45 substitutions: Vec::new(),
46 }
47 }
48
49 /// Applies constraint substitutions to an expression.
50 ///
51 /// Looks for subexpressions matching constraint patterns and replaces them.
52 pub fn apply(&self, expr: &Atom) -> Atom {
53 if self.substitutions.is_empty() {
54 return expr.clone();
55 }
56
57 // Expand the expression to canonical form for matching
58 let mut result = expr.expand();
59
60 // Try each substitution
61 for (pattern, value) in &self.substitutions {
62 result = Self::substitute_pattern(&result, pattern, value);
63 }
64
65 result
66 }
67
68 /// Substitutes a pattern in an expression.
69 ///
70 /// This is a simple approach that checks if the expression contains
71 /// the pattern as a subexpression.
72 fn substitute_pattern(expr: &Atom, pattern: &Atom, value: &Atom) -> Atom {
73 // Check if expr equals pattern directly
74 if Self::atoms_equal(expr, pattern) {
75 return value.clone();
76 }
77
78 // Check if expr contains pattern as a subexpression in a sum
79 // expr = pattern + rest => expr = value + rest
80 if let Some(result) = Self::try_substitute_in_sum(expr, pattern, value) {
81 return result;
82 }
83
84 // For more complex cases, we could use Symbolica's pattern matching
85 // but for now, return expr unchanged
86 expr.clone()
87 }
88
89 /// Checks if two atoms are equal.
90 ///
91 /// Uses Symbolica's native `PartialEq` implementation on expanded forms.
92 fn atoms_equal(a: &Atom, b: &Atom) -> bool {
93 // Expand both to canonical form for comparison
94 let a_expanded = a.expand();
95 let b_expanded = b.expand();
96
97 // Use Symbolica's PartialEq implementation directly
98 a_expanded == b_expanded
99 }
100
101 /// Tries to substitute a pattern within a sum expression.
102 ///
103 /// If expr = pattern + rest, returns value + rest.
104 fn try_substitute_in_sum(expr: &Atom, pattern: &Atom, value: &Atom) -> Option<Atom> {
105 // Compute expr - pattern
106 let difference = expr - pattern;
107 let simplified_diff = difference.expand();
108 let expanded_expr = expr.expand();
109
110 // Count terms in the expanded expressions
111 let expr_terms = Self::count_terms(&expanded_expr);
112 let diff_terms = Self::count_terms(&simplified_diff);
113 let pattern_terms = Self::count_terms(&pattern.expand());
114
115 // If the pattern was present, subtracting it should reduce terms
116 // The difference should have fewer terms than the original
117 // (approximately: expr_terms - pattern_terms + possible cancellations)
118 if diff_terms < expr_terms || Self::is_simpler(&simplified_diff, &expanded_expr) {
119 // Check that we removed approximately the right number of terms
120 // Allow for some cancellation effects
121 if expr_terms.saturating_sub(diff_terms) > 0 || diff_terms + pattern_terms > expr_terms
122 {
123 // Pattern was found and removed, add value back
124 let result = &simplified_diff + value;
125 return Some(result.expand());
126 }
127 }
128
129 None
130 }
131
132 /// Counts the number of terms in an atom.
133 ///
134 /// For Add expressions, returns the number of addends.
135 /// For other atoms, returns 1.
136 fn count_terms(atom: &Atom) -> usize {
137 atom.as_add_view().map(|add| add.get_nargs()).unwrap_or(1)
138 }
139
140 /// Checks if expr_a is simpler than expr_b.
141 ///
142 /// Uses term count as the complexity metric.
143 fn is_simpler(a: &Atom, b: &Atom) -> bool {
144 Self::count_terms(a) < Self::count_terms(b)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use std::borrow::Cow;
152 use std::sync::Mutex;
153 use symbolica::atom::DefaultNamespace;
154 use symbolica::parser::ParseSettings;
155
156 // Symbolica uses global state that conflicts when tests run in parallel.
157 // Tests prefixed with `symbolica_` are configured to run serially via nextest.
158 // The mutex provides a fallback for `cargo test` users.
159 static SYMBOLICA_LOCK: Mutex<()> = Mutex::new(());
160
161 fn parse_atom(s: &str) -> Atom {
162 let input = DefaultNamespace {
163 namespace: Cow::Borrowed(env!("CARGO_CRATE_NAME")),
164 data: s,
165 file: Cow::Borrowed(file!()),
166 line: line!() as usize,
167 };
168 Atom::parse(input, ParseSettings::symbolica()).unwrap()
169 }
170
171 #[test]
172 fn symbolica_empty_simplifier_returns_unchanged() {
173 let _guard = SYMBOLICA_LOCK.lock().unwrap();
174 let simplifier = ConstraintSimplifier::new(&[], &[]);
175 let expr = parse_atom("a + b");
176 let result = simplifier.apply(&expr);
177 assert_eq!(result, expr);
178 }
179}