Skip to main content

clifford_codegen/symbolic/
solver.rs

1//! Constraint solver for computing dependent coefficients.
2//!
3//! This module solves constraint equations for a specified variable,
4//! generating the Rust expression to compute that variable from the others.
5//!
6//! Supports two types of constraints:
7//!
8//! 1. **Linear**: `coeff1*var1*var2 + coeff2*var3*var4 + ... = 0`
9//!    Solution is direct division.
10//!
11//! 2. **Quadratic**: `var*var + other_terms = constant`
12//!    Solution requires square root, with domain restrictions.
13
14use thiserror::Error;
15
16/// Errors that can occur during constraint solving.
17#[derive(Debug, Error)]
18pub enum SolveError {
19    /// Failed to parse the constraint expression.
20    #[error("failed to parse constraint: {0}")]
21    ParseError(String),
22
23    /// Missing equality in constraint.
24    #[error("constraint must contain '=' operator: {0}")]
25    MissingEquality(String),
26
27    /// The solve_for variable doesn't appear in the constraint.
28    #[error("variable '{0}' does not appear in constraint")]
29    VariableNotFound(String),
30
31    /// Variable appears in a form that can't be solved.
32    #[error("variable '{0}' appears in a form that cannot be algebraically solved")]
33    UnsolvableForm(String),
34}
35
36/// The type of solution for a constraint.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum SolutionType {
39    /// Linear solution: var = expr / divisor
40    Linear,
41    /// Quadratic solution: var = sqrt(expr), has domain restrictions
42    Quadratic,
43}
44
45/// Result of solving a constraint for a variable.
46#[derive(Debug, Clone)]
47pub struct SolveResult {
48    /// The variable being solved for.
49    pub variable: String,
50    /// The Rust expression for the numerator (or sqrt argument for quadratic).
51    pub numerator: String,
52    /// The Rust expression for the divisor (if division is needed).
53    pub divisor: Option<String>,
54    /// The type of solution (affects code generation).
55    pub solution_type: SolutionType,
56    /// Whether to use positive or negative root (for quadratic).
57    pub positive_root: bool,
58    /// The original constraint.
59    pub constraint: String,
60}
61
62/// A term in a linear constraint expression.
63#[derive(Debug, Clone)]
64struct Term {
65    /// Numeric coefficient (e.g., 2, -2).
66    coefficient: i32,
67    /// Variables in the term (e.g., ["s", "e0123"]).
68    variables: Vec<String>,
69}
70
71/// Solver for linear constraint equations.
72#[derive(Debug, Default)]
73pub struct ConstraintSolver;
74
75impl ConstraintSolver {
76    /// Creates a new constraint solver.
77    pub fn new() -> Self {
78        Self
79    }
80
81    /// Solves a constraint for the specified variable.
82    ///
83    /// Handles both linear and quadratic constraints:
84    ///
85    /// **Linear**: `2*s*e0123 - 2*e12*e03 + 2*e13*e02 - 2*e23*e01 = 0`
86    /// solving for `e0123` yields: `e0123 = (e12*e03 - e13*e02 + e23*e01) / s`
87    ///
88    /// **Quadratic**: `s*s + e12*e12 + e13*e13 + e23*e23 = 1`
89    /// solving for `s` yields: `s = sqrt(1 - e12*e12 - e13*e13 - e23*e23)`
90    ///
91    /// # Arguments
92    ///
93    /// * `constraint` - The constraint expression string
94    /// * `solve_for` - The variable to solve for
95    /// * `positive_root` - Whether to use positive root for quadratic (default true)
96    ///
97    /// # Returns
98    ///
99    /// A `SolveResult` containing the solution expression and type.
100    pub fn solve(&self, constraint: &str, solve_for: &str) -> Result<SolveResult, SolveError> {
101        self.solve_with_sign(constraint, solve_for, true)
102    }
103
104    /// Solves a constraint with explicit sign convention for quadratic solutions.
105    pub fn solve_with_sign(
106        &self,
107        constraint: &str,
108        solve_for: &str,
109        positive_root: bool,
110    ) -> Result<SolveResult, SolveError> {
111        // Parse constraint into terms and RHS constant
112        let (terms, rhs_constant) = self.parse_constraint(constraint)?;
113
114        // Check if variable appears squared (quadratic)
115        let is_quadratic = self.is_quadratic_in_variable(&terms, solve_for);
116
117        if is_quadratic {
118            self.solve_quadratic(&terms, rhs_constant, solve_for, positive_root, constraint)
119        } else {
120            self.solve_linear(&terms, solve_for, constraint)
121        }
122    }
123
124    /// Checks if a variable appears squared in any term.
125    fn is_quadratic_in_variable(&self, terms: &[Term], var: &str) -> bool {
126        for term in terms {
127            let count = term.variables.iter().filter(|v| *v == var).count();
128            if count >= 2 {
129                return true;
130            }
131        }
132        false
133    }
134
135    /// Solves a linear constraint.
136    fn solve_linear(
137        &self,
138        terms: &[Term],
139        solve_for: &str,
140        constraint: &str,
141    ) -> Result<SolveResult, SolveError> {
142        // Separate terms containing solve_for from the rest
143        let mut solve_for_terms: Vec<Term> = Vec::new();
144        let mut other_terms: Vec<Term> = Vec::new();
145
146        for term in terms {
147            if term.variables.contains(&solve_for.to_string()) {
148                solve_for_terms.push(term.clone());
149            } else {
150                other_terms.push(term.clone());
151            }
152        }
153
154        if solve_for_terms.is_empty() {
155            return Err(SolveError::VariableNotFound(solve_for.to_string()));
156        }
157
158        // Extract the coefficient of solve_for
159        let (divisor_parts, total_coeff) = self.extract_coefficient(&solve_for_terms, solve_for);
160
161        // The solution is: solve_for = -other_terms / coefficient
162        let negated_other: Vec<Term> = other_terms
163            .into_iter()
164            .map(|mut t| {
165                t.coefficient = -t.coefficient;
166                t
167            })
168            .collect();
169
170        // Build numerator expression
171        let numerator = self.build_rust_expression(&negated_other, total_coeff);
172
173        // Build divisor expression (if not a constant)
174        let divisor = if divisor_parts.is_empty() {
175            None
176        } else {
177            Some(divisor_parts.join(" * "))
178        };
179
180        Ok(SolveResult {
181            variable: solve_for.to_string(),
182            numerator,
183            divisor,
184            solution_type: SolutionType::Linear,
185            positive_root: true,
186            constraint: constraint.to_string(),
187        })
188    }
189
190    /// Solves a quadratic constraint of form: coeff*var*var + other_terms = constant
191    fn solve_quadratic(
192        &self,
193        terms: &[Term],
194        rhs_constant: i32,
195        solve_for: &str,
196        positive_root: bool,
197        constraint: &str,
198    ) -> Result<SolveResult, SolveError> {
199        // Find the term with var*var
200        let mut squared_coeff = 0i32;
201        let mut other_terms: Vec<Term> = Vec::new();
202
203        for term in terms {
204            let var_count = term.variables.iter().filter(|v| *v == solve_for).count();
205
206            if var_count == 2 && term.variables.len() == 2 {
207                // Pure var*var term
208                squared_coeff += term.coefficient;
209            } else if var_count == 0 {
210                other_terms.push(term.clone());
211            } else {
212                // Mixed term (var appears but not squared alone) - can't solve simply
213                return Err(SolveError::UnsolvableForm(solve_for.to_string()));
214            }
215        }
216
217        if squared_coeff == 0 {
218            return Err(SolveError::VariableNotFound(solve_for.to_string()));
219        }
220
221        // Solution: var = sqrt((constant - other_terms) / squared_coeff)
222        // For typical unit norm constraints where squared_coeff = 1:
223        // var = sqrt(constant - other_terms)
224
225        // Build the sqrt argument: (constant - other_terms) / squared_coeff
226        // First negate other_terms
227        let negated_other: Vec<Term> = other_terms
228            .into_iter()
229            .map(|mut t| {
230                t.coefficient = -t.coefficient;
231                t
232            })
233            .collect();
234
235        // Build expression for constant + negated_other
236        let mut sqrt_arg = if rhs_constant != 0 {
237            if squared_coeff == 1 {
238                format!("T::from_i8({})", rhs_constant)
239            } else {
240                format!(
241                    "T::from_i8({}) / T::from_i8({})",
242                    rhs_constant, squared_coeff
243                )
244            }
245        } else {
246            String::new()
247        };
248
249        // Add negated other terms
250        for term in &negated_other {
251            let term_expr = self.term_to_rust_expression(term, squared_coeff);
252            if sqrt_arg.is_empty() {
253                sqrt_arg = term_expr;
254            } else if let Some(stripped) = term_expr.strip_prefix('-') {
255                sqrt_arg = format!("{} - {}", sqrt_arg, stripped);
256            } else {
257                sqrt_arg = format!("{} + {}", sqrt_arg, term_expr);
258            }
259        }
260
261        if sqrt_arg.is_empty() {
262            sqrt_arg = "T::zero()".to_string();
263        }
264
265        Ok(SolveResult {
266            variable: solve_for.to_string(),
267            numerator: sqrt_arg,
268            divisor: None,
269            solution_type: SolutionType::Quadratic,
270            positive_root,
271            constraint: constraint.to_string(),
272        })
273    }
274
275    /// Converts a single term to a Rust expression.
276    fn term_to_rust_expression(&self, term: &Term, divisor_coeff: i32) -> String {
277        let simplified_coeff = if divisor_coeff != 0 && term.coefficient % divisor_coeff == 0 {
278            term.coefficient / divisor_coeff
279        } else {
280            term.coefficient
281        };
282
283        let vars_expr = if term.variables.is_empty() {
284            format!("T::from_i8({})", simplified_coeff)
285        } else {
286            term.variables.join(" * ")
287        };
288
289        match simplified_coeff {
290            1 => vars_expr,
291            -1 => format!("-{}", vars_expr),
292            _ if term.variables.is_empty() => format!("T::from_i8({})", simplified_coeff),
293            _ => format!("T::from_i8({}) * {}", simplified_coeff, vars_expr),
294        }
295    }
296
297    /// Parse a constraint string into terms and RHS constant.
298    ///
299    /// Input: "2*s*e0123 - 2*e12*e03 = 0" or "s*s + b*b = 1"
300    /// Output: (Vec of Terms, RHS constant)
301    fn parse_constraint(&self, constraint: &str) -> Result<(Vec<Term>, i32), SolveError> {
302        // Split on '='
303        let parts: Vec<&str> = constraint.split('=').collect();
304        if parts.len() != 2 {
305            return Err(SolveError::MissingEquality(constraint.to_string()));
306        }
307
308        let lhs = parts[0].trim();
309        let rhs = parts[1].trim();
310
311        // Parse RHS as a constant
312        let rhs_constant: i32 = rhs.parse().map_err(|_| {
313            SolveError::ParseError(format!("RHS must be an integer constant, got '{}'", rhs))
314        })?;
315
316        let terms = self.parse_expression(lhs)?;
317        Ok((terms, rhs_constant))
318    }
319
320    /// Parse an expression into terms.
321    fn parse_expression(&self, expr: &str) -> Result<Vec<Term>, SolveError> {
322        let mut terms = Vec::new();
323        let mut current_term = String::new();
324        let mut sign = 1;
325
326        // Normalize: remove spaces around operators
327        let normalized = expr.replace(" ", "");
328
329        let chars: Vec<char> = normalized.chars().collect();
330        let mut i = 0;
331
332        while i < chars.len() {
333            let c = chars[i];
334
335            match c {
336                '+' => {
337                    if !current_term.is_empty() {
338                        terms.push(self.parse_term(&current_term, sign)?);
339                        current_term.clear();
340                    }
341                    sign = 1;
342                }
343                '-' => {
344                    if !current_term.is_empty() {
345                        terms.push(self.parse_term(&current_term, sign)?);
346                        current_term.clear();
347                    }
348                    sign = -1;
349                }
350                _ => {
351                    current_term.push(c);
352                }
353            }
354            i += 1;
355        }
356
357        // Don't forget the last term
358        if !current_term.is_empty() {
359            terms.push(self.parse_term(&current_term, sign)?);
360        }
361
362        Ok(terms)
363    }
364
365    /// Parse a single term like "2*s*e0123" into a Term.
366    fn parse_term(&self, term: &str, sign: i32) -> Result<Term, SolveError> {
367        let factors: Vec<&str> = term.split('*').collect();
368
369        let mut coefficient = sign;
370        let mut variables = Vec::new();
371
372        for factor in factors {
373            let factor = factor.trim();
374            if factor.is_empty() {
375                continue;
376            }
377
378            // Check if it's a number
379            if let Ok(num) = factor.parse::<i32>() {
380                coefficient *= num;
381            } else {
382                variables.push(factor.to_string());
383            }
384        }
385
386        Ok(Term {
387            coefficient,
388            variables,
389        })
390    }
391
392    /// Extract the coefficient (other variables) of the solve_for variable.
393    ///
394    /// For terms like [2*s*e0123], extracting e0123's coefficient gives:
395    /// - divisor_parts: ["s"]
396    /// - total_coeff: 2
397    fn extract_coefficient(&self, terms: &[Term], solve_for: &str) -> (Vec<String>, i32) {
398        // For now, assume there's only one term containing solve_for
399        // (our constraints are simple enough for this)
400        if let Some(term) = terms.first() {
401            let other_vars: Vec<String> = term
402                .variables
403                .iter()
404                .filter(|v| *v != solve_for)
405                .cloned()
406                .collect();
407
408            (other_vars, term.coefficient)
409        } else {
410            (Vec::new(), 1)
411        }
412    }
413
414    /// Build a Rust expression from terms.
415    ///
416    /// Takes into account the coefficient divisor to simplify expressions.
417    fn build_rust_expression(&self, terms: &[Term], divisor_coeff: i32) -> String {
418        if terms.is_empty() {
419            return "T::zero()".to_string();
420        }
421
422        let parts: Vec<String> = terms
423            .iter()
424            .map(|term| {
425                // Simplify coefficient if divisor matches
426                let simplified_coeff =
427                    if divisor_coeff != 0 && term.coefficient % divisor_coeff == 0 {
428                        term.coefficient / divisor_coeff
429                    } else {
430                        term.coefficient
431                    };
432
433                let vars_expr = if term.variables.is_empty() {
434                    format!("T::from_i8({})", simplified_coeff)
435                } else {
436                    term.variables.join(" * ")
437                };
438
439                match simplified_coeff {
440                    1 => vars_expr,
441                    -1 => format!("-{}", vars_expr),
442                    _ if term.variables.is_empty() => format!("T::from_i8({})", simplified_coeff),
443                    _ => format!("T::from_i8({}) * {}", simplified_coeff, vars_expr),
444                }
445            })
446            .collect();
447
448        // Join with proper operators
449        let mut result = String::new();
450        for (i, part) in parts.iter().enumerate() {
451            if i == 0 {
452                result.push_str(part);
453            } else if let Some(stripped) = part.strip_prefix('-') {
454                result.push_str(" - ");
455                result.push_str(stripped);
456            } else {
457                result.push_str(" + ");
458                result.push_str(part);
459            }
460        }
461
462        result
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_parse_term() {
472        let solver = ConstraintSolver::new();
473
474        let term = solver.parse_term("2*s*e0123", 1).unwrap();
475        assert_eq!(term.coefficient, 2);
476        assert_eq!(term.variables, vec!["s", "e0123"]);
477
478        let term = solver.parse_term("e12*e03", -1).unwrap();
479        assert_eq!(term.coefficient, -1);
480        assert_eq!(term.variables, vec!["e12", "e03"]);
481    }
482
483    #[test]
484    fn test_parse_expression() {
485        let solver = ConstraintSolver::new();
486
487        let terms = solver
488            .parse_expression("2*s*e0123 - 2*e12*e03 + 2*e13*e02")
489            .unwrap();
490        assert_eq!(terms.len(), 3);
491        assert_eq!(terms[0].coefficient, 2);
492        assert_eq!(terms[1].coefficient, -2);
493        assert_eq!(terms[2].coefficient, 2);
494    }
495
496    #[test]
497    fn test_solve_motor_constraint() {
498        let solver = ConstraintSolver::new();
499
500        let result = solver
501            .solve("2*s*e0123 - 2*e12*e03 + 2*e13*e02 - 2*e23*e01 = 0", "e0123")
502            .unwrap();
503
504        assert_eq!(result.variable, "e0123");
505        assert_eq!(result.divisor, Some("s".to_string()));
506        assert_eq!(result.solution_type, SolutionType::Linear);
507        // Numerator should be: e12*e03 - e13*e02 + e23*e01
508        assert!(result.numerator.contains("e12 * e03"));
509        assert!(result.numerator.contains("e13 * e02"));
510        assert!(result.numerator.contains("e23 * e01"));
511    }
512
513    #[test]
514    fn test_solve_bivector_constraint() {
515        let solver = ConstraintSolver::new();
516
517        let result = solver
518            .solve("-2*e12*e03 + 2*e13*e02 - 2*e23*e01 = 0", "e03")
519            .unwrap();
520
521        assert_eq!(result.variable, "e03");
522        assert_eq!(result.divisor, Some("e12".to_string()));
523        assert_eq!(result.solution_type, SolutionType::Linear);
524    }
525
526    #[test]
527    fn test_solve_quadratic_unit_norm() {
528        let solver = ConstraintSolver::new();
529
530        // Unit norm constraint: s*s + e12*e12 + e13*e13 + e23*e23 = 1
531        let result = solver
532            .solve("s*s + e12*e12 + e13*e13 + e23*e23 = 1", "s")
533            .unwrap();
534
535        assert_eq!(result.variable, "s");
536        assert_eq!(result.solution_type, SolutionType::Quadratic);
537        assert!(result.positive_root);
538        assert!(result.divisor.is_none());
539        // sqrt argument should be: 1 - e12*e12 - e13*e13 - e23*e23
540        assert!(result.numerator.contains("T::from_i8(1)"));
541        assert!(result.numerator.contains("e12 * e12"));
542        assert!(result.numerator.contains("e13 * e13"));
543        assert!(result.numerator.contains("e23 * e23"));
544    }
545
546    #[test]
547    fn test_solve_quadratic_negative_root() {
548        let solver = ConstraintSolver::new();
549
550        let result = solver.solve_with_sign("a*a + b*b = 1", "a", false).unwrap();
551
552        assert_eq!(result.variable, "a");
553        assert_eq!(result.solution_type, SolutionType::Quadratic);
554        assert!(!result.positive_root);
555    }
556
557    #[test]
558    fn test_solve_simple_quadratic() {
559        let solver = ConstraintSolver::new();
560
561        // Simple case: x*x = 1
562        let result = solver.solve("x*x = 1", "x").unwrap();
563
564        assert_eq!(result.variable, "x");
565        assert_eq!(result.solution_type, SolutionType::Quadratic);
566        assert_eq!(result.numerator, "T::from_i8(1)");
567    }
568
569    #[test]
570    fn test_variable_not_found() {
571        let solver = ConstraintSolver::new();
572
573        let result = solver.solve("2*s*e0123 = 0", "nonexistent");
574        assert!(matches!(result, Err(SolveError::VariableNotFound(_))));
575    }
576}