1use crate::compiler::{CompiledExp, CompiledSystem, VarId, VarTable};
26use crate::exp::MissingVarError;
27use crate::jacobian::Jacobian;
28use crate::matrix::{LeastSquaresQrInfo, Matrix, MatrixError};
29use crate::mode::{build_thread_pool, Mode};
30use rayon::ThreadPool;
31use std::collections::HashMap;
32
33struct LeastSquaresWorkspace {
34 augmented_j: Matrix,
35 augmented_b: Matrix,
36 num_equations: usize,
37 num_variables: usize,
38}
39
40impl LeastSquaresWorkspace {
41 fn new(num_equations: usize, num_variables: usize, sqrt_reg: f64) -> Self {
42 let mut augmented_j = Matrix::new(num_equations + num_variables, num_variables);
43 for i in 0..num_variables {
44 augmented_j[(num_equations + i, i)] = sqrt_reg;
45 }
46
47 LeastSquaresWorkspace {
48 augmented_j,
49 augmented_b: Matrix::new(num_equations + num_variables, 1),
50 num_equations,
51 num_variables,
52 }
53 }
54}
55
56enum LeastSquaresSolveError {
57 InvalidInput(MatrixError),
58 Singular { qr_err: String, normal_eq_err: String },
59}
60
61pub struct NewtonRaphsonSolver {
70 equations: Vec<CompiledExp>,
72 variables: Vec<VarId>,
74 var_table: VarTable,
76 max_iterations: usize,
78 tolerance: f64,
80 damping_factor: f64,
83 min_damping: f64,
85 regularization: f64,
87 equation_traces: Vec<Option<EquationTrace>>,
89 mode: Mode,
91 pool: Option<ThreadPool>,
93}
94
95#[derive(Debug, Clone)]
97pub struct Solution {
98 pub values: HashMap<String, f64>,
100 pub iterations: usize,
102 pub error: f64,
104 pub converged: bool,
106 pub convergence_history: Vec<f64>,
108}
109
110#[derive(Debug, Clone)]
112pub struct EquationTrace {
113 pub constraint_id: usize,
114 pub description: String,
115}
116
117#[derive(Debug, Clone)]
119pub struct SolverRunDiagnostic {
120 pub message: String,
121 pub iterations: usize,
122 pub error: f64,
123 pub values: HashMap<String, f64>,
125 pub residuals: Vec<f64>,
126}
127
128#[derive(Debug, Clone)]
130pub enum SolverError {
131 SingularMatrix(SolverRunDiagnostic),
133 NoConvergence(SolverRunDiagnostic),
135 MissingVariable(MissingVarError),
137 InvalidInput(String),
139}
140
141impl From<MissingVarError> for SolverError {
142 fn from(value: MissingVarError) -> Self {
143 SolverError::MissingVariable(value)
144 }
145}
146
147impl From<MatrixError> for SolverError {
148 fn from(value: MatrixError) -> Self {
149 SolverError::InvalidInput(value.to_string())
150 }
151}
152
153impl NewtonRaphsonSolver {
154 pub fn new(compiled: CompiledSystem) -> Self {
163 let variables = compiled.var_table.all_var_ids();
164 Self::build(compiled, variables)
165 }
166
167 pub fn new_with_mode(compiled: CompiledSystem, mode: Mode) -> Result<Self, SolverError> {
169 let variables = compiled.var_table.all_var_ids();
170 Self::build_with_mode(compiled, variables, mode)
171 }
172
173 pub fn new_with_variables(
178 compiled: CompiledSystem,
179 variables: &[&str],
180 ) -> Result<Self, SolverError> {
181 let mut ids = Vec::with_capacity(variables.len());
182 let mut seen = std::collections::HashSet::new();
183 for name in variables {
184 let id = compiled
185 .var_table
186 .get_id(name)
187 .ok_or_else(|| SolverError::InvalidInput(format!("Unknown variable '{name}'")))?;
188 if seen.insert(id) {
189 ids.push(id);
190 }
191 }
192 Ok(Self::build(compiled, ids))
193 }
194
195 pub fn new_with_variables_and_mode(
197 compiled: CompiledSystem,
198 variables: &[&str],
199 mode: Mode,
200 ) -> Result<Self, SolverError> {
201 let mut ids = Vec::with_capacity(variables.len());
202 let mut seen = std::collections::HashSet::new();
203 for name in variables {
204 let id = compiled
205 .var_table
206 .get_id(name)
207 .ok_or_else(|| SolverError::InvalidInput(format!("Unknown variable '{name}'")))?;
208 if seen.insert(id) {
209 ids.push(id);
210 }
211 }
212
213 Self::build_with_mode(compiled, ids, mode)
214 }
215
216 fn build(compiled: CompiledSystem, variables: Vec<VarId>) -> Self {
217 let is_over_constrained = compiled.equations.len() > variables.len();
220
221 NewtonRaphsonSolver {
222 equations: compiled.equations,
223 variables,
224 var_table: compiled.var_table,
225 max_iterations: if is_over_constrained { 200 } else { 100 },
227 tolerance: if is_over_constrained { 1e-8 } else { 1e-10 },
229 damping_factor: if is_over_constrained { 0.7 } else { 1.0 },
231 min_damping: 0.001,
232 regularization: if is_over_constrained { 1e-4 } else { 1e-8 },
234 equation_traces: Vec::new(),
235 mode: Mode::Serial,
236 pool: None,
237 }
238 }
239
240 fn build_with_mode(
241 compiled: CompiledSystem,
242 variables: Vec<VarId>,
243 mode: Mode,
244 ) -> Result<Self, SolverError> {
245 let mut solver = Self::build(compiled, variables);
246 solver.set_mode(mode)?;
247 Ok(solver)
248 }
249
250 pub fn with_mode(mut self, mode: Mode) -> Result<Self, SolverError> {
252 self.set_mode(mode)?;
253 Ok(self)
254 }
255
256 pub fn mode(&self) -> &Mode {
258 &self.mode
259 }
260
261 fn set_mode(&mut self, mode: Mode) -> Result<(), SolverError> {
262 let pool = build_thread_pool(&mode).map_err(SolverError::InvalidInput)?;
263 self.mode = mode;
264 self.pool = pool;
265 Ok(())
266 }
267
268 fn parallel_enabled(&self) -> bool {
269 self.pool.is_some()
270 }
271
272 fn run_in_mode<F, R>(&self, f: F) -> R
273 where
274 F: FnOnce() -> R + Send,
275 R: Send,
276 {
277 if let Some(pool) = &self.pool {
278 pool.install(f)
279 } else {
280 f()
281 }
282 }
283
284 pub fn try_with_equation_traces(
286 mut self,
287 traces: Vec<Option<EquationTrace>>,
288 ) -> Result<Self, SolverError> {
289 if !self.equations.is_empty() && traces.len() != self.equations.len() {
290 return Err(SolverError::InvalidInput(format!(
291 "Equation trace length ({}) does not match number of equations ({})",
292 traces.len(),
293 self.equations.len()
294 )));
295 }
296 self.equation_traces = traces;
297 Ok(self)
298 }
299
300 pub fn with_equation_traces(self, traces: Vec<Option<EquationTrace>>) -> Self {
302 self.try_with_equation_traces(traces)
303 .unwrap_or_else(|err| match err {
304 SolverError::InvalidInput(msg) => panic!("{}", msg),
305 other => panic!("{other:?}"),
306 })
307 }
308
309 pub fn trace_for_equation(&self, equation_index: usize) -> Option<&EquationTrace> {
311 self.equation_traces
312 .get(equation_index)
313 .and_then(|entry| entry.as_ref())
314 }
315
316 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
318 self.max_iterations = max_iterations;
319 self
320 }
321
322 pub fn with_tolerance(mut self, tolerance: f64) -> Self {
324 self.tolerance = tolerance;
325 self
326 }
327
328 pub fn with_damping(mut self, damping_factor: f64) -> Self {
330 self.damping_factor = damping_factor.clamp(0.1, 1.0);
331 self
332 }
333
334 pub fn with_regularization(mut self, regularization: f64) -> Self {
336 self.regularization = regularization.max(0.0);
337 self
338 }
339
340 pub fn solve(&self, initial_guess: HashMap<String, f64>) -> Result<Solution, SolverError> {
342 let vars = self.map_initial_guess(initial_guess)?;
343 self.run_in_mode(|| self.solve_modified_internal(vars, false))
344 }
345
346 fn solve_modified_internal(
361 &self,
362 initial_guess: HashMap<VarId, f64>,
363 use_line_search: bool,
364 ) -> Result<Solution, SolverError> {
365 self.validate_initial_guess(&initial_guess)?;
366
367 let mut vars = initial_guess.clone();
368 let jacobian = Jacobian::new(self.equations.clone(), self.variables.clone());
369 let mut jacobian_workspace = jacobian.workspace();
371 let mut convergence_history = Vec::with_capacity(self.max_iterations);
372 let mut damping = self.damping_factor;
374 let mut last_error = f64::INFINITY;
375
376 let num_equations = self.equations.len();
377 let num_variables = self.variables.len();
378
379 let mut f_vals = Matrix::new(num_equations, 1);
380 let mut f_neg = Matrix::new(num_equations, 1);
381 jacobian_workspace.replace_jacobian(jacobian.evaluate_checked(&vars)?);
382
383 let mut line_search_f_vals = Matrix::new(num_equations, 1);
384
385 let mut least_squares_workspace =
386 if self.regularization > 0.0 && num_equations != num_variables {
387 Some(LeastSquaresWorkspace::new(
388 num_equations,
389 num_variables,
390 self.regularization.sqrt(),
391 ))
392 } else {
393 None
394 };
395 let parallel = self.parallel_enabled();
396
397 for iter in 0..self.max_iterations {
398 jacobian.evaluate_functions_checked_into(&vars, &mut f_vals)?;
400 let error = f_vals.norm();
401 convergence_history.push(error);
402
403 if error < self.tolerance {
405 return Ok(Solution {
406 values: self.values_by_name(&vars),
407 iterations: iter + 1,
408 error,
409 converged: true,
410 convergence_history,
411 });
412 }
413
414 if iter > 0 && (error - last_error).abs() < self.tolerance * 0.1 {
416 damping *= 0.5; if damping < self.min_damping {
418 let diag = self.build_diagnostic(
419 format!(
420 "Solver stagnated at iteration {} with error {:.2e}",
421 iter + 1,
422 error
423 ),
424 iter + 1,
425 &vars,
426 &f_vals,
427 );
428 return Err(SolverError::NoConvergence(diag));
429 }
430 }
431
432 if error > last_error * 1.5 {
434 damping *= 0.5;
436 } else if error < last_error * 0.5 {
437 damping = (damping * 1.2).min(1.0);
439 }
440
441 last_error = error;
442
443 for i in 0..num_equations {
444 f_neg[(i, 0)] = -f_vals[(i, 0)];
445 }
446
447 let delta = {
449 let j_matrix = jacobian_workspace.jacobian();
450 if j_matrix.rows() == j_matrix.cols() {
452 match j_matrix.solve_lu(&f_neg) {
455 Ok(d) => d,
456 Err(_) => {
457 let diag_size = j_matrix.rows().min(j_matrix.cols());
459 for i in 0..diag_size {
460 j_matrix[(i, i)] += self.regularization * 100.0;
461 }
462 match j_matrix.solve_lu(&f_neg) {
463 Ok(d) => d,
464 Err(e) => {
465 let diag = self.build_diagnostic(
466 format!(
467 "Matrix is singular even with regularization: {e}"
468 ),
469 iter + 1,
470 &vars,
471 &f_vals,
472 );
473 return Err(SolverError::SingularMatrix(diag));
474 }
475 }
476 }
477 }
478 } else {
479 match self.solve_least_squares_delta(
480 j_matrix,
481 &f_neg,
482 least_squares_workspace.as_mut(),
483 parallel,
484 ) {
485 Ok(delta) => delta,
486 Err(LeastSquaresSolveError::InvalidInput(err)) => {
487 return Err(SolverError::InvalidInput(err.to_string()));
488 }
489 Err(LeastSquaresSolveError::Singular {
490 qr_err,
491 normal_eq_err,
492 }) => {
493 let diag = self.build_diagnostic(
494 format!(
495 "Least squares system is singular (qr_err: {qr_err}; normal_eq_err: {normal_eq_err})"
496 ),
497 iter + 1,
498 &vars,
499 &f_vals,
500 );
501 return Err(SolverError::SingularMatrix(diag));
502 }
503 }
504 }
505 };
506
507 let step_size = if use_line_search {
509 self.line_search(&jacobian, &vars, &delta, error, &mut line_search_f_vals)?
511 } else {
512 damping
514 };
515
516 for (i, &var_id) in self.variables.iter().enumerate() {
518 vars.insert(var_id, vars[&var_id] + step_size * delta[(i, 0)]);
519 }
520
521 if delta.norm() * step_size < self.tolerance {
523 return Ok(Solution {
524 values: self.values_by_name(&vars),
525 iterations: iter + 1,
526 error,
527 converged: true,
528 convergence_history,
529 });
530 }
531
532 jacobian.evaluate_checked_in_workspace(&vars, &mut jacobian_workspace)?;
534 }
535
536 let residuals = jacobian.evaluate_functions_checked(&vars)?;
538 let diag = self.build_diagnostic(
539 format!(
540 "Failed to converge after {} iterations. Final error: {:.2e}",
541 self.max_iterations,
542 residuals.norm()
543 ),
544 self.max_iterations,
545 &vars,
546 &residuals,
547 );
548 Err(SolverError::NoConvergence(diag))
549 }
550
551 fn solve_least_squares_delta(
552 &self,
553 j_matrix: &Matrix,
554 f_neg: &Matrix,
555 workspace: Option<&mut LeastSquaresWorkspace>,
556 parallel: bool,
557 ) -> Result<Matrix, LeastSquaresSolveError> {
558 let mut workspace = workspace;
559 const COND_LIMIT: f64 = 1e12;
564
565 let full_rank = j_matrix.rows().min(j_matrix.cols());
566 let is_ill_conditioned = |info: &LeastSquaresQrInfo| -> bool {
567 info.rank < full_rank
568 || !info.cond_est.is_finite()
569 || info.cond_est > COND_LIMIT
570 };
571
572 let mut qr_err: Option<String> = None;
573 let mut unreg_delta: Option<Matrix> = None;
574 let mut unreg_info: Option<LeastSquaresQrInfo> = None;
575
576 match j_matrix.solve_least_squares_qr_with_info_with_parallel(f_neg, parallel) {
577 Ok((delta, info)) => {
578 if !is_ill_conditioned(&info) {
579 return Ok(delta);
580 }
581 unreg_delta = Some(delta);
582 unreg_info = Some(info);
583 }
584 Err(err) => {
585 qr_err = Some(err);
586 }
587 }
588
589 if self.regularization <= 0.0 {
590 if let Some(delta) = unreg_delta {
591 return Ok(delta);
592 }
593 return Err(LeastSquaresSolveError::Singular {
594 qr_err: qr_err.unwrap_or_else(|| "QR least squares failed".to_string()),
595 normal_eq_err: "Regularization disabled".to_string(),
596 });
597 }
598
599 let reg_factors = [1.0, 10.0, 100.0];
600 let mut last_err: Option<String> = None;
601 let mut last_solution: Option<Matrix> = None;
602
603 for factor in reg_factors {
604 let lambda = self.regularization * factor;
605 if lambda <= 0.0 {
606 continue;
607 }
608 let sqrt_reg = lambda.sqrt();
609
610 let reg_result = match workspace.as_deref_mut() {
611 Some(workspace) => {
612 if workspace.num_equations != j_matrix.rows()
613 || workspace.num_variables != j_matrix.cols()
614 {
615 return Err(LeastSquaresSolveError::InvalidInput(
616 MatrixError::DimensionMismatch {
617 operation: "least_squares",
618 left: (workspace.augmented_j.rows(), workspace.augmented_j.cols()),
619 right: (j_matrix.rows(), j_matrix.cols()),
620 },
621 ));
622 }
623
624 for i in 0..workspace.num_equations {
625 for j in 0..workspace.num_variables {
626 workspace.augmented_j[(i, j)] = j_matrix[(i, j)];
627 }
628 workspace.augmented_b[(i, 0)] = f_neg[(i, 0)];
629 }
630 for i in 0..workspace.num_variables {
631 for j in 0..workspace.num_variables {
632 workspace.augmented_j[(workspace.num_equations + i, j)] = 0.0;
633 }
634 workspace.augmented_j[(workspace.num_equations + i, i)] = sqrt_reg;
635 workspace.augmented_b[(workspace.num_equations + i, 0)] = 0.0;
636 }
637
638 workspace
639 .augmented_j
640 .solve_least_squares_qr_with_info_with_parallel(
641 &workspace.augmented_b,
642 parallel,
643 )
644 }
645 None => {
646 let mut augmented_j =
647 Matrix::new(j_matrix.rows() + j_matrix.cols(), j_matrix.cols());
648 let mut augmented_b = Matrix::new(j_matrix.rows() + j_matrix.cols(), 1);
649 for i in 0..j_matrix.rows() {
650 for j in 0..j_matrix.cols() {
651 augmented_j[(i, j)] = j_matrix[(i, j)];
652 }
653 augmented_b[(i, 0)] = f_neg[(i, 0)];
654 }
655 for i in 0..j_matrix.cols() {
656 augmented_j[(j_matrix.rows() + i, i)] = sqrt_reg;
657 augmented_b[(j_matrix.rows() + i, 0)] = 0.0;
658 }
659 augmented_j.solve_least_squares_qr_with_info_with_parallel(
660 &augmented_b,
661 parallel,
662 )
663 }
664 };
665
666 match reg_result {
667 Ok((delta, info)) => {
668 last_solution = Some(delta);
669 if !is_ill_conditioned(&info) {
670 return Ok(last_solution.unwrap());
671 }
672 }
673 Err(err) => {
674 last_err = Some(err);
675 }
676 }
677 }
678
679 if let Some(delta) = last_solution {
680 return Ok(delta);
681 }
682
683 Err(LeastSquaresSolveError::Singular {
684 qr_err: qr_err
685 .or_else(|| {
686 unreg_info
687 .map(|info| format!("QR ill-conditioned (rank {}, cond {:.2e})", info.rank, info.cond_est))
688 })
689 .unwrap_or_else(|| "QR least squares failed".to_string()),
690 normal_eq_err: last_err.unwrap_or_else(|| "Regularized QR failed".to_string()),
691 })
692 }
693
694 fn map_initial_guess(
695 &self,
696 initial_guess: HashMap<String, f64>,
697 ) -> Result<HashMap<VarId, f64>, SolverError> {
698 let mut unknown = Vec::new();
699 let mut vars = HashMap::new();
700
701 for (name, value) in initial_guess {
702 if let Some(id) = self.var_table.get_id(&name) {
703 vars.insert(id, value);
704 } else {
705 unknown.push(name);
706 }
707 }
708
709 if !unknown.is_empty() {
710 unknown.sort();
711 return Err(SolverError::InvalidInput(format!(
712 "Initial guess includes unknown variables: {unknown:?}"
713 )));
714 }
715
716 self.validate_initial_guess(&vars)?;
717 Ok(vars)
718 }
719
720 fn validate_initial_guess(&self, vars: &HashMap<VarId, f64>) -> Result<(), SolverError> {
721 let mut missing = Vec::new();
722 for (idx, name) in self.var_table.names().iter().enumerate() {
723 if !vars.contains_key(&VarId::new(idx)) {
724 missing.push(name.clone());
725 }
726 }
727 if missing.is_empty() {
728 return Ok(());
729 }
730
731 missing.sort();
732 Err(SolverError::InvalidInput(format!(
733 "Initial guess missing values for variables: {missing:?}"
734 )))
735 }
736
737 fn values_by_name(&self, vars: &HashMap<VarId, f64>) -> HashMap<String, f64> {
738 let mut values = HashMap::with_capacity(self.var_table.len());
739 for (idx, name) in self.var_table.names().iter().enumerate() {
740 if let Some(value) = vars.get(&VarId::new(idx)) {
741 values.insert(name.clone(), *value);
742 }
743 }
744 values
745 }
746
747 fn build_diagnostic(
748 &self,
749 message: String,
750 iterations: usize,
751 vars: &HashMap<VarId, f64>,
752 residuals_matrix: &Matrix,
753 ) -> SolverRunDiagnostic {
754 let mut residuals = Vec::with_capacity(residuals_matrix.rows());
755 for i in 0..residuals_matrix.rows() {
756 residuals.push(residuals_matrix[(i, 0)]);
757 }
758
759 SolverRunDiagnostic {
760 message,
761 iterations,
762 error: residuals_matrix.norm(),
763 values: self.values_by_name(vars),
764 residuals,
765 }
766 }
767
768 fn line_search(
782 &self,
783 jacobian: &Jacobian,
784 vars: &HashMap<VarId, f64>,
785 delta: &Matrix,
786 current_error: f64,
787 f_vals: &mut Matrix,
788 ) -> Result<f64, SolverError> {
789 let mut alpha = 1.0; let mut new_vars = vars.clone();
791
792 for _ in 0..20 {
794 new_vars.clone_from(vars);
795 for (i, &var_id) in self.variables.iter().enumerate() {
797 new_vars.insert(var_id, vars[&var_id] + alpha * delta[(i, 0)]);
798 }
799
800 jacobian.evaluate_functions_checked_into(&new_vars, f_vals)?;
802 let new_error = f_vals.norm();
803 if new_error < current_error * (1.0 - 0.5 * alpha) {
805 return Ok(alpha);
806 }
807 alpha *= 0.5;
809
810 if alpha < 1e-10 {
812 break;
813 }
814 }
815
816 Ok(alpha.max(self.min_damping))
818 }
819
820 pub fn solve_with_line_search(
825 &self,
826 initial_guess: HashMap<String, f64>,
827 ) -> Result<Solution, SolverError> {
828 let vars = self.map_initial_guess(initial_guess)?;
829 self.run_in_mode(|| self.solve_modified_internal(vars, true))
830 }
831}
832
833#[cfg(test)]
834mod tests {
835 use super::*;
836 use crate::compiler::Compiler;
837 use crate::exp::Exp;
838 use crate::Mode;
839
840 struct TestRng {
841 state: u64,
842 }
843
844 impl TestRng {
845 fn new(seed: u64) -> Self {
846 Self { state: seed }
847 }
848
849 fn next_u32(&mut self) -> u32 {
850 self.state = self
851 .state
852 .wrapping_mul(6364136223846793005)
853 .wrapping_add(1);
854 (self.state >> 32) as u32
855 }
856
857 fn next_f64(&mut self) -> f64 {
858 let v = self.next_u32() as f64 / u32::MAX as f64;
859 2.0 * v - 1.0
860 }
861 }
862
863 fn linear_equation(coeffs: &[f64], vars: &[Exp], rhs: f64) -> Exp {
864 let mut sum = Exp::val(0.0);
865 for (coeff, var) in coeffs.iter().zip(vars.iter()) {
866 let term = Exp::mul(Exp::val(*coeff), var.clone());
867 sum = Exp::add(sum, term);
868 }
869 Exp::sub(sum, Exp::val(rhs))
870 }
871
872 fn test_modes() -> Vec<Mode> {
873 let threads = std::thread::available_parallelism()
874 .map(|n| n.get())
875 .unwrap_or(1)
876 .min(4)
877 .max(1);
878 vec![Mode::Serial, Mode::Parallel { thread_count: threads }]
879 }
880
881 fn assert_solution_close(serial: &Solution, parallel: &Solution, tol: f64) {
882 assert_eq!(serial.values.len(), parallel.values.len());
883 for (name, value) in &serial.values {
884 let other = parallel.values.get(name).expect("missing variable");
885 let diff = (value - other).abs();
886 assert!(diff <= tol, "var {name} mismatch: {value} vs {other}");
887 }
888 }
889
890 fn assert_error_matches(serial: &SolverError, parallel: &SolverError) {
891 assert_eq!(
892 std::mem::discriminant(serial),
893 std::mem::discriminant(parallel)
894 );
895 match (serial, parallel) {
896 (SolverError::InvalidInput(a), SolverError::InvalidInput(b)) => assert_eq!(a, b),
897 (SolverError::MissingVariable(a), SolverError::MissingVariable(b)) => {
898 assert_eq!(a, b)
899 }
900 _ => {}
901 }
902 }
903
904 fn solver_for_mode(equations: Vec<Exp>, mode: Mode) -> NewtonRaphsonSolver {
905 let compiled = Compiler::compile(&equations).expect("compile failed");
906 NewtonRaphsonSolver::new_with_mode(compiled, mode).expect("failed to set solver mode")
907 }
908
909 fn solver_for_with_vars_mode(
910 equations: Vec<Exp>,
911 variables: &[&str],
912 mode: Mode,
913 ) -> NewtonRaphsonSolver {
914 let compiled = Compiler::compile(&equations).expect("compile failed");
915 NewtonRaphsonSolver::new_with_variables_and_mode(compiled, variables, mode)
916 .expect("failed to select solve variables or set mode")
917 }
918
919 #[test]
920 fn test_simple_system() {
921 let mut serial_solution: Option<Solution> = None;
922 for mode in test_modes() {
923 let is_serial = matches!(mode, Mode::Serial);
924 let x = Exp::var("x");
927 let y = Exp::var("y");
928
929 let eq1 = Exp::sub(
930 Exp::add(Exp::power(x.clone(), 2.0), Exp::power(y.clone(), 2.0)),
931 Exp::val(1.0),
932 );
933 let eq2 = Exp::sub(Exp::mul(x.clone(), y.clone()), Exp::val(0.25));
934
935 let solver = solver_for_mode(vec![eq1, eq2], mode);
936
937 let mut initial = HashMap::new();
938 initial.insert("x".to_string(), 0.5);
939 initial.insert("y".to_string(), 0.866);
940
941 let solution = match solver.solve(initial.clone()) {
942 Ok(sol) => sol,
943 Err(_) => match solver.solve_with_line_search(initial) {
944 Ok(sol) => sol,
945 Err(e) => panic!("Failed to solve: {:?}", e),
946 },
947 };
948 if is_serial {
949 serial_solution = Some(solution.clone());
950 } else {
951 let serial = serial_solution.as_ref().expect("serial solution missing");
952 assert_solution_close(serial, &solution, 1e-8);
953 }
954 assert!(solution.converged);
955 assert!(solution.error < 1e-10);
956
957 let x_sol = solution.values.get("x").copied().unwrap();
958 let y_sol = solution.values.get("y").copied().unwrap();
959 assert!((x_sol * x_sol + y_sol * y_sol - 1.0).abs() < 1e-10);
960 assert!((x_sol * y_sol - 0.25).abs() < 1e-10);
961 }
962 }
963
964 #[test]
965 fn test_transcendental_system() {
966 let mut serial_solution: Option<Solution> = None;
967 for mode in test_modes() {
968 let is_serial = matches!(mode, Mode::Serial);
969 let x = Exp::var("x");
971 let y = Exp::var("y");
972
973 let eq1 = Exp::sub(Exp::sin(x.clone()), Exp::mul(y.clone(), Exp::val(2.0)));
974 let eq2 = Exp::sub(Exp::cos(y.clone()), x.clone());
975
976 let solver = solver_for_mode(vec![eq1, eq2], mode)
977 .with_tolerance(1e-8)
978 .with_max_iterations(50)
979 .with_regularization(1e-8);
980
981 let mut initial = HashMap::new();
982 initial.insert("x".to_string(), 0.5);
983 initial.insert("y".to_string(), 0.25);
984
985 let solution = solver
986 .solve_with_line_search(initial)
987 .expect("Failed to solve transcendental system");
988
989 if is_serial {
990 serial_solution = Some(solution.clone());
991 } else {
992 let serial = serial_solution.as_ref().expect("serial solution missing");
993 assert_solution_close(serial, &solution, 1e-8);
994 }
995 assert!(solution.converged);
996 let x_sol = solution.values.get("x").copied().unwrap();
997 let y_sol = solution.values.get("y").copied().unwrap();
998 assert!((x_sol.sin() - 2.0 * y_sol).abs() < 1e-8);
999 assert!((y_sol.cos() - x_sol).abs() < 1e-8);
1000 }
1001 }
1002
1003 #[test]
1004 fn test_solver_errors_on_missing_variable() {
1005 let mut serial_error: Option<SolverError> = None;
1006 for mode in test_modes() {
1007 let is_serial = matches!(mode, Mode::Serial);
1008 let x = Exp::var("x");
1011 let a = Exp::var("a");
1012 let eq = Exp::sub(x.clone(), a.clone());
1013
1014 let solver = solver_for_with_vars_mode(vec![eq], &["x"], mode);
1015
1016 let mut initial = HashMap::new();
1017 initial.insert("x".to_string(), 1.0);
1018
1019 let err = solver.solve(initial).expect_err("expected missing variable error");
1020 if is_serial {
1021 serial_error = Some(err.clone());
1022 } else {
1023 let serial = serial_error.as_ref().expect("serial error missing");
1024 assert_error_matches(serial, &err);
1025 }
1026 match err {
1027 SolverError::InvalidInput(msg) => {
1028 assert!(msg.contains("a"), "unexpected message: {msg}");
1029 }
1030 other => panic!("expected InvalidInput, got {:?}", other),
1031 }
1032 }
1033 }
1034
1035 #[test]
1036 fn test_solver_invalid_input_on_missing_initial_guess_variable() {
1037 let mut serial_error: Option<SolverError> = None;
1038 for mode in test_modes() {
1039 let is_serial = matches!(mode, Mode::Serial);
1040 let x = Exp::var("x");
1041 let y = Exp::var("y");
1042 let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
1043
1044 let solver = solver_for_mode(vec![eq], mode);
1045
1046 let mut initial = HashMap::new();
1047 initial.insert("x".to_string(), 0.0);
1048
1049 let err = solver
1050 .solve(initial)
1051 .expect_err("expected invalid input due to missing y");
1052 if is_serial {
1053 serial_error = Some(err.clone());
1054 } else {
1055 let serial = serial_error.as_ref().expect("serial error missing");
1056 assert_error_matches(serial, &err);
1057 }
1058 match err {
1059 SolverError::InvalidInput(msg) => {
1060 assert!(msg.contains("y"), "unexpected message: {msg}");
1061 }
1062 other => panic!("expected InvalidInput, got {:?}", other),
1063 }
1064 }
1065 }
1066
1067 #[test]
1068 fn test_try_with_equation_traces_invalid_length() {
1069 let mut serial_error: Option<SolverError> = None;
1070 for mode in test_modes() {
1071 let is_serial = matches!(mode, Mode::Serial);
1072 let x = Exp::var("x");
1073 let eq = Exp::sub(x.clone(), Exp::val(1.0));
1074
1075 let solver = solver_for_mode(vec![eq], mode);
1076
1077 let traces = vec![
1078 Some(EquationTrace {
1079 constraint_id: 0,
1080 description: "eq0".to_string(),
1081 }),
1082 None,
1083 ];
1084 let err = match solver.try_with_equation_traces(traces) {
1085 Ok(_) => panic!("expected invalid input for trace length mismatch"),
1086 Err(err) => err,
1087 };
1088 if is_serial {
1089 serial_error = Some(err.clone());
1090 } else {
1091 let serial = serial_error.as_ref().expect("serial error missing");
1092 assert_error_matches(serial, &err);
1093 }
1094 match err {
1095 SolverError::InvalidInput(msg) => {
1096 assert!(msg.contains("Equation trace length (2)"), "{}", msg);
1097 assert!(msg.contains("number of equations (1)"), "{}", msg);
1098 }
1099 other => panic!("expected InvalidInput, got {:?}", other),
1100 }
1101 }
1102 }
1103
1104 #[test]
1105 fn test_line_search_preserves_fixed_variables() {
1106 let mut serial_solution: Option<Solution> = None;
1107 for mode in test_modes() {
1108 let is_serial = matches!(mode, Mode::Serial);
1109 let x = Exp::var("x");
1114 let a = Exp::var("a");
1115 let eq = Exp::sub(x.clone(), a.clone());
1116
1117 let solver = solver_for_with_vars_mode(vec![eq], &["x"], mode);
1118
1119 let mut initial = HashMap::new();
1120 initial.insert("x".to_string(), 0.0);
1121 initial.insert("a".to_string(), 2.0);
1122
1123 let solution = solver
1124 .solve_with_line_search(initial)
1125 .expect("solver should converge when fixed variables are provided");
1126 if is_serial {
1127 serial_solution = Some(solution.clone());
1128 } else {
1129 let serial = serial_solution.as_ref().expect("serial solution missing");
1130 assert_solution_close(serial, &solution, 1e-10);
1131 }
1132 assert!(solution.converged);
1133 assert!((solution.values.get("x").copied().unwrap() - 2.0).abs() < 1e-10);
1134 }
1135 }
1136
1137 #[test]
1138 fn test_least_squares_qr_handles_ill_conditioned_overdetermined_system_without_regularization()
1139 {
1140 let mut serial_solution: Option<Solution> = None;
1141 for mode in test_modes() {
1142 let is_serial = matches!(mode, Mode::Serial);
1143 let eps = 2f64.powi(-27); let x = Exp::var("x");
1150 let y = Exp::var("y");
1151
1152 let eq1 = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(2.0));
1155 let eq2 = Exp::sub(
1156 Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 + eps))),
1157 Exp::val(2.0 + eps),
1158 );
1159 let eq3 = Exp::sub(
1160 Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 - eps))),
1161 Exp::val(2.0 - eps),
1162 );
1163
1164 let solver = solver_for_mode(vec![eq1, eq2, eq3], mode)
1165 .with_regularization(0.0)
1166 .with_damping(1.0)
1167 .with_max_iterations(10)
1168 .with_tolerance(1e-10);
1169
1170 let mut initial = HashMap::new();
1171 initial.insert("x".to_string(), 0.0);
1172 initial.insert("y".to_string(), 0.0);
1173
1174 let solution = solver
1175 .solve(initial)
1176 .expect("expected solver to converge with QR least squares");
1177
1178 if is_serial {
1179 serial_solution = Some(solution.clone());
1180 } else {
1181 let serial = serial_solution.as_ref().expect("serial solution missing");
1182 assert_solution_close(serial, &solution, 1e-8);
1183 }
1184 assert!(solution.converged, "{:?}", solution);
1185 assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-8);
1186 assert!((solution.values.get("y").copied().unwrap() - 1.0).abs() < 1e-8);
1187 }
1188 }
1189
1190 #[test]
1191 fn test_underconstrained_system_returns_min_norm_solution() {
1192 let mut serial_solution: Option<Solution> = None;
1193 for mode in test_modes() {
1194 let is_serial = matches!(mode, Mode::Serial);
1195 let x = Exp::var("x");
1198 let y = Exp::var("y");
1199
1200 let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
1201 let solver = solver_for_mode(vec![eq], mode)
1202 .with_tolerance(1e-12)
1203 .with_max_iterations(10);
1204
1205 let mut initial = HashMap::new();
1206 initial.insert("x".to_string(), 0.0);
1207 initial.insert("y".to_string(), 0.0);
1208
1209 let solution = solver.solve(initial).expect("expected solver to converge");
1210 if is_serial {
1211 serial_solution = Some(solution.clone());
1212 } else {
1213 let serial = serial_solution.as_ref().expect("serial solution missing");
1214 assert_solution_close(serial, &solution, 1e-10);
1215 }
1216 assert!(solution.converged, "{:?}", solution);
1217 assert!((solution.values.get("x").copied().unwrap() - 0.5).abs() < 1e-10);
1218 assert!((solution.values.get("y").copied().unwrap() - 0.5).abs() < 1e-10);
1219 }
1220 }
1221
1222 #[test]
1223 fn test_rank_deficient_overconstrained_recovers_with_regularization() {
1224 let mut serial_solution: Option<Solution> = None;
1225 for mode in test_modes() {
1226 let is_serial = matches!(mode, Mode::Serial);
1227 let x = Exp::var("x");
1229 let y = Exp::var("y");
1230 let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1231
1232 let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1233 let eq2 = Exp::sub(
1234 Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()),
1235 Exp::val(2.0),
1236 );
1237 let eq3 = Exp::sub(
1238 Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()),
1239 Exp::val(3.0),
1240 );
1241
1242 let solver = solver_for_mode(vec![eq1, eq2, eq3], mode)
1243 .with_tolerance(1e-12)
1244 .with_max_iterations(10);
1245
1246 let mut initial = HashMap::new();
1247 initial.insert("x".to_string(), 0.0);
1248 initial.insert("y".to_string(), 0.0);
1249
1250 let solution = solver.solve(initial).expect("expected solver to converge");
1251 if is_serial {
1252 serial_solution = Some(solution.clone());
1253 } else {
1254 let serial = serial_solution.as_ref().expect("serial solution missing");
1255 assert_solution_close(serial, &solution, 1e-10);
1256 }
1257 assert!(solution.converged, "{:?}", solution);
1258 assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1259 assert!((solution.values.get("y").copied().unwrap() - 0.0).abs() < 1e-10);
1260 }
1261 }
1262
1263 #[test]
1264 fn test_rank_deficient_overconstrained_fails_without_regularization() {
1265 let mut serial_error: Option<SolverError> = None;
1266 for mode in test_modes() {
1267 let is_serial = matches!(mode, Mode::Serial);
1268 let x = Exp::var("x");
1269 let y = Exp::var("y");
1270 let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1271
1272 let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1273 let eq2 = Exp::sub(
1274 Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()),
1275 Exp::val(2.0),
1276 );
1277 let eq3 = Exp::sub(
1278 Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()),
1279 Exp::val(3.0),
1280 );
1281
1282 let solver = solver_for_mode(vec![eq1, eq2, eq3], mode)
1283 .with_regularization(0.0)
1284 .with_tolerance(1e-12)
1285 .with_max_iterations(10);
1286
1287 let mut initial = HashMap::new();
1288 initial.insert("x".to_string(), 0.0);
1289 initial.insert("y".to_string(), 0.0);
1290
1291 let err = solver.solve(initial).expect_err("expected singular matrix error");
1292 if is_serial {
1293 serial_error = Some(err.clone());
1294 } else {
1295 let serial = serial_error.as_ref().expect("serial error missing");
1296 assert_error_matches(serial, &err);
1297 }
1298 assert!(matches!(err, SolverError::SingularMatrix(_)), "{:?}", err);
1299 }
1300 }
1301
1302 #[test]
1303 fn test_overconstrained_consistent_system_solves() {
1304 let mut serial_solution: Option<Solution> = None;
1305 for mode in test_modes() {
1306 let is_serial = matches!(mode, Mode::Serial);
1307 let x = Exp::var("x");
1309 let eq1 = Exp::sub(x.clone(), Exp::val(1.0));
1310 let eq2 = Exp::sub(Exp::mul(x.clone(), Exp::val(2.0)), Exp::val(2.0));
1311
1312 let solver = solver_for_mode(vec![eq1, eq2], mode)
1313 .with_tolerance(1e-12)
1314 .with_max_iterations(10);
1315
1316 let mut initial = HashMap::new();
1317 initial.insert("x".to_string(), 0.0);
1318
1319 let solution = solver.solve(initial).expect("expected solver to converge");
1320 if is_serial {
1321 serial_solution = Some(solution.clone());
1322 } else {
1323 let serial = serial_solution.as_ref().expect("serial solution missing");
1324 assert_solution_close(serial, &solution, 1e-10);
1325 }
1326 assert!(solution.converged, "{:?}", solution);
1327 assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1328 }
1329 }
1330
1331 #[test]
1332 fn test_random_square_linear_systems() {
1333 let mut serial_solutions: Vec<Solution> = Vec::new();
1334 for mode in test_modes() {
1335 let is_serial = matches!(mode, Mode::Serial);
1336 let mut rng = TestRng::new(0x51ab_1e55_cafe_f00d);
1337 for case_index in 0..5 {
1338 let n = 3;
1339 let mut a = vec![vec![0.0; n]; n];
1340 for i in 0..n {
1341 for j in 0..n {
1342 a[i][j] = rng.next_f64();
1343 }
1344 a[i][i] += 2.0;
1345 }
1346
1347 let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1348 let mut b = vec![0.0; n];
1349 for i in 0..n {
1350 for j in 0..n {
1351 b[i] += a[i][j] * x_true[j];
1352 }
1353 }
1354
1355 let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1356 let equations: Vec<Exp> = (0..n)
1357 .map(|i| linear_equation(&a[i], &vars, b[i]))
1358 .collect();
1359
1360 let solver = solver_for_mode(equations, mode.clone())
1361 .with_tolerance(1e-12)
1362 .with_max_iterations(20);
1363
1364 let mut initial = HashMap::new();
1365 for i in 0..n {
1366 initial.insert(format!("x{i}"), 0.0);
1367 }
1368
1369 let solution = solver.solve(initial).expect("expected to solve");
1370 if is_serial {
1371 serial_solutions.push(solution.clone());
1372 } else {
1373 assert_solution_close(&serial_solutions[case_index], &solution, 1e-8);
1374 }
1375 for i in 0..n {
1376 let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1377 assert!((value - x_true[i]).abs() < 1e-8);
1378 }
1379 }
1380 }
1381 }
1382
1383 #[test]
1384 fn test_random_overconstrained_linear_systems() {
1385 let mut serial_solutions: Vec<Solution> = Vec::new();
1386 for mode in test_modes() {
1387 let is_serial = matches!(mode, Mode::Serial);
1388 let mut rng = TestRng::new(0xa11c_e551_dead_beef);
1389 let m = 5;
1390 let n = 3;
1391
1392 for case_index in 0..5 {
1393 let mut a = vec![vec![0.0; n]; m];
1394 for i in 0..m {
1395 for j in 0..n {
1396 a[i][j] = rng.next_f64();
1397 }
1398 }
1399 for i in 0..n {
1400 a[i][i] += 2.0;
1401 }
1402
1403 let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1404 let mut b = vec![0.0; m];
1405 for i in 0..m {
1406 for j in 0..n {
1407 b[i] += a[i][j] * x_true[j];
1408 }
1409 }
1410
1411 let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1412 let equations: Vec<Exp> = (0..m)
1413 .map(|i| linear_equation(&a[i], &vars, b[i]))
1414 .collect();
1415
1416 let solver = solver_for_mode(equations, mode.clone())
1417 .with_tolerance(1e-12)
1418 .with_max_iterations(20);
1419
1420 let mut initial = HashMap::new();
1421 for i in 0..n {
1422 initial.insert(format!("x{i}"), 0.0);
1423 }
1424
1425 let solution = solver.solve(initial).expect("expected to solve");
1426 if is_serial {
1427 serial_solutions.push(solution.clone());
1428 } else {
1429 assert_solution_close(&serial_solutions[case_index], &solution, 1e-8);
1430 }
1431 for i in 0..n {
1432 let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1433 assert!((value - x_true[i]).abs() < 1e-8);
1434 }
1435 }
1436 }
1437 }
1438
1439 #[test]
1440 fn test_random_underdetermined_min_norm_structure() {
1441 let mut serial_solutions: Vec<Solution> = Vec::new();
1442 for mode in test_modes() {
1443 let is_serial = matches!(mode, Mode::Serial);
1444 let mut rng = TestRng::new(0x0ddc_affe_fade_bead);
1445 let _m = 2;
1446 let n = 4;
1447 let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1448
1449 for case_index in 0..5 {
1450 let b0 = rng.next_f64();
1451 let b1 = rng.next_f64();
1452 let x2_zero = Exp::mul(vars[2].clone(), Exp::val(0.0));
1453 let x3_zero = Exp::mul(vars[3].clone(), Exp::val(0.0));
1454 let eq1 = Exp::sub(Exp::add(vars[0].clone(), x2_zero.clone()), Exp::val(b0));
1455 let eq2 = Exp::sub(Exp::add(vars[1].clone(), x3_zero.clone()), Exp::val(b1));
1456
1457 let solver = solver_for_mode(vec![eq1, eq2], mode.clone())
1458 .with_tolerance(1e-12)
1459 .with_max_iterations(20);
1460
1461 let mut initial = HashMap::new();
1462 for i in 0..n {
1463 initial.insert(format!("x{i}"), 0.0);
1464 }
1465
1466 let solution = solver.solve(initial).expect("expected to solve");
1467 if is_serial {
1468 serial_solutions.push(solution.clone());
1469 } else {
1470 assert_solution_close(&serial_solutions[case_index], &solution, 1e-10);
1471 }
1472 assert!((solution.values.get("x0").copied().unwrap() - b0).abs() < 1e-10);
1473 assert!((solution.values.get("x1").copied().unwrap() - b1).abs() < 1e-10);
1474 assert!(solution.values.get("x2").copied().unwrap().abs() < 1e-10);
1475 assert!(solution.values.get("x3").copied().unwrap().abs() < 1e-10);
1476 }
1477 }
1478 }
1479}