1use crate::compiler::{CompiledExp, CompiledSystem, VarId, VarTable};
26use crate::exp::MissingVarError;
27use crate::jacobian::Jacobian;
28use crate::matrix::{LeastSquaresQrInfo, Matrix, MatrixError};
29use std::collections::HashMap;
30
31struct LeastSquaresWorkspace {
32 augmented_j: Matrix,
33 augmented_b: Matrix,
34 num_equations: usize,
35 num_variables: usize,
36}
37
38impl LeastSquaresWorkspace {
39 fn new(num_equations: usize, num_variables: usize, sqrt_reg: f64) -> Self {
40 let mut augmented_j = Matrix::new(num_equations + num_variables, num_variables);
41 for i in 0..num_variables {
42 augmented_j[(num_equations + i, i)] = sqrt_reg;
43 }
44
45 LeastSquaresWorkspace {
46 augmented_j,
47 augmented_b: Matrix::new(num_equations + num_variables, 1),
48 num_equations,
49 num_variables,
50 }
51 }
52}
53
54enum LeastSquaresSolveError {
55 InvalidInput(MatrixError),
56 Singular { qr_err: String, normal_eq_err: String },
57}
58
59pub struct NewtonRaphsonSolver {
68 equations: Vec<CompiledExp>,
70 variables: Vec<VarId>,
72 var_table: VarTable,
74 max_iterations: usize,
76 tolerance: f64,
78 damping_factor: f64,
81 min_damping: f64,
83 regularization: f64,
85 equation_traces: Vec<Option<EquationTrace>>,
87}
88
89#[derive(Debug, Clone)]
91pub struct Solution {
92 pub values: HashMap<String, f64>,
94 pub iterations: usize,
96 pub error: f64,
98 pub converged: bool,
100 pub convergence_history: Vec<f64>,
102}
103
104#[derive(Debug, Clone)]
106pub struct EquationTrace {
107 pub constraint_id: usize,
108 pub description: String,
109}
110
111#[derive(Debug, Clone)]
113pub struct SolverRunDiagnostic {
114 pub message: String,
115 pub iterations: usize,
116 pub error: f64,
117 pub values: HashMap<String, f64>,
119 pub residuals: Vec<f64>,
120}
121
122#[derive(Debug, Clone)]
124pub enum SolverError {
125 SingularMatrix(SolverRunDiagnostic),
127 NoConvergence(SolverRunDiagnostic),
129 MissingVariable(MissingVarError),
131 InvalidInput(String),
133}
134
135impl From<MissingVarError> for SolverError {
136 fn from(value: MissingVarError) -> Self {
137 SolverError::MissingVariable(value)
138 }
139}
140
141impl From<MatrixError> for SolverError {
142 fn from(value: MatrixError) -> Self {
143 SolverError::InvalidInput(value.to_string())
144 }
145}
146
147impl NewtonRaphsonSolver {
148 pub fn new(compiled: CompiledSystem) -> Self {
157 let variables = compiled.var_table.all_var_ids();
158 Self::build(compiled, variables)
159 }
160
161 pub fn new_with_variables(
166 compiled: CompiledSystem,
167 variables: &[&str],
168 ) -> Result<Self, SolverError> {
169 let mut ids = Vec::with_capacity(variables.len());
170 let mut seen = std::collections::HashSet::new();
171 for name in variables {
172 let id = compiled
173 .var_table
174 .get_id(name)
175 .ok_or_else(|| SolverError::InvalidInput(format!("Unknown variable '{name}'")))?;
176 if seen.insert(id) {
177 ids.push(id);
178 }
179 }
180 Ok(Self::build(compiled, ids))
181 }
182
183 fn build(compiled: CompiledSystem, variables: Vec<VarId>) -> Self {
184 let is_over_constrained = compiled.equations.len() > variables.len();
187
188 NewtonRaphsonSolver {
189 equations: compiled.equations,
190 variables,
191 var_table: compiled.var_table,
192 max_iterations: if is_over_constrained { 200 } else { 100 },
194 tolerance: if is_over_constrained { 1e-8 } else { 1e-10 },
196 damping_factor: if is_over_constrained { 0.7 } else { 1.0 },
198 min_damping: 0.001,
199 regularization: if is_over_constrained { 1e-4 } else { 1e-8 },
201 equation_traces: Vec::new(),
202 }
203 }
204
205 pub fn try_with_equation_traces(
207 mut self,
208 traces: Vec<Option<EquationTrace>>,
209 ) -> Result<Self, SolverError> {
210 if !self.equations.is_empty() && traces.len() != self.equations.len() {
211 return Err(SolverError::InvalidInput(format!(
212 "Equation trace length ({}) does not match number of equations ({})",
213 traces.len(),
214 self.equations.len()
215 )));
216 }
217 self.equation_traces = traces;
218 Ok(self)
219 }
220
221 pub fn with_equation_traces(self, traces: Vec<Option<EquationTrace>>) -> Self {
223 self.try_with_equation_traces(traces)
224 .unwrap_or_else(|err| match err {
225 SolverError::InvalidInput(msg) => panic!("{}", msg),
226 other => panic!("{other:?}"),
227 })
228 }
229
230 pub fn trace_for_equation(&self, equation_index: usize) -> Option<&EquationTrace> {
232 self.equation_traces
233 .get(equation_index)
234 .and_then(|entry| entry.as_ref())
235 }
236
237 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
239 self.max_iterations = max_iterations;
240 self
241 }
242
243 pub fn with_tolerance(mut self, tolerance: f64) -> Self {
245 self.tolerance = tolerance;
246 self
247 }
248
249 pub fn with_damping(mut self, damping_factor: f64) -> Self {
251 self.damping_factor = damping_factor.clamp(0.1, 1.0);
252 self
253 }
254
255 pub fn with_regularization(mut self, regularization: f64) -> Self {
257 self.regularization = regularization.max(0.0);
258 self
259 }
260
261 pub fn solve(&self, initial_guess: HashMap<String, f64>) -> Result<Solution, SolverError> {
263 let vars = self.map_initial_guess(initial_guess)?;
264 self.solve_modified_internal(vars, false)
265 }
266
267 fn solve_modified_internal(
282 &self,
283 initial_guess: HashMap<VarId, f64>,
284 use_line_search: bool,
285 ) -> Result<Solution, SolverError> {
286 self.validate_initial_guess(&initial_guess)?;
287
288 let mut vars = initial_guess.clone();
289 let jacobian = Jacobian::new(self.equations.clone(), self.variables.clone());
290 let mut jacobian_workspace = jacobian.workspace();
292 let mut convergence_history = Vec::with_capacity(self.max_iterations);
293 let mut damping = self.damping_factor;
295 let mut last_error = f64::INFINITY;
296
297 let num_equations = self.equations.len();
298 let num_variables = self.variables.len();
299
300 let mut f_vals = Matrix::new(num_equations, 1);
301 let mut f_neg = Matrix::new(num_equations, 1);
302 jacobian_workspace.replace_jacobian(jacobian.evaluate_checked(&vars)?);
303
304 let mut line_search_f_vals = Matrix::new(num_equations, 1);
305
306 let mut least_squares_workspace =
307 if self.regularization > 0.0 && num_equations != num_variables {
308 Some(LeastSquaresWorkspace::new(
309 num_equations,
310 num_variables,
311 self.regularization.sqrt(),
312 ))
313 } else {
314 None
315 };
316
317 for iter in 0..self.max_iterations {
318 jacobian.evaluate_functions_checked_into(&vars, &mut f_vals)?;
320 let error = f_vals.norm();
321 convergence_history.push(error);
322
323 if error < self.tolerance {
325 return Ok(Solution {
326 values: self.values_by_name(&vars),
327 iterations: iter + 1,
328 error,
329 converged: true,
330 convergence_history,
331 });
332 }
333
334 if iter > 0 && (error - last_error).abs() < self.tolerance * 0.1 {
336 damping *= 0.5; if damping < self.min_damping {
338 let diag = self.build_diagnostic(
339 format!(
340 "Solver stagnated at iteration {} with error {:.2e}",
341 iter + 1,
342 error
343 ),
344 iter + 1,
345 &vars,
346 &f_vals,
347 );
348 return Err(SolverError::NoConvergence(diag));
349 }
350 }
351
352 if error > last_error * 1.5 {
354 damping *= 0.5;
356 } else if error < last_error * 0.5 {
357 damping = (damping * 1.2).min(1.0);
359 }
360
361 last_error = error;
362
363 for i in 0..num_equations {
364 f_neg[(i, 0)] = -f_vals[(i, 0)];
365 }
366
367 let delta = {
369 let j_matrix = jacobian_workspace.jacobian();
370 if j_matrix.rows() == j_matrix.cols() {
372 match j_matrix.solve_lu(&f_neg) {
375 Ok(d) => d,
376 Err(_) => {
377 let diag_size = j_matrix.rows().min(j_matrix.cols());
379 for i in 0..diag_size {
380 j_matrix[(i, i)] += self.regularization * 100.0;
381 }
382 match j_matrix.solve_lu(&f_neg) {
383 Ok(d) => d,
384 Err(e) => {
385 let diag = self.build_diagnostic(
386 format!(
387 "Matrix is singular even with regularization: {e}"
388 ),
389 iter + 1,
390 &vars,
391 &f_vals,
392 );
393 return Err(SolverError::SingularMatrix(diag));
394 }
395 }
396 }
397 }
398 } else {
399 match self.solve_least_squares_delta(
400 j_matrix,
401 &f_neg,
402 least_squares_workspace.as_mut(),
403 ) {
404 Ok(delta) => delta,
405 Err(LeastSquaresSolveError::InvalidInput(err)) => {
406 return Err(SolverError::InvalidInput(err.to_string()));
407 }
408 Err(LeastSquaresSolveError::Singular {
409 qr_err,
410 normal_eq_err,
411 }) => {
412 let diag = self.build_diagnostic(
413 format!(
414 "Least squares system is singular (qr_err: {qr_err}; normal_eq_err: {normal_eq_err})"
415 ),
416 iter + 1,
417 &vars,
418 &f_vals,
419 );
420 return Err(SolverError::SingularMatrix(diag));
421 }
422 }
423 }
424 };
425
426 let step_size = if use_line_search {
428 self.line_search(&jacobian, &vars, &delta, error, &mut line_search_f_vals)?
430 } else {
431 damping
433 };
434
435 for (i, &var_id) in self.variables.iter().enumerate() {
437 vars.insert(var_id, vars[&var_id] + step_size * delta[(i, 0)]);
438 }
439
440 if delta.norm() * step_size < self.tolerance {
442 return Ok(Solution {
443 values: self.values_by_name(&vars),
444 iterations: iter + 1,
445 error,
446 converged: true,
447 convergence_history,
448 });
449 }
450
451 jacobian.evaluate_checked_in_workspace(&vars, &mut jacobian_workspace)?;
453 }
454
455 let residuals = jacobian.evaluate_functions_checked(&vars)?;
457 let diag = self.build_diagnostic(
458 format!(
459 "Failed to converge after {} iterations. Final error: {:.2e}",
460 self.max_iterations,
461 residuals.norm()
462 ),
463 self.max_iterations,
464 &vars,
465 &residuals,
466 );
467 Err(SolverError::NoConvergence(diag))
468 }
469
470 fn solve_least_squares_delta(
471 &self,
472 j_matrix: &Matrix,
473 f_neg: &Matrix,
474 workspace: Option<&mut LeastSquaresWorkspace>,
475 ) -> Result<Matrix, LeastSquaresSolveError> {
476 let mut workspace = workspace;
477 const COND_LIMIT: f64 = 1e12;
482
483 let full_rank = j_matrix.rows().min(j_matrix.cols());
484 let is_ill_conditioned = |info: &LeastSquaresQrInfo| -> bool {
485 info.rank < full_rank
486 || !info.cond_est.is_finite()
487 || info.cond_est > COND_LIMIT
488 };
489
490 let mut qr_err: Option<String> = None;
491 let mut unreg_delta: Option<Matrix> = None;
492 let mut unreg_info: Option<LeastSquaresQrInfo> = None;
493
494 match j_matrix.solve_least_squares_qr_with_info(f_neg) {
495 Ok((delta, info)) => {
496 if !is_ill_conditioned(&info) {
497 return Ok(delta);
498 }
499 unreg_delta = Some(delta);
500 unreg_info = Some(info);
501 }
502 Err(err) => {
503 qr_err = Some(err);
504 }
505 }
506
507 if self.regularization <= 0.0 {
508 if let Some(delta) = unreg_delta {
509 return Ok(delta);
510 }
511 return Err(LeastSquaresSolveError::Singular {
512 qr_err: qr_err.unwrap_or_else(|| "QR least squares failed".to_string()),
513 normal_eq_err: "Regularization disabled".to_string(),
514 });
515 }
516
517 let reg_factors = [1.0, 10.0, 100.0];
518 let mut last_err: Option<String> = None;
519 let mut last_solution: Option<Matrix> = None;
520
521 for factor in reg_factors {
522 let lambda = self.regularization * factor;
523 if lambda <= 0.0 {
524 continue;
525 }
526 let sqrt_reg = lambda.sqrt();
527
528 let reg_result = match workspace.as_deref_mut() {
529 Some(workspace) => {
530 if workspace.num_equations != j_matrix.rows()
531 || workspace.num_variables != j_matrix.cols()
532 {
533 return Err(LeastSquaresSolveError::InvalidInput(
534 MatrixError::DimensionMismatch {
535 operation: "least_squares",
536 left: (workspace.augmented_j.rows(), workspace.augmented_j.cols()),
537 right: (j_matrix.rows(), j_matrix.cols()),
538 },
539 ));
540 }
541
542 for i in 0..workspace.num_equations {
543 for j in 0..workspace.num_variables {
544 workspace.augmented_j[(i, j)] = j_matrix[(i, j)];
545 }
546 workspace.augmented_b[(i, 0)] = f_neg[(i, 0)];
547 }
548 for i in 0..workspace.num_variables {
549 for j in 0..workspace.num_variables {
550 workspace.augmented_j[(workspace.num_equations + i, j)] = 0.0;
551 }
552 workspace.augmented_j[(workspace.num_equations + i, i)] = sqrt_reg;
553 workspace.augmented_b[(workspace.num_equations + i, 0)] = 0.0;
554 }
555
556 workspace
557 .augmented_j
558 .solve_least_squares_qr_with_info(&workspace.augmented_b)
559 }
560 None => {
561 let mut augmented_j =
562 Matrix::new(j_matrix.rows() + j_matrix.cols(), j_matrix.cols());
563 let mut augmented_b = Matrix::new(j_matrix.rows() + j_matrix.cols(), 1);
564 for i in 0..j_matrix.rows() {
565 for j in 0..j_matrix.cols() {
566 augmented_j[(i, j)] = j_matrix[(i, j)];
567 }
568 augmented_b[(i, 0)] = f_neg[(i, 0)];
569 }
570 for i in 0..j_matrix.cols() {
571 augmented_j[(j_matrix.rows() + i, i)] = sqrt_reg;
572 augmented_b[(j_matrix.rows() + i, 0)] = 0.0;
573 }
574 augmented_j.solve_least_squares_qr_with_info(&augmented_b)
575 }
576 };
577
578 match reg_result {
579 Ok((delta, info)) => {
580 last_solution = Some(delta);
581 if !is_ill_conditioned(&info) {
582 return Ok(last_solution.unwrap());
583 }
584 }
585 Err(err) => {
586 last_err = Some(err);
587 }
588 }
589 }
590
591 if let Some(delta) = last_solution {
592 return Ok(delta);
593 }
594
595 Err(LeastSquaresSolveError::Singular {
596 qr_err: qr_err
597 .or_else(|| {
598 unreg_info
599 .map(|info| format!("QR ill-conditioned (rank {}, cond {:.2e})", info.rank, info.cond_est))
600 })
601 .unwrap_or_else(|| "QR least squares failed".to_string()),
602 normal_eq_err: last_err.unwrap_or_else(|| "Regularized QR failed".to_string()),
603 })
604 }
605
606 fn map_initial_guess(
607 &self,
608 initial_guess: HashMap<String, f64>,
609 ) -> Result<HashMap<VarId, f64>, SolverError> {
610 let mut unknown = Vec::new();
611 let mut vars = HashMap::new();
612
613 for (name, value) in initial_guess {
614 if let Some(id) = self.var_table.get_id(&name) {
615 vars.insert(id, value);
616 } else {
617 unknown.push(name);
618 }
619 }
620
621 if !unknown.is_empty() {
622 unknown.sort();
623 return Err(SolverError::InvalidInput(format!(
624 "Initial guess includes unknown variables: {unknown:?}"
625 )));
626 }
627
628 self.validate_initial_guess(&vars)?;
629 Ok(vars)
630 }
631
632 fn validate_initial_guess(&self, vars: &HashMap<VarId, f64>) -> Result<(), SolverError> {
633 let mut missing = Vec::new();
634 for (idx, name) in self.var_table.names().iter().enumerate() {
635 if !vars.contains_key(&VarId::new(idx)) {
636 missing.push(name.clone());
637 }
638 }
639 if missing.is_empty() {
640 return Ok(());
641 }
642
643 missing.sort();
644 Err(SolverError::InvalidInput(format!(
645 "Initial guess missing values for variables: {missing:?}"
646 )))
647 }
648
649 fn values_by_name(&self, vars: &HashMap<VarId, f64>) -> HashMap<String, f64> {
650 let mut values = HashMap::with_capacity(self.var_table.len());
651 for (idx, name) in self.var_table.names().iter().enumerate() {
652 if let Some(value) = vars.get(&VarId::new(idx)) {
653 values.insert(name.clone(), *value);
654 }
655 }
656 values
657 }
658
659 fn build_diagnostic(
660 &self,
661 message: String,
662 iterations: usize,
663 vars: &HashMap<VarId, f64>,
664 residuals_matrix: &Matrix,
665 ) -> SolverRunDiagnostic {
666 let mut residuals = Vec::with_capacity(residuals_matrix.rows());
667 for i in 0..residuals_matrix.rows() {
668 residuals.push(residuals_matrix[(i, 0)]);
669 }
670
671 SolverRunDiagnostic {
672 message,
673 iterations,
674 error: residuals_matrix.norm(),
675 values: self.values_by_name(vars),
676 residuals,
677 }
678 }
679
680 fn line_search(
694 &self,
695 jacobian: &Jacobian,
696 vars: &HashMap<VarId, f64>,
697 delta: &Matrix,
698 current_error: f64,
699 f_vals: &mut Matrix,
700 ) -> Result<f64, SolverError> {
701 let mut alpha = 1.0; let mut new_vars = vars.clone();
703
704 for _ in 0..20 {
706 new_vars.clone_from(vars);
707 for (i, &var_id) in self.variables.iter().enumerate() {
709 new_vars.insert(var_id, vars[&var_id] + alpha * delta[(i, 0)]);
710 }
711
712 jacobian.evaluate_functions_checked_into(&new_vars, f_vals)?;
714 let new_error = f_vals.norm();
715 if new_error < current_error * (1.0 - 0.5 * alpha) {
717 return Ok(alpha);
718 }
719 alpha *= 0.5;
721
722 if alpha < 1e-10 {
724 break;
725 }
726 }
727
728 Ok(alpha.max(self.min_damping))
730 }
731
732 pub fn solve_with_line_search(
737 &self,
738 initial_guess: HashMap<String, f64>,
739 ) -> Result<Solution, SolverError> {
740 let vars = self.map_initial_guess(initial_guess)?;
741 self.solve_modified_internal(vars, true)
742 }
743}
744
745#[cfg(test)]
746mod tests {
747 use super::*;
748 use crate::compiler::Compiler;
749 use crate::exp::Exp;
750
751 struct TestRng {
752 state: u64,
753 }
754
755 impl TestRng {
756 fn new(seed: u64) -> Self {
757 Self { state: seed }
758 }
759
760 fn next_u32(&mut self) -> u32 {
761 self.state = self
762 .state
763 .wrapping_mul(6364136223846793005)
764 .wrapping_add(1);
765 (self.state >> 32) as u32
766 }
767
768 fn next_f64(&mut self) -> f64 {
769 let v = self.next_u32() as f64 / u32::MAX as f64;
770 2.0 * v - 1.0
771 }
772 }
773
774 fn linear_equation(coeffs: &[f64], vars: &[Exp], rhs: f64) -> Exp {
775 let mut sum = Exp::val(0.0);
776 for (coeff, var) in coeffs.iter().zip(vars.iter()) {
777 let term = Exp::mul(Exp::val(*coeff), var.clone());
778 sum = Exp::add(sum, term);
779 }
780 Exp::sub(sum, Exp::val(rhs))
781 }
782
783 fn solver_for(equations: Vec<Exp>) -> NewtonRaphsonSolver {
784 let compiled = Compiler::compile(&equations).expect("compile failed");
785 NewtonRaphsonSolver::new(compiled)
786 }
787
788 fn solver_for_with_vars(equations: Vec<Exp>, variables: &[&str]) -> NewtonRaphsonSolver {
789 let compiled = Compiler::compile(&equations).expect("compile failed");
790 NewtonRaphsonSolver::new_with_variables(compiled, variables)
791 .expect("failed to select solve variables")
792 }
793
794 #[test]
795 fn test_simple_system() {
796 let x = Exp::var("x");
799 let y = Exp::var("y");
800
801 let eq1 = Exp::sub(
802 Exp::add(Exp::power(x.clone(), 2.0), Exp::power(y.clone(), 2.0)),
803 Exp::val(1.0),
804 );
805 let eq2 = Exp::sub(Exp::mul(x.clone(), y.clone()), Exp::val(0.25));
806
807 let solver = solver_for(vec![eq1, eq2]);
808
809 let mut initial = HashMap::new();
810 initial.insert("x".to_string(), 0.5);
811 initial.insert("y".to_string(), 0.866);
812
813 let solution = match solver.solve(initial.clone()) {
814 Ok(sol) => sol,
815 Err(_) => match solver.solve_with_line_search(initial) {
816 Ok(sol) => sol,
817 Err(e) => panic!("Failed to solve: {:?}", e),
818 },
819 };
820 assert!(solution.converged);
821 assert!(solution.error < 1e-10);
822
823 let x_sol = solution.values.get("x").copied().unwrap();
824 let y_sol = solution.values.get("y").copied().unwrap();
825 assert!((x_sol * x_sol + y_sol * y_sol - 1.0).abs() < 1e-10);
826 assert!((x_sol * y_sol - 0.25).abs() < 1e-10);
827 }
828
829 #[test]
830 fn test_transcendental_system() {
831 let x = Exp::var("x");
833 let y = Exp::var("y");
834
835 let eq1 = Exp::sub(Exp::sin(x.clone()), Exp::mul(y.clone(), Exp::val(2.0)));
836 let eq2 = Exp::sub(Exp::cos(y.clone()), x.clone());
837
838 let solver = solver_for(vec![eq1, eq2])
839 .with_tolerance(1e-8)
840 .with_max_iterations(50)
841 .with_regularization(1e-8);
842
843 let mut initial = HashMap::new();
844 initial.insert("x".to_string(), 0.5);
845 initial.insert("y".to_string(), 0.25);
846
847 let solution = solver
848 .solve_with_line_search(initial)
849 .expect("Failed to solve transcendental system");
850
851 assert!(solution.converged);
852 let x_sol = solution.values.get("x").copied().unwrap();
853 let y_sol = solution.values.get("y").copied().unwrap();
854 assert!((x_sol.sin() - 2.0 * y_sol).abs() < 1e-8);
855 assert!((y_sol.cos() - x_sol).abs() < 1e-8);
856 }
857
858 #[test]
859 fn test_solver_errors_on_missing_variable() {
860 let x = Exp::var("x");
863 let a = Exp::var("a");
864 let eq = Exp::sub(x.clone(), a.clone());
865
866 let solver = solver_for_with_vars(vec![eq], &["x"]);
867
868 let mut initial = HashMap::new();
869 initial.insert("x".to_string(), 1.0);
870
871 let err = solver.solve(initial).expect_err("expected missing variable error");
872 match err {
873 SolverError::InvalidInput(msg) => {
874 assert!(msg.contains("a"), "unexpected message: {msg}");
875 }
876 other => panic!("expected InvalidInput, got {:?}", other),
877 }
878 }
879
880 #[test]
881 fn test_solver_invalid_input_on_missing_initial_guess_variable() {
882 let x = Exp::var("x");
883 let y = Exp::var("y");
884 let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
885
886 let solver = solver_for(vec![eq]);
887
888 let mut initial = HashMap::new();
889 initial.insert("x".to_string(), 0.0);
890
891 let err = solver
892 .solve(initial)
893 .expect_err("expected invalid input due to missing y");
894 match err {
895 SolverError::InvalidInput(msg) => {
896 assert!(msg.contains("y"), "unexpected message: {msg}");
897 }
898 other => panic!("expected InvalidInput, got {:?}", other),
899 }
900 }
901
902 #[test]
903 fn test_try_with_equation_traces_invalid_length() {
904 let x = Exp::var("x");
905 let eq = Exp::sub(x.clone(), Exp::val(1.0));
906
907 let solver = solver_for(vec![eq]);
908
909 let traces = vec![
910 Some(EquationTrace {
911 constraint_id: 0,
912 description: "eq0".to_string(),
913 }),
914 None,
915 ];
916 let err = match solver.try_with_equation_traces(traces) {
917 Ok(_) => panic!("expected invalid input for trace length mismatch"),
918 Err(err) => err,
919 };
920 match err {
921 SolverError::InvalidInput(msg) => {
922 assert!(msg.contains("Equation trace length (2)"), "{}", msg);
923 assert!(msg.contains("number of equations (1)"), "{}", msg);
924 }
925 other => panic!("expected InvalidInput, got {:?}", other),
926 }
927 }
928
929 #[test]
930 fn test_line_search_preserves_fixed_variables() {
931 let x = Exp::var("x");
936 let a = Exp::var("a");
937 let eq = Exp::sub(x.clone(), a.clone());
938
939 let solver = solver_for_with_vars(vec![eq], &["x"]);
940
941 let mut initial = HashMap::new();
942 initial.insert("x".to_string(), 0.0);
943 initial.insert("a".to_string(), 2.0);
944
945 let solution = solver
946 .solve_with_line_search(initial)
947 .expect("solver should converge when fixed variables are provided");
948 assert!(solution.converged);
949 assert!((solution.values.get("x").copied().unwrap() - 2.0).abs() < 1e-10);
950 }
951
952 #[test]
953 fn test_least_squares_qr_handles_ill_conditioned_overdetermined_system_without_regularization()
954 {
955 let eps = 2f64.powi(-27); let x = Exp::var("x");
962 let y = Exp::var("y");
963
964 let eq1 = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(2.0));
967 let eq2 = Exp::sub(
968 Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 + eps))),
969 Exp::val(2.0 + eps),
970 );
971 let eq3 = Exp::sub(
972 Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 - eps))),
973 Exp::val(2.0 - eps),
974 );
975
976 let solver = solver_for(vec![eq1, eq2, eq3])
977 .with_regularization(0.0)
978 .with_damping(1.0)
979 .with_max_iterations(10)
980 .with_tolerance(1e-10);
981
982 let mut initial = HashMap::new();
983 initial.insert("x".to_string(), 0.0);
984 initial.insert("y".to_string(), 0.0);
985
986 let solution = solver
987 .solve(initial)
988 .expect("expected solver to converge with QR least squares");
989
990 assert!(solution.converged, "{:?}", solution);
991 assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-8);
992 assert!((solution.values.get("y").copied().unwrap() - 1.0).abs() < 1e-8);
993 }
994
995 #[test]
996 fn test_underconstrained_system_returns_min_norm_solution() {
997 let x = Exp::var("x");
1000 let y = Exp::var("y");
1001
1002 let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
1003 let solver = solver_for(vec![eq])
1004 .with_tolerance(1e-12)
1005 .with_max_iterations(10);
1006
1007 let mut initial = HashMap::new();
1008 initial.insert("x".to_string(), 0.0);
1009 initial.insert("y".to_string(), 0.0);
1010
1011 let solution = solver.solve(initial).expect("expected solver to converge");
1012 assert!(solution.converged, "{:?}", solution);
1013 assert!((solution.values.get("x").copied().unwrap() - 0.5).abs() < 1e-10);
1014 assert!((solution.values.get("y").copied().unwrap() - 0.5).abs() < 1e-10);
1015 }
1016
1017 #[test]
1018 fn test_rank_deficient_overconstrained_recovers_with_regularization() {
1019 let x = Exp::var("x");
1021 let y = Exp::var("y");
1022 let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1023
1024 let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1025 let eq2 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()), Exp::val(2.0));
1026 let eq3 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()), Exp::val(3.0));
1027
1028 let solver = solver_for(vec![eq1, eq2, eq3])
1029 .with_tolerance(1e-12)
1030 .with_max_iterations(10);
1031
1032 let mut initial = HashMap::new();
1033 initial.insert("x".to_string(), 0.0);
1034 initial.insert("y".to_string(), 0.0);
1035
1036 let solution = solver.solve(initial).expect("expected solver to converge");
1037 assert!(solution.converged, "{:?}", solution);
1038 assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1039 assert!((solution.values.get("y").copied().unwrap() - 0.0).abs() < 1e-10);
1040 }
1041
1042 #[test]
1043 fn test_rank_deficient_overconstrained_fails_without_regularization() {
1044 let x = Exp::var("x");
1045 let y = Exp::var("y");
1046 let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1047
1048 let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1049 let eq2 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()), Exp::val(2.0));
1050 let eq3 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()), Exp::val(3.0));
1051
1052 let solver = solver_for(vec![eq1, eq2, eq3])
1053 .with_regularization(0.0)
1054 .with_tolerance(1e-12)
1055 .with_max_iterations(10);
1056
1057 let mut initial = HashMap::new();
1058 initial.insert("x".to_string(), 0.0);
1059 initial.insert("y".to_string(), 0.0);
1060
1061 let err = solver.solve(initial).expect_err("expected singular matrix error");
1062 assert!(matches!(err, SolverError::SingularMatrix(_)), "{:?}", err);
1063 }
1064
1065 #[test]
1066 fn test_overconstrained_consistent_system_solves() {
1067 let x = Exp::var("x");
1069 let eq1 = Exp::sub(x.clone(), Exp::val(1.0));
1070 let eq2 = Exp::sub(Exp::mul(x.clone(), Exp::val(2.0)), Exp::val(2.0));
1071
1072 let solver = solver_for(vec![eq1, eq2])
1073 .with_tolerance(1e-12)
1074 .with_max_iterations(10);
1075
1076 let mut initial = HashMap::new();
1077 initial.insert("x".to_string(), 0.0);
1078
1079 let solution = solver.solve(initial).expect("expected solver to converge");
1080 assert!(solution.converged, "{:?}", solution);
1081 assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1082 }
1083
1084 #[test]
1085 fn test_random_square_linear_systems() {
1086 let mut rng = TestRng::new(0x51ab_1e55_cafe_f00d);
1087 for _ in 0..5 {
1088 let n = 3;
1089 let mut a = vec![vec![0.0; n]; n];
1090 for i in 0..n {
1091 for j in 0..n {
1092 a[i][j] = rng.next_f64();
1093 }
1094 a[i][i] += 2.0;
1095 }
1096
1097 let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1098 let mut b = vec![0.0; n];
1099 for i in 0..n {
1100 for j in 0..n {
1101 b[i] += a[i][j] * x_true[j];
1102 }
1103 }
1104
1105 let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1106 let equations: Vec<Exp> = (0..n)
1107 .map(|i| linear_equation(&a[i], &vars, b[i]))
1108 .collect();
1109
1110 let solver = solver_for(equations)
1111 .with_tolerance(1e-12)
1112 .with_max_iterations(20);
1113
1114 let mut initial = HashMap::new();
1115 for i in 0..n {
1116 initial.insert(format!("x{i}"), 0.0);
1117 }
1118
1119 let solution = solver.solve(initial).expect("expected to solve");
1120 for i in 0..n {
1121 let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1122 assert!((value - x_true[i]).abs() < 1e-8);
1123 }
1124 }
1125 }
1126
1127 #[test]
1128 fn test_random_overconstrained_linear_systems() {
1129 let mut rng = TestRng::new(0xa11c_e551_dead_beef);
1130 let m = 5;
1131 let n = 3;
1132
1133 for _ in 0..5 {
1134 let mut a = vec![vec![0.0; n]; m];
1135 for i in 0..m {
1136 for j in 0..n {
1137 a[i][j] = rng.next_f64();
1138 }
1139 }
1140 for i in 0..n {
1141 a[i][i] += 2.0;
1142 }
1143
1144 let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1145 let mut b = vec![0.0; m];
1146 for i in 0..m {
1147 for j in 0..n {
1148 b[i] += a[i][j] * x_true[j];
1149 }
1150 }
1151
1152 let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1153 let equations: Vec<Exp> = (0..m)
1154 .map(|i| linear_equation(&a[i], &vars, b[i]))
1155 .collect();
1156
1157 let solver = solver_for(equations)
1158 .with_tolerance(1e-12)
1159 .with_max_iterations(20);
1160
1161 let mut initial = HashMap::new();
1162 for i in 0..n {
1163 initial.insert(format!("x{i}"), 0.0);
1164 }
1165
1166 let solution = solver.solve(initial).expect("expected to solve");
1167 for i in 0..n {
1168 let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1169 assert!((value - x_true[i]).abs() < 1e-8);
1170 }
1171 }
1172 }
1173
1174 #[test]
1175 fn test_random_underdetermined_min_norm_structure() {
1176 let mut rng = TestRng::new(0x0ddc_affe_fade_bead);
1177 let _m = 2;
1178 let n = 4;
1179 let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1180
1181 for _ in 0..5 {
1182 let b0 = rng.next_f64();
1183 let b1 = rng.next_f64();
1184 let x2_zero = Exp::mul(vars[2].clone(), Exp::val(0.0));
1185 let x3_zero = Exp::mul(vars[3].clone(), Exp::val(0.0));
1186 let eq1 = Exp::sub(Exp::add(vars[0].clone(), x2_zero.clone()), Exp::val(b0));
1187 let eq2 = Exp::sub(Exp::add(vars[1].clone(), x3_zero.clone()), Exp::val(b1));
1188
1189 let solver = solver_for(vec![eq1, eq2])
1190 .with_tolerance(1e-12)
1191 .with_max_iterations(20);
1192
1193 let mut initial = HashMap::new();
1194 for i in 0..n {
1195 initial.insert(format!("x{i}"), 0.0);
1196 }
1197
1198 let solution = solver.solve(initial).expect("expected to solve");
1199 assert!((solution.values.get("x0").copied().unwrap() - b0).abs() < 1e-10);
1200 assert!((solution.values.get("x1").copied().unwrap() - b1).abs() < 1e-10);
1201 assert!(solution.values.get("x2").copied().unwrap().abs() < 1e-10);
1202 assert!(solution.values.get("x3").copied().unwrap().abs() < 1e-10);
1203 }
1204 }
1205}