clifford_codegen/symbolic/
solver.rs1use thiserror::Error;
15
16#[derive(Debug, Error)]
18pub enum SolveError {
19 #[error("failed to parse constraint: {0}")]
21 ParseError(String),
22
23 #[error("constraint must contain '=' operator: {0}")]
25 MissingEquality(String),
26
27 #[error("variable '{0}' does not appear in constraint")]
29 VariableNotFound(String),
30
31 #[error("variable '{0}' appears in a form that cannot be algebraically solved")]
33 UnsolvableForm(String),
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum SolutionType {
39 Linear,
41 Quadratic,
43}
44
45#[derive(Debug, Clone)]
47pub struct SolveResult {
48 pub variable: String,
50 pub numerator: String,
52 pub divisor: Option<String>,
54 pub solution_type: SolutionType,
56 pub positive_root: bool,
58 pub constraint: String,
60}
61
62#[derive(Debug, Clone)]
64struct Term {
65 coefficient: i32,
67 variables: Vec<String>,
69}
70
71#[derive(Debug, Default)]
73pub struct ConstraintSolver;
74
75impl ConstraintSolver {
76 pub fn new() -> Self {
78 Self
79 }
80
81 pub fn solve(&self, constraint: &str, solve_for: &str) -> Result<SolveResult, SolveError> {
101 self.solve_with_sign(constraint, solve_for, true)
102 }
103
104 pub fn solve_with_sign(
106 &self,
107 constraint: &str,
108 solve_for: &str,
109 positive_root: bool,
110 ) -> Result<SolveResult, SolveError> {
111 let (terms, rhs_constant) = self.parse_constraint(constraint)?;
113
114 let is_quadratic = self.is_quadratic_in_variable(&terms, solve_for);
116
117 if is_quadratic {
118 self.solve_quadratic(&terms, rhs_constant, solve_for, positive_root, constraint)
119 } else {
120 self.solve_linear(&terms, solve_for, constraint)
121 }
122 }
123
124 fn is_quadratic_in_variable(&self, terms: &[Term], var: &str) -> bool {
126 for term in terms {
127 let count = term.variables.iter().filter(|v| *v == var).count();
128 if count >= 2 {
129 return true;
130 }
131 }
132 false
133 }
134
135 fn solve_linear(
137 &self,
138 terms: &[Term],
139 solve_for: &str,
140 constraint: &str,
141 ) -> Result<SolveResult, SolveError> {
142 let mut solve_for_terms: Vec<Term> = Vec::new();
144 let mut other_terms: Vec<Term> = Vec::new();
145
146 for term in terms {
147 if term.variables.contains(&solve_for.to_string()) {
148 solve_for_terms.push(term.clone());
149 } else {
150 other_terms.push(term.clone());
151 }
152 }
153
154 if solve_for_terms.is_empty() {
155 return Err(SolveError::VariableNotFound(solve_for.to_string()));
156 }
157
158 let (divisor_parts, total_coeff) = self.extract_coefficient(&solve_for_terms, solve_for);
160
161 let negated_other: Vec<Term> = other_terms
163 .into_iter()
164 .map(|mut t| {
165 t.coefficient = -t.coefficient;
166 t
167 })
168 .collect();
169
170 let numerator = self.build_rust_expression(&negated_other, total_coeff);
172
173 let divisor = if divisor_parts.is_empty() {
175 None
176 } else {
177 Some(divisor_parts.join(" * "))
178 };
179
180 Ok(SolveResult {
181 variable: solve_for.to_string(),
182 numerator,
183 divisor,
184 solution_type: SolutionType::Linear,
185 positive_root: true,
186 constraint: constraint.to_string(),
187 })
188 }
189
190 fn solve_quadratic(
192 &self,
193 terms: &[Term],
194 rhs_constant: i32,
195 solve_for: &str,
196 positive_root: bool,
197 constraint: &str,
198 ) -> Result<SolveResult, SolveError> {
199 let mut squared_coeff = 0i32;
201 let mut other_terms: Vec<Term> = Vec::new();
202
203 for term in terms {
204 let var_count = term.variables.iter().filter(|v| *v == solve_for).count();
205
206 if var_count == 2 && term.variables.len() == 2 {
207 squared_coeff += term.coefficient;
209 } else if var_count == 0 {
210 other_terms.push(term.clone());
211 } else {
212 return Err(SolveError::UnsolvableForm(solve_for.to_string()));
214 }
215 }
216
217 if squared_coeff == 0 {
218 return Err(SolveError::VariableNotFound(solve_for.to_string()));
219 }
220
221 let negated_other: Vec<Term> = other_terms
228 .into_iter()
229 .map(|mut t| {
230 t.coefficient = -t.coefficient;
231 t
232 })
233 .collect();
234
235 let mut sqrt_arg = if rhs_constant != 0 {
237 if squared_coeff == 1 {
238 format!("T::from_i8({})", rhs_constant)
239 } else {
240 format!(
241 "T::from_i8({}) / T::from_i8({})",
242 rhs_constant, squared_coeff
243 )
244 }
245 } else {
246 String::new()
247 };
248
249 for term in &negated_other {
251 let term_expr = self.term_to_rust_expression(term, squared_coeff);
252 if sqrt_arg.is_empty() {
253 sqrt_arg = term_expr;
254 } else if let Some(stripped) = term_expr.strip_prefix('-') {
255 sqrt_arg = format!("{} - {}", sqrt_arg, stripped);
256 } else {
257 sqrt_arg = format!("{} + {}", sqrt_arg, term_expr);
258 }
259 }
260
261 if sqrt_arg.is_empty() {
262 sqrt_arg = "T::zero()".to_string();
263 }
264
265 Ok(SolveResult {
266 variable: solve_for.to_string(),
267 numerator: sqrt_arg,
268 divisor: None,
269 solution_type: SolutionType::Quadratic,
270 positive_root,
271 constraint: constraint.to_string(),
272 })
273 }
274
275 fn term_to_rust_expression(&self, term: &Term, divisor_coeff: i32) -> String {
277 let simplified_coeff = if divisor_coeff != 0 && term.coefficient % divisor_coeff == 0 {
278 term.coefficient / divisor_coeff
279 } else {
280 term.coefficient
281 };
282
283 let vars_expr = if term.variables.is_empty() {
284 format!("T::from_i8({})", simplified_coeff)
285 } else {
286 term.variables.join(" * ")
287 };
288
289 match simplified_coeff {
290 1 => vars_expr,
291 -1 => format!("-{}", vars_expr),
292 _ if term.variables.is_empty() => format!("T::from_i8({})", simplified_coeff),
293 _ => format!("T::from_i8({}) * {}", simplified_coeff, vars_expr),
294 }
295 }
296
297 fn parse_constraint(&self, constraint: &str) -> Result<(Vec<Term>, i32), SolveError> {
302 let parts: Vec<&str> = constraint.split('=').collect();
304 if parts.len() != 2 {
305 return Err(SolveError::MissingEquality(constraint.to_string()));
306 }
307
308 let lhs = parts[0].trim();
309 let rhs = parts[1].trim();
310
311 let rhs_constant: i32 = rhs.parse().map_err(|_| {
313 SolveError::ParseError(format!("RHS must be an integer constant, got '{}'", rhs))
314 })?;
315
316 let terms = self.parse_expression(lhs)?;
317 Ok((terms, rhs_constant))
318 }
319
320 fn parse_expression(&self, expr: &str) -> Result<Vec<Term>, SolveError> {
322 let mut terms = Vec::new();
323 let mut current_term = String::new();
324 let mut sign = 1;
325
326 let normalized = expr.replace(" ", "");
328
329 let chars: Vec<char> = normalized.chars().collect();
330 let mut i = 0;
331
332 while i < chars.len() {
333 let c = chars[i];
334
335 match c {
336 '+' => {
337 if !current_term.is_empty() {
338 terms.push(self.parse_term(¤t_term, sign)?);
339 current_term.clear();
340 }
341 sign = 1;
342 }
343 '-' => {
344 if !current_term.is_empty() {
345 terms.push(self.parse_term(¤t_term, sign)?);
346 current_term.clear();
347 }
348 sign = -1;
349 }
350 _ => {
351 current_term.push(c);
352 }
353 }
354 i += 1;
355 }
356
357 if !current_term.is_empty() {
359 terms.push(self.parse_term(¤t_term, sign)?);
360 }
361
362 Ok(terms)
363 }
364
365 fn parse_term(&self, term: &str, sign: i32) -> Result<Term, SolveError> {
367 let factors: Vec<&str> = term.split('*').collect();
368
369 let mut coefficient = sign;
370 let mut variables = Vec::new();
371
372 for factor in factors {
373 let factor = factor.trim();
374 if factor.is_empty() {
375 continue;
376 }
377
378 if let Ok(num) = factor.parse::<i32>() {
380 coefficient *= num;
381 } else {
382 variables.push(factor.to_string());
383 }
384 }
385
386 Ok(Term {
387 coefficient,
388 variables,
389 })
390 }
391
392 fn extract_coefficient(&self, terms: &[Term], solve_for: &str) -> (Vec<String>, i32) {
398 if let Some(term) = terms.first() {
401 let other_vars: Vec<String> = term
402 .variables
403 .iter()
404 .filter(|v| *v != solve_for)
405 .cloned()
406 .collect();
407
408 (other_vars, term.coefficient)
409 } else {
410 (Vec::new(), 1)
411 }
412 }
413
414 fn build_rust_expression(&self, terms: &[Term], divisor_coeff: i32) -> String {
418 if terms.is_empty() {
419 return "T::zero()".to_string();
420 }
421
422 let parts: Vec<String> = terms
423 .iter()
424 .map(|term| {
425 let simplified_coeff =
427 if divisor_coeff != 0 && term.coefficient % divisor_coeff == 0 {
428 term.coefficient / divisor_coeff
429 } else {
430 term.coefficient
431 };
432
433 let vars_expr = if term.variables.is_empty() {
434 format!("T::from_i8({})", simplified_coeff)
435 } else {
436 term.variables.join(" * ")
437 };
438
439 match simplified_coeff {
440 1 => vars_expr,
441 -1 => format!("-{}", vars_expr),
442 _ if term.variables.is_empty() => format!("T::from_i8({})", simplified_coeff),
443 _ => format!("T::from_i8({}) * {}", simplified_coeff, vars_expr),
444 }
445 })
446 .collect();
447
448 let mut result = String::new();
450 for (i, part) in parts.iter().enumerate() {
451 if i == 0 {
452 result.push_str(part);
453 } else if let Some(stripped) = part.strip_prefix('-') {
454 result.push_str(" - ");
455 result.push_str(stripped);
456 } else {
457 result.push_str(" + ");
458 result.push_str(part);
459 }
460 }
461
462 result
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_parse_term() {
472 let solver = ConstraintSolver::new();
473
474 let term = solver.parse_term("2*s*e0123", 1).unwrap();
475 assert_eq!(term.coefficient, 2);
476 assert_eq!(term.variables, vec!["s", "e0123"]);
477
478 let term = solver.parse_term("e12*e03", -1).unwrap();
479 assert_eq!(term.coefficient, -1);
480 assert_eq!(term.variables, vec!["e12", "e03"]);
481 }
482
483 #[test]
484 fn test_parse_expression() {
485 let solver = ConstraintSolver::new();
486
487 let terms = solver
488 .parse_expression("2*s*e0123 - 2*e12*e03 + 2*e13*e02")
489 .unwrap();
490 assert_eq!(terms.len(), 3);
491 assert_eq!(terms[0].coefficient, 2);
492 assert_eq!(terms[1].coefficient, -2);
493 assert_eq!(terms[2].coefficient, 2);
494 }
495
496 #[test]
497 fn test_solve_motor_constraint() {
498 let solver = ConstraintSolver::new();
499
500 let result = solver
501 .solve("2*s*e0123 - 2*e12*e03 + 2*e13*e02 - 2*e23*e01 = 0", "e0123")
502 .unwrap();
503
504 assert_eq!(result.variable, "e0123");
505 assert_eq!(result.divisor, Some("s".to_string()));
506 assert_eq!(result.solution_type, SolutionType::Linear);
507 assert!(result.numerator.contains("e12 * e03"));
509 assert!(result.numerator.contains("e13 * e02"));
510 assert!(result.numerator.contains("e23 * e01"));
511 }
512
513 #[test]
514 fn test_solve_bivector_constraint() {
515 let solver = ConstraintSolver::new();
516
517 let result = solver
518 .solve("-2*e12*e03 + 2*e13*e02 - 2*e23*e01 = 0", "e03")
519 .unwrap();
520
521 assert_eq!(result.variable, "e03");
522 assert_eq!(result.divisor, Some("e12".to_string()));
523 assert_eq!(result.solution_type, SolutionType::Linear);
524 }
525
526 #[test]
527 fn test_solve_quadratic_unit_norm() {
528 let solver = ConstraintSolver::new();
529
530 let result = solver
532 .solve("s*s + e12*e12 + e13*e13 + e23*e23 = 1", "s")
533 .unwrap();
534
535 assert_eq!(result.variable, "s");
536 assert_eq!(result.solution_type, SolutionType::Quadratic);
537 assert!(result.positive_root);
538 assert!(result.divisor.is_none());
539 assert!(result.numerator.contains("T::from_i8(1)"));
541 assert!(result.numerator.contains("e12 * e12"));
542 assert!(result.numerator.contains("e13 * e13"));
543 assert!(result.numerator.contains("e23 * e23"));
544 }
545
546 #[test]
547 fn test_solve_quadratic_negative_root() {
548 let solver = ConstraintSolver::new();
549
550 let result = solver.solve_with_sign("a*a + b*b = 1", "a", false).unwrap();
551
552 assert_eq!(result.variable, "a");
553 assert_eq!(result.solution_type, SolutionType::Quadratic);
554 assert!(!result.positive_root);
555 }
556
557 #[test]
558 fn test_solve_simple_quadratic() {
559 let solver = ConstraintSolver::new();
560
561 let result = solver.solve("x*x = 1", "x").unwrap();
563
564 assert_eq!(result.variable, "x");
565 assert_eq!(result.solution_type, SolutionType::Quadratic);
566 assert_eq!(result.numerator, "T::from_i8(1)");
567 }
568
569 #[test]
570 fn test_variable_not_found() {
571 let solver = ConstraintSolver::new();
572
573 let result = solver.solve("2*s*e0123 = 0", "nonexistent");
574 assert!(matches!(result, Err(SolveError::VariableNotFound(_))));
575 }
576}