Skip to main content

constraint_solver/
solver.rs

1/*
2MIT License
3
4Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23*/
24
25use crate::compiler::{CompiledExp, CompiledSystem, VarId, VarTable};
26use crate::exp::MissingVarError;
27use crate::jacobian::Jacobian;
28use crate::matrix::{LeastSquaresQrInfo, Matrix, MatrixError};
29use crate::mode::{build_thread_pool, Mode};
30use rayon::ThreadPool;
31use std::collections::HashMap;
32
33struct LeastSquaresWorkspace {
34    augmented_j: Matrix,
35    augmented_b: Matrix,
36    num_equations: usize,
37    num_variables: usize,
38}
39
40impl LeastSquaresWorkspace {
41    fn new(num_equations: usize, num_variables: usize, sqrt_reg: f64) -> Self {
42        let mut augmented_j = Matrix::new(num_equations + num_variables, num_variables);
43        for i in 0..num_variables {
44            augmented_j[(num_equations + i, i)] = sqrt_reg;
45        }
46
47        LeastSquaresWorkspace {
48            augmented_j,
49            augmented_b: Matrix::new(num_equations + num_variables, 1),
50            num_equations,
51            num_variables,
52        }
53    }
54}
55
56enum LeastSquaresSolveError {
57    InvalidInput(MatrixError),
58    Singular { qr_err: String, normal_eq_err: String },
59}
60
61/// Newton-Raphson solver for systems of nonlinear equations
62///
63/// This solver can handle:
64/// - Square systems (equations == variables)
65/// - Under-constrained systems (equations < variables) - uses least squares
66/// - Over-constrained systems (equations > variables) - uses least squares
67///
68/// The solver uses adaptive damping and regularization for numerical stability.
69pub struct NewtonRaphsonSolver {
70    /// The system of equations to solve (each equation should equal zero)
71    equations: Vec<CompiledExp>,
72    /// Variable IDs that correspond to unknowns in the system
73    variables: Vec<VarId>,
74    /// Registry of all variables referenced by the compiled system
75    var_table: VarTable,
76    /// Maximum number of iterations before giving up
77    max_iterations: usize,
78    /// Convergence tolerance (solution found when |f(x)| < tolerance)
79    tolerance: f64,
80    /// Initial damping factor (step size multiplier, 0 < damping <= 1)
81    /// Lower values make convergence more stable but slower
82    damping_factor: f64,
83    /// Minimum allowed damping before declaring convergence failure
84    min_damping: f64,
85    /// Base regularization parameter used when the system is ill-conditioned
86    regularization: f64,
87    /// Optional metadata describing the source of each equation in the system
88    equation_traces: Vec<Option<EquationTrace>>,
89    /// Execution mode for linear algebra (serial vs parallel)
90    mode: Mode,
91    /// Optional thread pool for parallel execution
92    pool: Option<ThreadPool>,
93}
94
95/// Solution result containing the solved variable values and convergence info
96#[derive(Debug, Clone)]
97pub struct Solution {
98    /// Final values of all variables (variable name -> value)
99    pub values: HashMap<String, f64>,
100    /// Number of iterations taken to converge
101    pub iterations: usize,
102    /// Final residual error |f(x)|
103    pub error: f64,
104    /// Whether the solver successfully converged
105    pub converged: bool,
106    /// History of error values throughout the iteration process
107    pub convergence_history: Vec<f64>,
108}
109
110/// Provides traceability information back to the constraint that generated an equation
111#[derive(Debug, Clone)]
112pub struct EquationTrace {
113    pub constraint_id: usize,
114    pub description: String,
115}
116
117/// Diagnostic data captured when the solver fails
118#[derive(Debug, Clone)]
119pub struct SolverRunDiagnostic {
120    pub message: String,
121    pub iterations: usize,
122    pub error: f64,
123    /// Variable values at the end of the run (variable name -> value)
124    pub values: HashMap<String, f64>,
125    pub residuals: Vec<f64>,
126}
127
128/// Possible solver error conditions
129#[derive(Debug, Clone)]
130pub enum SolverError {
131    /// Jacobian matrix became singular and couldn't be inverted
132    SingularMatrix(SolverRunDiagnostic),
133    /// Failed to converge within the maximum number of iterations
134    NoConvergence(SolverRunDiagnostic),
135    /// A required variable was missing during evaluation
136    MissingVariable(MissingVarError),
137    /// Invalid input parameters or system setup
138    InvalidInput(String),
139}
140
141impl From<MissingVarError> for SolverError {
142    fn from(value: MissingVarError) -> Self {
143        SolverError::MissingVariable(value)
144    }
145}
146
147impl From<MatrixError> for SolverError {
148    fn from(value: MatrixError) -> Self {
149        SolverError::InvalidInput(value.to_string())
150    }
151}
152
153impl NewtonRaphsonSolver {
154    /// Create a new Newton-Raphson solver with adaptive parameters
155    ///
156    /// Parameters are automatically adjusted based on system type:
157    /// - Over-constrained systems get more iterations, relaxed tolerance, conservative damping
158    /// - Normal systems use standard parameters for fast convergence
159    ///
160    /// # Arguments
161    /// * `compiled` - Compiled system of equations (each should evaluate to 0 at solution)
162    pub fn new(compiled: CompiledSystem) -> Self {
163        let variables = compiled.var_table.all_var_ids();
164        Self::build(compiled, variables)
165    }
166
167    /// Create a new solver with an explicit execution mode.
168    pub fn new_with_mode(compiled: CompiledSystem, mode: Mode) -> Result<Self, SolverError> {
169        let variables = compiled.var_table.all_var_ids();
170        Self::build_with_mode(compiled, variables, mode)
171    }
172
173    /// Create a solver while specifying which variables to solve for.
174    ///
175    /// Variables not listed here are treated as fixed parameters and must be
176    /// provided in the initial guess.
177    pub fn new_with_variables(
178        compiled: CompiledSystem,
179        variables: &[&str],
180    ) -> Result<Self, SolverError> {
181        let mut ids = Vec::with_capacity(variables.len());
182        let mut seen = std::collections::HashSet::new();
183        for name in variables {
184            let id = compiled
185                .var_table
186                .get_id(name)
187                .ok_or_else(|| SolverError::InvalidInput(format!("Unknown variable '{name}'")))?;
188            if seen.insert(id) {
189                ids.push(id);
190            }
191        }
192        Ok(Self::build(compiled, ids))
193    }
194
195    /// Create a solver with explicit variables and execution mode.
196    pub fn new_with_variables_and_mode(
197        compiled: CompiledSystem,
198        variables: &[&str],
199        mode: Mode,
200    ) -> Result<Self, SolverError> {
201        let mut ids = Vec::with_capacity(variables.len());
202        let mut seen = std::collections::HashSet::new();
203        for name in variables {
204            let id = compiled
205                .var_table
206                .get_id(name)
207                .ok_or_else(|| SolverError::InvalidInput(format!("Unknown variable '{name}'")))?;
208            if seen.insert(id) {
209                ids.push(id);
210            }
211        }
212
213        Self::build_with_mode(compiled, ids, mode)
214    }
215
216    fn build(compiled: CompiledSystem, variables: Vec<VarId>) -> Self {
217        // Allow both under-constrained and over-constrained systems
218        // We'll use least squares to solve them
219        let is_over_constrained = compiled.equations.len() > variables.len();
220
221        NewtonRaphsonSolver {
222            equations: compiled.equations,
223            variables,
224            var_table: compiled.var_table,
225            // Over-constrained systems are harder to converge, so give them more iterations
226            max_iterations: if is_over_constrained { 200 } else { 100 },
227            // Over-constrained systems use relaxed tolerance since exact solutions may not exist
228            tolerance: if is_over_constrained { 1e-8 } else { 1e-10 },
229            // Over-constrained systems use conservative damping for stability
230            damping_factor: if is_over_constrained { 0.7 } else { 1.0 },
231            min_damping: 0.001,
232            // Over-constrained systems need more regularization for numerical stability
233            regularization: if is_over_constrained { 1e-4 } else { 1e-8 },
234            equation_traces: Vec::new(),
235            mode: Mode::Serial,
236            pool: None,
237        }
238    }
239
240    fn build_with_mode(
241        compiled: CompiledSystem,
242        variables: Vec<VarId>,
243        mode: Mode,
244    ) -> Result<Self, SolverError> {
245        let mut solver = Self::build(compiled, variables);
246        solver.set_mode(mode)?;
247        Ok(solver)
248    }
249
250    /// Update the execution mode for this solver.
251    pub fn with_mode(mut self, mode: Mode) -> Result<Self, SolverError> {
252        self.set_mode(mode)?;
253        Ok(self)
254    }
255
256    /// Returns the current execution mode.
257    pub fn mode(&self) -> &Mode {
258        &self.mode
259    }
260
261    fn set_mode(&mut self, mode: Mode) -> Result<(), SolverError> {
262        let pool = build_thread_pool(&mode).map_err(SolverError::InvalidInput)?;
263        self.mode = mode;
264        self.pool = pool;
265        Ok(())
266    }
267
268    fn parallel_enabled(&self) -> bool {
269        self.pool.is_some()
270    }
271
272    fn run_in_mode<F, R>(&self, f: F) -> R
273    where
274        F: FnOnce() -> R + Send,
275        R: Send,
276    {
277        if let Some(pool) = &self.pool {
278            pool.install(f)
279        } else {
280            f()
281        }
282    }
283
284    /// Attach metadata describing the origin of each equation in the system
285    pub fn try_with_equation_traces(
286        mut self,
287        traces: Vec<Option<EquationTrace>>,
288    ) -> Result<Self, SolverError> {
289        if !self.equations.is_empty() && traces.len() != self.equations.len() {
290            return Err(SolverError::InvalidInput(format!(
291                "Equation trace length ({}) does not match number of equations ({})",
292                traces.len(),
293                self.equations.len()
294            )));
295        }
296        self.equation_traces = traces;
297        Ok(self)
298    }
299
300    /// Attach metadata describing the origin of each equation in the system
301    pub fn with_equation_traces(self, traces: Vec<Option<EquationTrace>>) -> Self {
302        self.try_with_equation_traces(traces)
303            .unwrap_or_else(|err| match err {
304                SolverError::InvalidInput(msg) => panic!("{}", msg),
305                other => panic!("{other:?}"),
306            })
307    }
308
309    /// Fetch trace metadata for a specific equation, if available
310    pub fn trace_for_equation(&self, equation_index: usize) -> Option<&EquationTrace> {
311        self.equation_traces
312            .get(equation_index)
313            .and_then(|entry| entry.as_ref())
314    }
315
316    /// Override the maximum number of iterations
317    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
318        self.max_iterations = max_iterations;
319        self
320    }
321
322    /// Override the convergence tolerance
323    pub fn with_tolerance(mut self, tolerance: f64) -> Self {
324        self.tolerance = tolerance;
325        self
326    }
327
328    /// Override the damping factor (clamped to [0.1, 1.0] for stability)
329    pub fn with_damping(mut self, damping_factor: f64) -> Self {
330        self.damping_factor = damping_factor.clamp(0.1, 1.0);
331        self
332    }
333
334    /// Override the base regularization parameter used when the system is ill-conditioned
335    pub fn with_regularization(mut self, regularization: f64) -> Self {
336        self.regularization = regularization.max(0.0);
337        self
338    }
339
340    /// Solve the system using standard Newton-Raphson method
341    pub fn solve(&self, initial_guess: HashMap<String, f64>) -> Result<Solution, SolverError> {
342        let vars = self.map_initial_guess(initial_guess)?;
343        self.run_in_mode(|| self.solve_modified_internal(vars, false))
344    }
345
346    /// Core Newton-Raphson solver implementation with optional line search
347    ///
348    /// # Arguments
349    /// * `initial_guess` - Starting values for all variables (by name)
350    /// * `use_line_search` - Whether to use backtracking line search for step size
351    ///
352    /// # Algorithm
353    /// 1. Evaluate f(x) and check for convergence
354    /// 2. Compute Jacobian J = df/dx
355    /// 3. Solve linear system: J * delta = -f(x)
356    ///    - Square systems: Direct LU decomposition
357    ///    - Under/over-constrained: QR least squares (with adaptive regularization if needed)
358    /// 4. Update: x_new = x_old + damping * delta
359    /// 5. Adapt damping based on error change
360    fn solve_modified_internal(
361        &self,
362        initial_guess: HashMap<VarId, f64>,
363        use_line_search: bool,
364    ) -> Result<Solution, SolverError> {
365        self.validate_initial_guess(&initial_guess)?;
366
367        let mut vars = initial_guess.clone();
368        let jacobian = Jacobian::new(self.equations.clone(), self.variables.clone());
369        // Cache Jacobian storage to avoid per-iteration allocations.
370        let mut jacobian_workspace = jacobian.workspace();
371        let mut convergence_history = Vec::with_capacity(self.max_iterations);
372        // Start with configured damping factor, will be adapted during iterations
373        let mut damping = self.damping_factor;
374        let mut last_error = f64::INFINITY;
375
376        let num_equations = self.equations.len();
377        let num_variables = self.variables.len();
378
379        let mut f_vals = Matrix::new(num_equations, 1);
380        let mut f_neg = Matrix::new(num_equations, 1);
381        jacobian_workspace.replace_jacobian(jacobian.evaluate_checked(&vars)?);
382
383        let mut line_search_f_vals = Matrix::new(num_equations, 1);
384
385        let mut least_squares_workspace =
386            if self.regularization > 0.0 && num_equations != num_variables {
387                Some(LeastSquaresWorkspace::new(
388                    num_equations,
389                    num_variables,
390                    self.regularization.sqrt(),
391                ))
392            } else {
393                None
394            };
395        let parallel = self.parallel_enabled();
396
397        for iter in 0..self.max_iterations {
398            // Step 1: Evaluate function values f(x)
399            jacobian.evaluate_functions_checked_into(&vars, &mut f_vals)?;
400            let error = f_vals.norm();
401            convergence_history.push(error);
402
403            // Check for convergence: |f(x)| < tolerance
404            if error < self.tolerance {
405                return Ok(Solution {
406                    values: self.values_by_name(&vars),
407                    iterations: iter + 1,
408                    error,
409                    converged: true,
410                    convergence_history,
411                });
412            }
413
414            // Check for stagnation (error not decreasing significantly)
415            if iter > 0 && (error - last_error).abs() < self.tolerance * 0.1 {
416                damping *= 0.5; // Reduce step size
417                if damping < self.min_damping {
418                    let diag = self.build_diagnostic(
419                        format!(
420                            "Solver stagnated at iteration {} with error {:.2e}",
421                            iter + 1,
422                            error
423                        ),
424                        iter + 1,
425                        &vars,
426                        &f_vals,
427                    );
428                    return Err(SolverError::NoConvergence(diag));
429                }
430            }
431
432            // Adaptive damping based on error change
433            if error > last_error * 1.5 {
434                // Error increased significantly - reduce damping for stability
435                damping *= 0.5;
436            } else if error < last_error * 0.5 {
437                // Error decreased significantly - can increase damping for faster convergence
438                damping = (damping * 1.2).min(1.0);
439            }
440
441            last_error = error;
442
443            for i in 0..num_equations {
444                f_neg[(i, 0)] = -f_vals[(i, 0)];
445            }
446
447            // Step 2: Solve linear system J * delta = -f(x)
448            let delta = {
449                let j_matrix = jacobian_workspace.jacobian();
450                // Handle different system types with appropriate solving methods
451                if j_matrix.rows() == j_matrix.cols() {
452                    // Square system (equations == variables)
453                    // Use standard LU decomposition: J * delta = -f
454                    match j_matrix.solve_lu(&f_neg) {
455                        Ok(d) => d,
456                        Err(_) => {
457                            // Matrix is singular - try with increased regularization
458                            let diag_size = j_matrix.rows().min(j_matrix.cols());
459                            for i in 0..diag_size {
460                                j_matrix[(i, i)] += self.regularization * 100.0;
461                            }
462                            match j_matrix.solve_lu(&f_neg) {
463                                Ok(d) => d,
464                                Err(e) => {
465                                    let diag = self.build_diagnostic(
466                                        format!(
467                                            "Matrix is singular even with regularization: {e}"
468                                        ),
469                                        iter + 1,
470                                        &vars,
471                                        &f_vals,
472                                    );
473                                    return Err(SolverError::SingularMatrix(diag));
474                                }
475                            }
476                        }
477                    }
478                } else {
479                    match self.solve_least_squares_delta(
480                        j_matrix,
481                        &f_neg,
482                        least_squares_workspace.as_mut(),
483                        parallel,
484                    ) {
485                        Ok(delta) => delta,
486                        Err(LeastSquaresSolveError::InvalidInput(err)) => {
487                            return Err(SolverError::InvalidInput(err.to_string()));
488                        }
489                        Err(LeastSquaresSolveError::Singular {
490                            qr_err,
491                            normal_eq_err,
492                        }) => {
493                            let diag = self.build_diagnostic(
494                                format!(
495                                    "Least squares system is singular (qr_err: {qr_err}; normal_eq_err: {normal_eq_err})"
496                                ),
497                                iter + 1,
498                                &vars,
499                                &f_vals,
500                            );
501                            return Err(SolverError::SingularMatrix(diag));
502                        }
503                    }
504                }
505            };
506
507            // Step 4: Determine step size (damping or line search)
508            let step_size = if use_line_search {
509                // Use backtracking line search to find optimal step size
510                self.line_search(&jacobian, &vars, &delta, error, &mut line_search_f_vals)?
511            } else {
512                // Use current damping factor as step size
513                damping
514            };
515
516            // Step 5: Update variables: x_new = x_old + step_size * delta
517            for (i, &var_id) in self.variables.iter().enumerate() {
518                vars.insert(var_id, vars[&var_id] + step_size * delta[(i, 0)]);
519            }
520
521            // Additional convergence check: small step size indicates convergence
522            if delta.norm() * step_size < self.tolerance {
523                return Ok(Solution {
524                    values: self.values_by_name(&vars),
525                    iterations: iter + 1,
526                    error,
527                    converged: true,
528                    convergence_history,
529                });
530            }
531
532            // Refresh Jacobian for next iteration
533            jacobian.evaluate_checked_in_workspace(&vars, &mut jacobian_workspace)?;
534        }
535
536        // Failed to converge within max iterations
537        let residuals = jacobian.evaluate_functions_checked(&vars)?;
538        let diag = self.build_diagnostic(
539            format!(
540                "Failed to converge after {} iterations. Final error: {:.2e}",
541                self.max_iterations,
542                residuals.norm()
543            ),
544            self.max_iterations,
545            &vars,
546            &residuals,
547        );
548        Err(SolverError::NoConvergence(diag))
549    }
550
551    fn solve_least_squares_delta(
552        &self,
553        j_matrix: &Matrix,
554        f_neg: &Matrix,
555        workspace: Option<&mut LeastSquaresWorkspace>,
556        parallel: bool,
557    ) -> Result<Matrix, LeastSquaresSolveError> {
558        let mut workspace = workspace;
559        // Prefer QR-based least squares to avoid forming normal equations (J^T J), which can
560        // severely amplify conditioning issues. If the QR solve is rank deficient or
561        // ill-conditioned, fall back to a regularized augmented system
562        // [J; sqrt(lambda) I] * delta ~= [-f; 0].
563        const COND_LIMIT: f64 = 1e12;
564
565        let full_rank = j_matrix.rows().min(j_matrix.cols());
566        let is_ill_conditioned = |info: &LeastSquaresQrInfo| -> bool {
567            info.rank < full_rank
568                || !info.cond_est.is_finite()
569                || info.cond_est > COND_LIMIT
570        };
571
572        let mut qr_err: Option<String> = None;
573        let mut unreg_delta: Option<Matrix> = None;
574        let mut unreg_info: Option<LeastSquaresQrInfo> = None;
575
576        match j_matrix.solve_least_squares_qr_with_info_with_parallel(f_neg, parallel) {
577            Ok((delta, info)) => {
578                if !is_ill_conditioned(&info) {
579                    return Ok(delta);
580                }
581                unreg_delta = Some(delta);
582                unreg_info = Some(info);
583            }
584            Err(err) => {
585                qr_err = Some(err);
586            }
587        }
588
589        if self.regularization <= 0.0 {
590            if let Some(delta) = unreg_delta {
591                return Ok(delta);
592            }
593            return Err(LeastSquaresSolveError::Singular {
594                qr_err: qr_err.unwrap_or_else(|| "QR least squares failed".to_string()),
595                normal_eq_err: "Regularization disabled".to_string(),
596            });
597        }
598
599        let reg_factors = [1.0, 10.0, 100.0];
600        let mut last_err: Option<String> = None;
601        let mut last_solution: Option<Matrix> = None;
602
603        for factor in reg_factors {
604            let lambda = self.regularization * factor;
605            if lambda <= 0.0 {
606                continue;
607            }
608            let sqrt_reg = lambda.sqrt();
609
610            let reg_result = match workspace.as_deref_mut() {
611                Some(workspace) => {
612                    if workspace.num_equations != j_matrix.rows()
613                        || workspace.num_variables != j_matrix.cols()
614                    {
615                        return Err(LeastSquaresSolveError::InvalidInput(
616                            MatrixError::DimensionMismatch {
617                                operation: "least_squares",
618                                left: (workspace.augmented_j.rows(), workspace.augmented_j.cols()),
619                                right: (j_matrix.rows(), j_matrix.cols()),
620                            },
621                        ));
622                    }
623
624                    for i in 0..workspace.num_equations {
625                        for j in 0..workspace.num_variables {
626                            workspace.augmented_j[(i, j)] = j_matrix[(i, j)];
627                        }
628                        workspace.augmented_b[(i, 0)] = f_neg[(i, 0)];
629                    }
630                    for i in 0..workspace.num_variables {
631                        for j in 0..workspace.num_variables {
632                            workspace.augmented_j[(workspace.num_equations + i, j)] = 0.0;
633                        }
634                        workspace.augmented_j[(workspace.num_equations + i, i)] = sqrt_reg;
635                        workspace.augmented_b[(workspace.num_equations + i, 0)] = 0.0;
636                    }
637
638                    workspace
639                        .augmented_j
640                        .solve_least_squares_qr_with_info_with_parallel(
641                            &workspace.augmented_b,
642                            parallel,
643                        )
644                }
645                None => {
646                    let mut augmented_j =
647                        Matrix::new(j_matrix.rows() + j_matrix.cols(), j_matrix.cols());
648                    let mut augmented_b = Matrix::new(j_matrix.rows() + j_matrix.cols(), 1);
649                    for i in 0..j_matrix.rows() {
650                        for j in 0..j_matrix.cols() {
651                            augmented_j[(i, j)] = j_matrix[(i, j)];
652                        }
653                        augmented_b[(i, 0)] = f_neg[(i, 0)];
654                    }
655                    for i in 0..j_matrix.cols() {
656                        augmented_j[(j_matrix.rows() + i, i)] = sqrt_reg;
657                        augmented_b[(j_matrix.rows() + i, 0)] = 0.0;
658                    }
659                    augmented_j.solve_least_squares_qr_with_info_with_parallel(
660                        &augmented_b,
661                        parallel,
662                    )
663                }
664            };
665
666            match reg_result {
667                Ok((delta, info)) => {
668                    last_solution = Some(delta);
669                    if !is_ill_conditioned(&info) {
670                        return Ok(last_solution.unwrap());
671                    }
672                }
673                Err(err) => {
674                    last_err = Some(err);
675                }
676            }
677        }
678
679        if let Some(delta) = last_solution {
680            return Ok(delta);
681        }
682
683        Err(LeastSquaresSolveError::Singular {
684            qr_err: qr_err
685                .or_else(|| {
686                    unreg_info
687                        .map(|info| format!("QR ill-conditioned (rank {}, cond {:.2e})", info.rank, info.cond_est))
688                })
689                .unwrap_or_else(|| "QR least squares failed".to_string()),
690            normal_eq_err: last_err.unwrap_or_else(|| "Regularized QR failed".to_string()),
691        })
692    }
693
694    fn map_initial_guess(
695        &self,
696        initial_guess: HashMap<String, f64>,
697    ) -> Result<HashMap<VarId, f64>, SolverError> {
698        let mut unknown = Vec::new();
699        let mut vars = HashMap::new();
700
701        for (name, value) in initial_guess {
702            if let Some(id) = self.var_table.get_id(&name) {
703                vars.insert(id, value);
704            } else {
705                unknown.push(name);
706            }
707        }
708
709        if !unknown.is_empty() {
710            unknown.sort();
711            return Err(SolverError::InvalidInput(format!(
712                "Initial guess includes unknown variables: {unknown:?}"
713            )));
714        }
715
716        self.validate_initial_guess(&vars)?;
717        Ok(vars)
718    }
719
720    fn validate_initial_guess(&self, vars: &HashMap<VarId, f64>) -> Result<(), SolverError> {
721        let mut missing = Vec::new();
722        for (idx, name) in self.var_table.names().iter().enumerate() {
723            if !vars.contains_key(&VarId::new(idx)) {
724                missing.push(name.clone());
725            }
726        }
727        if missing.is_empty() {
728            return Ok(());
729        }
730
731        missing.sort();
732        Err(SolverError::InvalidInput(format!(
733            "Initial guess missing values for variables: {missing:?}"
734        )))
735    }
736
737    fn values_by_name(&self, vars: &HashMap<VarId, f64>) -> HashMap<String, f64> {
738        let mut values = HashMap::with_capacity(self.var_table.len());
739        for (idx, name) in self.var_table.names().iter().enumerate() {
740            if let Some(value) = vars.get(&VarId::new(idx)) {
741                values.insert(name.clone(), *value);
742            }
743        }
744        values
745    }
746
747    fn build_diagnostic(
748        &self,
749        message: String,
750        iterations: usize,
751        vars: &HashMap<VarId, f64>,
752        residuals_matrix: &Matrix,
753    ) -> SolverRunDiagnostic {
754        let mut residuals = Vec::with_capacity(residuals_matrix.rows());
755        for i in 0..residuals_matrix.rows() {
756            residuals.push(residuals_matrix[(i, 0)]);
757        }
758
759        SolverRunDiagnostic {
760            message,
761            iterations,
762            error: residuals_matrix.norm(),
763            values: self.values_by_name(vars),
764            residuals,
765        }
766    }
767
768    /// Backtracking line search to find optimal step size
769    ///
770    /// Tries progressively smaller step sizes until one is found that reduces the error.
771    /// This helps prevent oscillation and improves convergence robustness.
772    ///
773    /// # Arguments
774    /// * `jacobian` - Jacobian evaluator for computing function values
775    /// * `vars` - Current variable values
776    /// * `delta` - Newton step direction
777    /// * `current_error` - Current function error |f(x)|
778    ///
779    /// # Returns
780    /// Optimal step size alpha such that x_new = x + alpha * delta reduces error
781    fn line_search(
782        &self,
783        jacobian: &Jacobian,
784        vars: &HashMap<VarId, f64>,
785        delta: &Matrix,
786        current_error: f64,
787        f_vals: &mut Matrix,
788    ) -> Result<f64, SolverError> {
789        let mut alpha = 1.0; // Start with full Newton step
790        let mut new_vars = vars.clone();
791
792        // Try up to 20 different step sizes
793        for _ in 0..20 {
794            new_vars.clone_from(vars);
795            // Compute new variable values: x_new = x + alpha * delta
796            for (i, &var_id) in self.variables.iter().enumerate() {
797                new_vars.insert(var_id, vars[&var_id] + alpha * delta[(i, 0)]);
798            }
799
800            // Check if this step size reduces the error sufficiently
801            jacobian.evaluate_functions_checked_into(&new_vars, f_vals)?;
802            let new_error = f_vals.norm();
803            // Armijo condition: require sufficient decrease
804            if new_error < current_error * (1.0 - 0.5 * alpha) {
805                return Ok(alpha);
806            }
807            // Reduce step size and try again
808            alpha *= 0.5;
809
810            // Give up if step size becomes too small
811            if alpha < 1e-10 {
812                break;
813            }
814        }
815
816        // Return at least the minimum damping to avoid complete stagnation
817        Ok(alpha.max(self.min_damping))
818    }
819
820    /// Solve the system using Newton-Raphson with line search
821    ///
822    /// Line search helps improve convergence robustness by automatically
823    /// finding good step sizes, especially useful for difficult systems.
824    pub fn solve_with_line_search(
825        &self,
826        initial_guess: HashMap<String, f64>,
827    ) -> Result<Solution, SolverError> {
828        let vars = self.map_initial_guess(initial_guess)?;
829        self.run_in_mode(|| self.solve_modified_internal(vars, true))
830    }
831}
832
833#[cfg(test)]
834mod tests {
835    use super::*;
836    use crate::compiler::Compiler;
837    use crate::exp::Exp;
838    use crate::Mode;
839
840    struct TestRng {
841        state: u64,
842    }
843
844    impl TestRng {
845        fn new(seed: u64) -> Self {
846            Self { state: seed }
847        }
848
849        fn next_u32(&mut self) -> u32 {
850            self.state = self
851                .state
852                .wrapping_mul(6364136223846793005)
853                .wrapping_add(1);
854            (self.state >> 32) as u32
855        }
856
857        fn next_f64(&mut self) -> f64 {
858            let v = self.next_u32() as f64 / u32::MAX as f64;
859            2.0 * v - 1.0
860        }
861    }
862
863    fn linear_equation(coeffs: &[f64], vars: &[Exp], rhs: f64) -> Exp {
864        let mut sum = Exp::val(0.0);
865        for (coeff, var) in coeffs.iter().zip(vars.iter()) {
866            let term = Exp::mul(Exp::val(*coeff), var.clone());
867            sum = Exp::add(sum, term);
868        }
869        Exp::sub(sum, Exp::val(rhs))
870    }
871
872    fn test_modes() -> Vec<Mode> {
873        let threads = std::thread::available_parallelism()
874            .map(|n| n.get())
875            .unwrap_or(1)
876            .min(4)
877            .max(1);
878        vec![Mode::Serial, Mode::Parallel { thread_count: threads }]
879    }
880
881    fn assert_solution_close(serial: &Solution, parallel: &Solution, tol: f64) {
882        assert_eq!(serial.values.len(), parallel.values.len());
883        for (name, value) in &serial.values {
884            let other = parallel.values.get(name).expect("missing variable");
885            let diff = (value - other).abs();
886            assert!(diff <= tol, "var {name} mismatch: {value} vs {other}");
887        }
888    }
889
890    fn assert_error_matches(serial: &SolverError, parallel: &SolverError) {
891        assert_eq!(
892            std::mem::discriminant(serial),
893            std::mem::discriminant(parallel)
894        );
895        match (serial, parallel) {
896            (SolverError::InvalidInput(a), SolverError::InvalidInput(b)) => assert_eq!(a, b),
897            (SolverError::MissingVariable(a), SolverError::MissingVariable(b)) => {
898                assert_eq!(a, b)
899            }
900            _ => {}
901        }
902    }
903
904    fn solver_for_mode(equations: Vec<Exp>, mode: Mode) -> NewtonRaphsonSolver {
905        let compiled = Compiler::compile(&equations).expect("compile failed");
906        NewtonRaphsonSolver::new_with_mode(compiled, mode).expect("failed to set solver mode")
907    }
908
909    fn solver_for_with_vars_mode(
910        equations: Vec<Exp>,
911        variables: &[&str],
912        mode: Mode,
913    ) -> NewtonRaphsonSolver {
914        let compiled = Compiler::compile(&equations).expect("compile failed");
915        NewtonRaphsonSolver::new_with_variables_and_mode(compiled, variables, mode)
916            .expect("failed to select solve variables or set mode")
917    }
918
919    #[test]
920    fn test_simple_system() {
921        let mut serial_solution: Option<Solution> = None;
922        for mode in test_modes() {
923            let is_serial = matches!(mode, Mode::Serial);
924            // Test system: x^2 + y^2 = 1, xy = 0.25
925            // Solution should be approximately x ~= 0.5, y ~= 0.866 (or vice versa)
926            let x = Exp::var("x");
927            let y = Exp::var("y");
928
929            let eq1 = Exp::sub(
930                Exp::add(Exp::power(x.clone(), 2.0), Exp::power(y.clone(), 2.0)),
931                Exp::val(1.0),
932            );
933            let eq2 = Exp::sub(Exp::mul(x.clone(), y.clone()), Exp::val(0.25));
934
935            let solver = solver_for_mode(vec![eq1, eq2], mode);
936
937            let mut initial = HashMap::new();
938            initial.insert("x".to_string(), 0.5);
939            initial.insert("y".to_string(), 0.866);
940
941            let solution = match solver.solve(initial.clone()) {
942                Ok(sol) => sol,
943                Err(_) => match solver.solve_with_line_search(initial) {
944                    Ok(sol) => sol,
945                    Err(e) => panic!("Failed to solve: {:?}", e),
946                },
947            };
948            if is_serial {
949                serial_solution = Some(solution.clone());
950            } else {
951                let serial = serial_solution.as_ref().expect("serial solution missing");
952                assert_solution_close(serial, &solution, 1e-8);
953            }
954            assert!(solution.converged);
955            assert!(solution.error < 1e-10);
956
957            let x_sol = solution.values.get("x").copied().unwrap();
958            let y_sol = solution.values.get("y").copied().unwrap();
959            assert!((x_sol * x_sol + y_sol * y_sol - 1.0).abs() < 1e-10);
960            assert!((x_sol * y_sol - 0.25).abs() < 1e-10);
961        }
962    }
963
964    #[test]
965    fn test_transcendental_system() {
966        let mut serial_solution: Option<Solution> = None;
967        for mode in test_modes() {
968            let is_serial = matches!(mode, Mode::Serial);
969            // Test transcendental system: sin(x) = 2y, cos(y) = x
970            let x = Exp::var("x");
971            let y = Exp::var("y");
972
973            let eq1 = Exp::sub(Exp::sin(x.clone()), Exp::mul(y.clone(), Exp::val(2.0)));
974            let eq2 = Exp::sub(Exp::cos(y.clone()), x.clone());
975
976            let solver = solver_for_mode(vec![eq1, eq2], mode)
977                .with_tolerance(1e-8)
978                .with_max_iterations(50)
979                .with_regularization(1e-8);
980
981            let mut initial = HashMap::new();
982            initial.insert("x".to_string(), 0.5);
983            initial.insert("y".to_string(), 0.25);
984
985            let solution = solver
986                .solve_with_line_search(initial)
987                .expect("Failed to solve transcendental system");
988
989            if is_serial {
990                serial_solution = Some(solution.clone());
991            } else {
992                let serial = serial_solution.as_ref().expect("serial solution missing");
993                assert_solution_close(serial, &solution, 1e-8);
994            }
995            assert!(solution.converged);
996            let x_sol = solution.values.get("x").copied().unwrap();
997            let y_sol = solution.values.get("y").copied().unwrap();
998            assert!((x_sol.sin() - 2.0 * y_sol).abs() < 1e-8);
999            assert!((y_sol.cos() - x_sol).abs() < 1e-8);
1000        }
1001    }
1002
1003    #[test]
1004    fn test_solver_errors_on_missing_variable() {
1005        let mut serial_error: Option<SolverError> = None;
1006        for mode in test_modes() {
1007            let is_serial = matches!(mode, Mode::Serial);
1008            // System: x - a = 0, solving for x with a treated as a fixed parameter.
1009            // Missing `a` should be an error (not implicitly treated as 0).
1010            let x = Exp::var("x");
1011            let a = Exp::var("a");
1012            let eq = Exp::sub(x.clone(), a.clone());
1013
1014            let solver = solver_for_with_vars_mode(vec![eq], &["x"], mode);
1015
1016            let mut initial = HashMap::new();
1017            initial.insert("x".to_string(), 1.0);
1018
1019            let err = solver.solve(initial).expect_err("expected missing variable error");
1020            if is_serial {
1021                serial_error = Some(err.clone());
1022            } else {
1023                let serial = serial_error.as_ref().expect("serial error missing");
1024                assert_error_matches(serial, &err);
1025            }
1026            match err {
1027                SolverError::InvalidInput(msg) => {
1028                    assert!(msg.contains("a"), "unexpected message: {msg}");
1029                }
1030                other => panic!("expected InvalidInput, got {:?}", other),
1031            }
1032        }
1033    }
1034
1035    #[test]
1036    fn test_solver_invalid_input_on_missing_initial_guess_variable() {
1037        let mut serial_error: Option<SolverError> = None;
1038        for mode in test_modes() {
1039            let is_serial = matches!(mode, Mode::Serial);
1040            let x = Exp::var("x");
1041            let y = Exp::var("y");
1042            let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
1043
1044            let solver = solver_for_mode(vec![eq], mode);
1045
1046            let mut initial = HashMap::new();
1047            initial.insert("x".to_string(), 0.0);
1048
1049            let err = solver
1050                .solve(initial)
1051                .expect_err("expected invalid input due to missing y");
1052            if is_serial {
1053                serial_error = Some(err.clone());
1054            } else {
1055                let serial = serial_error.as_ref().expect("serial error missing");
1056                assert_error_matches(serial, &err);
1057            }
1058            match err {
1059                SolverError::InvalidInput(msg) => {
1060                    assert!(msg.contains("y"), "unexpected message: {msg}");
1061                }
1062                other => panic!("expected InvalidInput, got {:?}", other),
1063            }
1064        }
1065    }
1066
1067    #[test]
1068    fn test_try_with_equation_traces_invalid_length() {
1069        let mut serial_error: Option<SolverError> = None;
1070        for mode in test_modes() {
1071            let is_serial = matches!(mode, Mode::Serial);
1072            let x = Exp::var("x");
1073            let eq = Exp::sub(x.clone(), Exp::val(1.0));
1074
1075            let solver = solver_for_mode(vec![eq], mode);
1076
1077            let traces = vec![
1078                Some(EquationTrace {
1079                    constraint_id: 0,
1080                    description: "eq0".to_string(),
1081                }),
1082                None,
1083            ];
1084            let err = match solver.try_with_equation_traces(traces) {
1085                Ok(_) => panic!("expected invalid input for trace length mismatch"),
1086                Err(err) => err,
1087            };
1088            if is_serial {
1089                serial_error = Some(err.clone());
1090            } else {
1091                let serial = serial_error.as_ref().expect("serial error missing");
1092                assert_error_matches(serial, &err);
1093            }
1094            match err {
1095                SolverError::InvalidInput(msg) => {
1096                    assert!(msg.contains("Equation trace length (2)"), "{}", msg);
1097                    assert!(msg.contains("number of equations (1)"), "{}", msg);
1098                }
1099                other => panic!("expected InvalidInput, got {:?}", other),
1100            }
1101        }
1102    }
1103
1104    #[test]
1105    fn test_line_search_preserves_fixed_variables() {
1106        let mut serial_solution: Option<Solution> = None;
1107        for mode in test_modes() {
1108            let is_serial = matches!(mode, Mode::Serial);
1109            // Regression test: line search must preserve "fixed parameter" variables that are
1110            // not part of `self.variables` but are referenced by equations.
1111            //
1112            // If line search drops them, evaluation errors and the solver fails.
1113            let x = Exp::var("x");
1114            let a = Exp::var("a");
1115            let eq = Exp::sub(x.clone(), a.clone());
1116
1117            let solver = solver_for_with_vars_mode(vec![eq], &["x"], mode);
1118
1119            let mut initial = HashMap::new();
1120            initial.insert("x".to_string(), 0.0);
1121            initial.insert("a".to_string(), 2.0);
1122
1123            let solution = solver
1124                .solve_with_line_search(initial)
1125                .expect("solver should converge when fixed variables are provided");
1126            if is_serial {
1127                serial_solution = Some(solution.clone());
1128            } else {
1129                let serial = serial_solution.as_ref().expect("serial solution missing");
1130                assert_solution_close(serial, &solution, 1e-10);
1131            }
1132            assert!(solution.converged);
1133            assert!((solution.values.get("x").copied().unwrap() - 2.0).abs() < 1e-10);
1134        }
1135    }
1136
1137    #[test]
1138    fn test_least_squares_qr_handles_ill_conditioned_overdetermined_system_without_regularization()
1139    {
1140        let mut serial_solution: Option<Solution> = None;
1141        for mode in test_modes() {
1142            let is_serial = matches!(mode, Mode::Serial);
1143            // This system is overdetermined (3 equations, 2 unknowns) with an ill-conditioned
1144            // Jacobian whose columns are nearly linearly dependent. Normal equations (J^T J) can
1145            // lose the tiny distinguishing term and become singular in floating point. QR-based
1146            // least squares should still solve it.
1147            let eps = 2f64.powi(-27); // exactly representable; eps^2 is below 1 ulp at ~3.0
1148
1149            let x = Exp::var("x");
1150            let y = Exp::var("y");
1151
1152            // A = [[1, 1], [1, 1+eps], [1, 1-eps]]
1153            // b corresponds to solution (x,y) = (1,1)
1154            let eq1 = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(2.0));
1155            let eq2 = Exp::sub(
1156                Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 + eps))),
1157                Exp::val(2.0 + eps),
1158            );
1159            let eq3 = Exp::sub(
1160                Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 - eps))),
1161                Exp::val(2.0 - eps),
1162            );
1163
1164            let solver = solver_for_mode(vec![eq1, eq2, eq3], mode)
1165                .with_regularization(0.0)
1166                .with_damping(1.0)
1167                .with_max_iterations(10)
1168                .with_tolerance(1e-10);
1169
1170            let mut initial = HashMap::new();
1171            initial.insert("x".to_string(), 0.0);
1172            initial.insert("y".to_string(), 0.0);
1173
1174            let solution = solver
1175                .solve(initial)
1176                .expect("expected solver to converge with QR least squares");
1177
1178            if is_serial {
1179                serial_solution = Some(solution.clone());
1180            } else {
1181                let serial = serial_solution.as_ref().expect("serial solution missing");
1182                assert_solution_close(serial, &solution, 1e-8);
1183            }
1184            assert!(solution.converged, "{:?}", solution);
1185            assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-8);
1186            assert!((solution.values.get("y").copied().unwrap() - 1.0).abs() < 1e-8);
1187        }
1188    }
1189
1190    #[test]
1191    fn test_underconstrained_system_returns_min_norm_solution() {
1192        let mut serial_solution: Option<Solution> = None;
1193        for mode in test_modes() {
1194            let is_serial = matches!(mode, Mode::Serial);
1195            // Underdetermined system: x + y = 1 has infinitely many solutions.
1196            // The solver should return the minimum-norm solution: x = 0.5, y = 0.5.
1197            let x = Exp::var("x");
1198            let y = Exp::var("y");
1199
1200            let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
1201            let solver = solver_for_mode(vec![eq], mode)
1202                .with_tolerance(1e-12)
1203                .with_max_iterations(10);
1204
1205            let mut initial = HashMap::new();
1206            initial.insert("x".to_string(), 0.0);
1207            initial.insert("y".to_string(), 0.0);
1208
1209            let solution = solver.solve(initial).expect("expected solver to converge");
1210            if is_serial {
1211                serial_solution = Some(solution.clone());
1212            } else {
1213                let serial = serial_solution.as_ref().expect("serial solution missing");
1214                assert_solution_close(serial, &solution, 1e-10);
1215            }
1216            assert!(solution.converged, "{:?}", solution);
1217            assert!((solution.values.get("x").copied().unwrap() - 0.5).abs() < 1e-10);
1218            assert!((solution.values.get("y").copied().unwrap() - 0.5).abs() < 1e-10);
1219        }
1220    }
1221
1222    #[test]
1223    fn test_rank_deficient_overconstrained_recovers_with_regularization() {
1224        let mut serial_solution: Option<Solution> = None;
1225        for mode in test_modes() {
1226            let is_serial = matches!(mode, Mode::Serial);
1227            // Equations depend only on x; y is unconstrained. The Jacobian is rank deficient.
1228            let x = Exp::var("x");
1229            let y = Exp::var("y");
1230            let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1231
1232            let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1233            let eq2 = Exp::sub(
1234                Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()),
1235                Exp::val(2.0),
1236            );
1237            let eq3 = Exp::sub(
1238                Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()),
1239                Exp::val(3.0),
1240            );
1241
1242            let solver = solver_for_mode(vec![eq1, eq2, eq3], mode)
1243                .with_tolerance(1e-12)
1244                .with_max_iterations(10);
1245
1246            let mut initial = HashMap::new();
1247            initial.insert("x".to_string(), 0.0);
1248            initial.insert("y".to_string(), 0.0);
1249
1250            let solution = solver.solve(initial).expect("expected solver to converge");
1251            if is_serial {
1252                serial_solution = Some(solution.clone());
1253            } else {
1254                let serial = serial_solution.as_ref().expect("serial solution missing");
1255                assert_solution_close(serial, &solution, 1e-10);
1256            }
1257            assert!(solution.converged, "{:?}", solution);
1258            assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1259            assert!((solution.values.get("y").copied().unwrap() - 0.0).abs() < 1e-10);
1260        }
1261    }
1262
1263    #[test]
1264    fn test_rank_deficient_overconstrained_fails_without_regularization() {
1265        let mut serial_error: Option<SolverError> = None;
1266        for mode in test_modes() {
1267            let is_serial = matches!(mode, Mode::Serial);
1268            let x = Exp::var("x");
1269            let y = Exp::var("y");
1270            let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1271
1272            let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1273            let eq2 = Exp::sub(
1274                Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()),
1275                Exp::val(2.0),
1276            );
1277            let eq3 = Exp::sub(
1278                Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()),
1279                Exp::val(3.0),
1280            );
1281
1282            let solver = solver_for_mode(vec![eq1, eq2, eq3], mode)
1283                .with_regularization(0.0)
1284                .with_tolerance(1e-12)
1285                .with_max_iterations(10);
1286
1287            let mut initial = HashMap::new();
1288            initial.insert("x".to_string(), 0.0);
1289            initial.insert("y".to_string(), 0.0);
1290
1291            let err = solver.solve(initial).expect_err("expected singular matrix error");
1292            if is_serial {
1293                serial_error = Some(err.clone());
1294            } else {
1295                let serial = serial_error.as_ref().expect("serial error missing");
1296                assert_error_matches(serial, &err);
1297            }
1298            assert!(matches!(err, SolverError::SingularMatrix(_)), "{:?}", err);
1299        }
1300    }
1301
1302    #[test]
1303    fn test_overconstrained_consistent_system_solves() {
1304        let mut serial_solution: Option<Solution> = None;
1305        for mode in test_modes() {
1306            let is_serial = matches!(mode, Mode::Serial);
1307            // Overdetermined but consistent system: x = 1 and 2x = 2.
1308            let x = Exp::var("x");
1309            let eq1 = Exp::sub(x.clone(), Exp::val(1.0));
1310            let eq2 = Exp::sub(Exp::mul(x.clone(), Exp::val(2.0)), Exp::val(2.0));
1311
1312            let solver = solver_for_mode(vec![eq1, eq2], mode)
1313                .with_tolerance(1e-12)
1314                .with_max_iterations(10);
1315
1316            let mut initial = HashMap::new();
1317            initial.insert("x".to_string(), 0.0);
1318
1319            let solution = solver.solve(initial).expect("expected solver to converge");
1320            if is_serial {
1321                serial_solution = Some(solution.clone());
1322            } else {
1323                let serial = serial_solution.as_ref().expect("serial solution missing");
1324                assert_solution_close(serial, &solution, 1e-10);
1325            }
1326            assert!(solution.converged, "{:?}", solution);
1327            assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1328        }
1329    }
1330
1331    #[test]
1332    fn test_random_square_linear_systems() {
1333        let mut serial_solutions: Vec<Solution> = Vec::new();
1334        for mode in test_modes() {
1335            let is_serial = matches!(mode, Mode::Serial);
1336            let mut rng = TestRng::new(0x51ab_1e55_cafe_f00d);
1337            for case_index in 0..5 {
1338                let n = 3;
1339                let mut a = vec![vec![0.0; n]; n];
1340                for i in 0..n {
1341                    for j in 0..n {
1342                        a[i][j] = rng.next_f64();
1343                    }
1344                    a[i][i] += 2.0;
1345                }
1346
1347                let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1348                let mut b = vec![0.0; n];
1349                for i in 0..n {
1350                    for j in 0..n {
1351                        b[i] += a[i][j] * x_true[j];
1352                    }
1353                }
1354
1355                let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1356                let equations: Vec<Exp> = (0..n)
1357                    .map(|i| linear_equation(&a[i], &vars, b[i]))
1358                    .collect();
1359
1360                let solver = solver_for_mode(equations, mode.clone())
1361                    .with_tolerance(1e-12)
1362                    .with_max_iterations(20);
1363
1364                let mut initial = HashMap::new();
1365                for i in 0..n {
1366                    initial.insert(format!("x{i}"), 0.0);
1367                }
1368
1369                let solution = solver.solve(initial).expect("expected to solve");
1370                if is_serial {
1371                    serial_solutions.push(solution.clone());
1372                } else {
1373                    assert_solution_close(&serial_solutions[case_index], &solution, 1e-8);
1374                }
1375                for i in 0..n {
1376                    let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1377                    assert!((value - x_true[i]).abs() < 1e-8);
1378                }
1379            }
1380        }
1381    }
1382
1383    #[test]
1384    fn test_random_overconstrained_linear_systems() {
1385        let mut serial_solutions: Vec<Solution> = Vec::new();
1386        for mode in test_modes() {
1387            let is_serial = matches!(mode, Mode::Serial);
1388            let mut rng = TestRng::new(0xa11c_e551_dead_beef);
1389            let m = 5;
1390            let n = 3;
1391
1392            for case_index in 0..5 {
1393                let mut a = vec![vec![0.0; n]; m];
1394                for i in 0..m {
1395                    for j in 0..n {
1396                        a[i][j] = rng.next_f64();
1397                    }
1398                }
1399                for i in 0..n {
1400                    a[i][i] += 2.0;
1401                }
1402
1403                let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1404                let mut b = vec![0.0; m];
1405                for i in 0..m {
1406                    for j in 0..n {
1407                        b[i] += a[i][j] * x_true[j];
1408                    }
1409                }
1410
1411                let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1412                let equations: Vec<Exp> = (0..m)
1413                    .map(|i| linear_equation(&a[i], &vars, b[i]))
1414                    .collect();
1415
1416                let solver = solver_for_mode(equations, mode.clone())
1417                    .with_tolerance(1e-12)
1418                    .with_max_iterations(20);
1419
1420                let mut initial = HashMap::new();
1421                for i in 0..n {
1422                    initial.insert(format!("x{i}"), 0.0);
1423                }
1424
1425                let solution = solver.solve(initial).expect("expected to solve");
1426                if is_serial {
1427                    serial_solutions.push(solution.clone());
1428                } else {
1429                    assert_solution_close(&serial_solutions[case_index], &solution, 1e-8);
1430                }
1431                for i in 0..n {
1432                    let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1433                    assert!((value - x_true[i]).abs() < 1e-8);
1434                }
1435            }
1436        }
1437    }
1438
1439    #[test]
1440    fn test_random_underdetermined_min_norm_structure() {
1441        let mut serial_solutions: Vec<Solution> = Vec::new();
1442        for mode in test_modes() {
1443            let is_serial = matches!(mode, Mode::Serial);
1444            let mut rng = TestRng::new(0x0ddc_affe_fade_bead);
1445            let _m = 2;
1446            let n = 4;
1447            let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1448
1449            for case_index in 0..5 {
1450                let b0 = rng.next_f64();
1451                let b1 = rng.next_f64();
1452                let x2_zero = Exp::mul(vars[2].clone(), Exp::val(0.0));
1453                let x3_zero = Exp::mul(vars[3].clone(), Exp::val(0.0));
1454                let eq1 = Exp::sub(Exp::add(vars[0].clone(), x2_zero.clone()), Exp::val(b0));
1455                let eq2 = Exp::sub(Exp::add(vars[1].clone(), x3_zero.clone()), Exp::val(b1));
1456
1457                let solver = solver_for_mode(vec![eq1, eq2], mode.clone())
1458                    .with_tolerance(1e-12)
1459                    .with_max_iterations(20);
1460
1461                let mut initial = HashMap::new();
1462                for i in 0..n {
1463                    initial.insert(format!("x{i}"), 0.0);
1464                }
1465
1466                let solution = solver.solve(initial).expect("expected to solve");
1467                if is_serial {
1468                    serial_solutions.push(solution.clone());
1469                } else {
1470                    assert_solution_close(&serial_solutions[case_index], &solution, 1e-10);
1471                }
1472                assert!((solution.values.get("x0").copied().unwrap() - b0).abs() < 1e-10);
1473                assert!((solution.values.get("x1").copied().unwrap() - b1).abs() < 1e-10);
1474                assert!(solution.values.get("x2").copied().unwrap().abs() < 1e-10);
1475                assert!(solution.values.get("x3").copied().unwrap().abs() < 1e-10);
1476            }
1477        }
1478    }
1479}