Skip to main content

constraint_solver/
solver.rs

1/*
2MIT License
3
4Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23*/
24
25use crate::compiler::{CompiledExp, CompiledSystem, VarId, VarTable};
26use crate::exp::MissingVarError;
27use crate::jacobian::Jacobian;
28use crate::matrix::{LeastSquaresQrInfo, Matrix, MatrixError};
29use std::collections::HashMap;
30
31struct LeastSquaresWorkspace {
32    augmented_j: Matrix,
33    augmented_b: Matrix,
34    num_equations: usize,
35    num_variables: usize,
36}
37
38impl LeastSquaresWorkspace {
39    fn new(num_equations: usize, num_variables: usize, sqrt_reg: f64) -> Self {
40        let mut augmented_j = Matrix::new(num_equations + num_variables, num_variables);
41        for i in 0..num_variables {
42            augmented_j[(num_equations + i, i)] = sqrt_reg;
43        }
44
45        LeastSquaresWorkspace {
46            augmented_j,
47            augmented_b: Matrix::new(num_equations + num_variables, 1),
48            num_equations,
49            num_variables,
50        }
51    }
52}
53
54enum LeastSquaresSolveError {
55    InvalidInput(MatrixError),
56    Singular { qr_err: String, normal_eq_err: String },
57}
58
59/// Newton-Raphson solver for systems of nonlinear equations
60///
61/// This solver can handle:
62/// - Square systems (equations == variables)
63/// - Under-constrained systems (equations < variables) - uses least squares
64/// - Over-constrained systems (equations > variables) - uses least squares
65///
66/// The solver uses adaptive damping and regularization for numerical stability.
67pub struct NewtonRaphsonSolver {
68    /// The system of equations to solve (each equation should equal zero)
69    equations: Vec<CompiledExp>,
70    /// Variable IDs that correspond to unknowns in the system
71    variables: Vec<VarId>,
72    /// Registry of all variables referenced by the compiled system
73    var_table: VarTable,
74    /// Maximum number of iterations before giving up
75    max_iterations: usize,
76    /// Convergence tolerance (solution found when |f(x)| < tolerance)
77    tolerance: f64,
78    /// Initial damping factor (step size multiplier, 0 < damping <= 1)
79    /// Lower values make convergence more stable but slower
80    damping_factor: f64,
81    /// Minimum allowed damping before declaring convergence failure
82    min_damping: f64,
83    /// Base regularization parameter used when the system is ill-conditioned
84    regularization: f64,
85    /// Optional metadata describing the source of each equation in the system
86    equation_traces: Vec<Option<EquationTrace>>,
87}
88
89/// Solution result containing the solved variable values and convergence info
90#[derive(Debug, Clone)]
91pub struct Solution {
92    /// Final values of all variables (variable name -> value)
93    pub values: HashMap<String, f64>,
94    /// Number of iterations taken to converge
95    pub iterations: usize,
96    /// Final residual error |f(x)|
97    pub error: f64,
98    /// Whether the solver successfully converged
99    pub converged: bool,
100    /// History of error values throughout the iteration process
101    pub convergence_history: Vec<f64>,
102}
103
104/// Provides traceability information back to the constraint that generated an equation
105#[derive(Debug, Clone)]
106pub struct EquationTrace {
107    pub constraint_id: usize,
108    pub description: String,
109}
110
111/// Diagnostic data captured when the solver fails
112#[derive(Debug, Clone)]
113pub struct SolverRunDiagnostic {
114    pub message: String,
115    pub iterations: usize,
116    pub error: f64,
117    /// Variable values at the end of the run (variable name -> value)
118    pub values: HashMap<String, f64>,
119    pub residuals: Vec<f64>,
120}
121
122/// Possible solver error conditions
123#[derive(Debug, Clone)]
124pub enum SolverError {
125    /// Jacobian matrix became singular and couldn't be inverted
126    SingularMatrix(SolverRunDiagnostic),
127    /// Failed to converge within the maximum number of iterations
128    NoConvergence(SolverRunDiagnostic),
129    /// A required variable was missing during evaluation
130    MissingVariable(MissingVarError),
131    /// Invalid input parameters or system setup
132    InvalidInput(String),
133}
134
135impl From<MissingVarError> for SolverError {
136    fn from(value: MissingVarError) -> Self {
137        SolverError::MissingVariable(value)
138    }
139}
140
141impl From<MatrixError> for SolverError {
142    fn from(value: MatrixError) -> Self {
143        SolverError::InvalidInput(value.to_string())
144    }
145}
146
147impl NewtonRaphsonSolver {
148    /// Create a new Newton-Raphson solver with adaptive parameters
149    ///
150    /// Parameters are automatically adjusted based on system type:
151    /// - Over-constrained systems get more iterations, relaxed tolerance, conservative damping
152    /// - Normal systems use standard parameters for fast convergence
153    ///
154    /// # Arguments
155    /// * `compiled` - Compiled system of equations (each should evaluate to 0 at solution)
156    pub fn new(compiled: CompiledSystem) -> Self {
157        let variables = compiled.var_table.all_var_ids();
158        Self::build(compiled, variables)
159    }
160
161    /// Create a solver while specifying which variables to solve for.
162    ///
163    /// Variables not listed here are treated as fixed parameters and must be
164    /// provided in the initial guess.
165    pub fn new_with_variables(
166        compiled: CompiledSystem,
167        variables: &[&str],
168    ) -> Result<Self, SolverError> {
169        let mut ids = Vec::with_capacity(variables.len());
170        let mut seen = std::collections::HashSet::new();
171        for name in variables {
172            let id = compiled
173                .var_table
174                .get_id(name)
175                .ok_or_else(|| SolverError::InvalidInput(format!("Unknown variable '{name}'")))?;
176            if seen.insert(id) {
177                ids.push(id);
178            }
179        }
180        Ok(Self::build(compiled, ids))
181    }
182
183    fn build(compiled: CompiledSystem, variables: Vec<VarId>) -> Self {
184        // Allow both under-constrained and over-constrained systems
185        // We'll use least squares to solve them
186        let is_over_constrained = compiled.equations.len() > variables.len();
187
188        NewtonRaphsonSolver {
189            equations: compiled.equations,
190            variables,
191            var_table: compiled.var_table,
192            // Over-constrained systems are harder to converge, so give them more iterations
193            max_iterations: if is_over_constrained { 200 } else { 100 },
194            // Over-constrained systems use relaxed tolerance since exact solutions may not exist
195            tolerance: if is_over_constrained { 1e-8 } else { 1e-10 },
196            // Over-constrained systems use conservative damping for stability
197            damping_factor: if is_over_constrained { 0.7 } else { 1.0 },
198            min_damping: 0.001,
199            // Over-constrained systems need more regularization for numerical stability
200            regularization: if is_over_constrained { 1e-4 } else { 1e-8 },
201            equation_traces: Vec::new(),
202        }
203    }
204
205    /// Attach metadata describing the origin of each equation in the system
206    pub fn try_with_equation_traces(
207        mut self,
208        traces: Vec<Option<EquationTrace>>,
209    ) -> Result<Self, SolverError> {
210        if !self.equations.is_empty() && traces.len() != self.equations.len() {
211            return Err(SolverError::InvalidInput(format!(
212                "Equation trace length ({}) does not match number of equations ({})",
213                traces.len(),
214                self.equations.len()
215            )));
216        }
217        self.equation_traces = traces;
218        Ok(self)
219    }
220
221    /// Attach metadata describing the origin of each equation in the system
222    pub fn with_equation_traces(self, traces: Vec<Option<EquationTrace>>) -> Self {
223        self.try_with_equation_traces(traces)
224            .unwrap_or_else(|err| match err {
225                SolverError::InvalidInput(msg) => panic!("{}", msg),
226                other => panic!("{other:?}"),
227            })
228    }
229
230    /// Fetch trace metadata for a specific equation, if available
231    pub fn trace_for_equation(&self, equation_index: usize) -> Option<&EquationTrace> {
232        self.equation_traces
233            .get(equation_index)
234            .and_then(|entry| entry.as_ref())
235    }
236
237    /// Override the maximum number of iterations
238    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
239        self.max_iterations = max_iterations;
240        self
241    }
242
243    /// Override the convergence tolerance
244    pub fn with_tolerance(mut self, tolerance: f64) -> Self {
245        self.tolerance = tolerance;
246        self
247    }
248
249    /// Override the damping factor (clamped to [0.1, 1.0] for stability)
250    pub fn with_damping(mut self, damping_factor: f64) -> Self {
251        self.damping_factor = damping_factor.clamp(0.1, 1.0);
252        self
253    }
254
255    /// Override the base regularization parameter used when the system is ill-conditioned
256    pub fn with_regularization(mut self, regularization: f64) -> Self {
257        self.regularization = regularization.max(0.0);
258        self
259    }
260
261    /// Solve the system using standard Newton-Raphson method
262    pub fn solve(&self, initial_guess: HashMap<String, f64>) -> Result<Solution, SolverError> {
263        let vars = self.map_initial_guess(initial_guess)?;
264        self.solve_modified_internal(vars, false)
265    }
266
267    /// Core Newton-Raphson solver implementation with optional line search
268    ///
269    /// # Arguments
270    /// * `initial_guess` - Starting values for all variables (by name)
271    /// * `use_line_search` - Whether to use backtracking line search for step size
272    ///
273    /// # Algorithm
274    /// 1. Evaluate f(x) and check for convergence
275    /// 2. Compute Jacobian J = df/dx
276    /// 3. Solve linear system: J * delta = -f(x)
277    ///    - Square systems: Direct LU decomposition
278    ///    - Under/over-constrained: QR least squares (with adaptive regularization if needed)
279    /// 4. Update: x_new = x_old + damping * delta
280    /// 5. Adapt damping based on error change
281    fn solve_modified_internal(
282        &self,
283        initial_guess: HashMap<VarId, f64>,
284        use_line_search: bool,
285    ) -> Result<Solution, SolverError> {
286        self.validate_initial_guess(&initial_guess)?;
287
288        let mut vars = initial_guess.clone();
289        let jacobian = Jacobian::new(self.equations.clone(), self.variables.clone());
290        // Cache Jacobian storage to avoid per-iteration allocations.
291        let mut jacobian_workspace = jacobian.workspace();
292        let mut convergence_history = Vec::with_capacity(self.max_iterations);
293        // Start with configured damping factor, will be adapted during iterations
294        let mut damping = self.damping_factor;
295        let mut last_error = f64::INFINITY;
296
297        let num_equations = self.equations.len();
298        let num_variables = self.variables.len();
299
300        let mut f_vals = Matrix::new(num_equations, 1);
301        let mut f_neg = Matrix::new(num_equations, 1);
302        jacobian_workspace.replace_jacobian(jacobian.evaluate_checked(&vars)?);
303
304        let mut line_search_f_vals = Matrix::new(num_equations, 1);
305
306        let mut least_squares_workspace =
307            if self.regularization > 0.0 && num_equations != num_variables {
308                Some(LeastSquaresWorkspace::new(
309                    num_equations,
310                    num_variables,
311                    self.regularization.sqrt(),
312                ))
313            } else {
314                None
315            };
316
317        for iter in 0..self.max_iterations {
318            // Step 1: Evaluate function values f(x)
319            jacobian.evaluate_functions_checked_into(&vars, &mut f_vals)?;
320            let error = f_vals.norm();
321            convergence_history.push(error);
322
323            // Check for convergence: |f(x)| < tolerance
324            if error < self.tolerance {
325                return Ok(Solution {
326                    values: self.values_by_name(&vars),
327                    iterations: iter + 1,
328                    error,
329                    converged: true,
330                    convergence_history,
331                });
332            }
333
334            // Check for stagnation (error not decreasing significantly)
335            if iter > 0 && (error - last_error).abs() < self.tolerance * 0.1 {
336                damping *= 0.5; // Reduce step size
337                if damping < self.min_damping {
338                    let diag = self.build_diagnostic(
339                        format!(
340                            "Solver stagnated at iteration {} with error {:.2e}",
341                            iter + 1,
342                            error
343                        ),
344                        iter + 1,
345                        &vars,
346                        &f_vals,
347                    );
348                    return Err(SolverError::NoConvergence(diag));
349                }
350            }
351
352            // Adaptive damping based on error change
353            if error > last_error * 1.5 {
354                // Error increased significantly - reduce damping for stability
355                damping *= 0.5;
356            } else if error < last_error * 0.5 {
357                // Error decreased significantly - can increase damping for faster convergence
358                damping = (damping * 1.2).min(1.0);
359            }
360
361            last_error = error;
362
363            for i in 0..num_equations {
364                f_neg[(i, 0)] = -f_vals[(i, 0)];
365            }
366
367            // Step 2: Solve linear system J * delta = -f(x)
368            let delta = {
369                let j_matrix = jacobian_workspace.jacobian();
370                // Handle different system types with appropriate solving methods
371                if j_matrix.rows() == j_matrix.cols() {
372                    // Square system (equations == variables)
373                    // Use standard LU decomposition: J * delta = -f
374                    match j_matrix.solve_lu(&f_neg) {
375                        Ok(d) => d,
376                        Err(_) => {
377                            // Matrix is singular - try with increased regularization
378                            let diag_size = j_matrix.rows().min(j_matrix.cols());
379                            for i in 0..diag_size {
380                                j_matrix[(i, i)] += self.regularization * 100.0;
381                            }
382                            match j_matrix.solve_lu(&f_neg) {
383                                Ok(d) => d,
384                                Err(e) => {
385                                    let diag = self.build_diagnostic(
386                                        format!(
387                                            "Matrix is singular even with regularization: {e}"
388                                        ),
389                                        iter + 1,
390                                        &vars,
391                                        &f_vals,
392                                    );
393                                    return Err(SolverError::SingularMatrix(diag));
394                                }
395                            }
396                        }
397                    }
398                } else {
399                    match self.solve_least_squares_delta(
400                        j_matrix,
401                        &f_neg,
402                        least_squares_workspace.as_mut(),
403                    ) {
404                        Ok(delta) => delta,
405                        Err(LeastSquaresSolveError::InvalidInput(err)) => {
406                            return Err(SolverError::InvalidInput(err.to_string()));
407                        }
408                        Err(LeastSquaresSolveError::Singular {
409                            qr_err,
410                            normal_eq_err,
411                        }) => {
412                            let diag = self.build_diagnostic(
413                                format!(
414                                    "Least squares system is singular (qr_err: {qr_err}; normal_eq_err: {normal_eq_err})"
415                                ),
416                                iter + 1,
417                                &vars,
418                                &f_vals,
419                            );
420                            return Err(SolverError::SingularMatrix(diag));
421                        }
422                    }
423                }
424            };
425
426            // Step 4: Determine step size (damping or line search)
427            let step_size = if use_line_search {
428                // Use backtracking line search to find optimal step size
429                self.line_search(&jacobian, &vars, &delta, error, &mut line_search_f_vals)?
430            } else {
431                // Use current damping factor as step size
432                damping
433            };
434
435            // Step 5: Update variables: x_new = x_old + step_size * delta
436            for (i, &var_id) in self.variables.iter().enumerate() {
437                vars.insert(var_id, vars[&var_id] + step_size * delta[(i, 0)]);
438            }
439
440            // Additional convergence check: small step size indicates convergence
441            if delta.norm() * step_size < self.tolerance {
442                return Ok(Solution {
443                    values: self.values_by_name(&vars),
444                    iterations: iter + 1,
445                    error,
446                    converged: true,
447                    convergence_history,
448                });
449            }
450
451            // Refresh Jacobian for next iteration
452            jacobian.evaluate_checked_in_workspace(&vars, &mut jacobian_workspace)?;
453        }
454
455        // Failed to converge within max iterations
456        let residuals = jacobian.evaluate_functions_checked(&vars)?;
457        let diag = self.build_diagnostic(
458            format!(
459                "Failed to converge after {} iterations. Final error: {:.2e}",
460                self.max_iterations,
461                residuals.norm()
462            ),
463            self.max_iterations,
464            &vars,
465            &residuals,
466        );
467        Err(SolverError::NoConvergence(diag))
468    }
469
470    fn solve_least_squares_delta(
471        &self,
472        j_matrix: &Matrix,
473        f_neg: &Matrix,
474        workspace: Option<&mut LeastSquaresWorkspace>,
475    ) -> Result<Matrix, LeastSquaresSolveError> {
476        let mut workspace = workspace;
477        // Prefer QR-based least squares to avoid forming normal equations (J^T J), which can
478        // severely amplify conditioning issues. If the QR solve is rank deficient or
479        // ill-conditioned, fall back to a regularized augmented system
480        // [J; sqrt(lambda) I] * delta ~= [-f; 0].
481        const COND_LIMIT: f64 = 1e12;
482
483        let full_rank = j_matrix.rows().min(j_matrix.cols());
484        let is_ill_conditioned = |info: &LeastSquaresQrInfo| -> bool {
485            info.rank < full_rank
486                || !info.cond_est.is_finite()
487                || info.cond_est > COND_LIMIT
488        };
489
490        let mut qr_err: Option<String> = None;
491        let mut unreg_delta: Option<Matrix> = None;
492        let mut unreg_info: Option<LeastSquaresQrInfo> = None;
493
494        match j_matrix.solve_least_squares_qr_with_info(f_neg) {
495            Ok((delta, info)) => {
496                if !is_ill_conditioned(&info) {
497                    return Ok(delta);
498                }
499                unreg_delta = Some(delta);
500                unreg_info = Some(info);
501            }
502            Err(err) => {
503                qr_err = Some(err);
504            }
505        }
506
507        if self.regularization <= 0.0 {
508            if let Some(delta) = unreg_delta {
509                return Ok(delta);
510            }
511            return Err(LeastSquaresSolveError::Singular {
512                qr_err: qr_err.unwrap_or_else(|| "QR least squares failed".to_string()),
513                normal_eq_err: "Regularization disabled".to_string(),
514            });
515        }
516
517        let reg_factors = [1.0, 10.0, 100.0];
518        let mut last_err: Option<String> = None;
519        let mut last_solution: Option<Matrix> = None;
520
521        for factor in reg_factors {
522            let lambda = self.regularization * factor;
523            if lambda <= 0.0 {
524                continue;
525            }
526            let sqrt_reg = lambda.sqrt();
527
528            let reg_result = match workspace.as_deref_mut() {
529                Some(workspace) => {
530                    if workspace.num_equations != j_matrix.rows()
531                        || workspace.num_variables != j_matrix.cols()
532                    {
533                        return Err(LeastSquaresSolveError::InvalidInput(
534                            MatrixError::DimensionMismatch {
535                                operation: "least_squares",
536                                left: (workspace.augmented_j.rows(), workspace.augmented_j.cols()),
537                                right: (j_matrix.rows(), j_matrix.cols()),
538                            },
539                        ));
540                    }
541
542                    for i in 0..workspace.num_equations {
543                        for j in 0..workspace.num_variables {
544                            workspace.augmented_j[(i, j)] = j_matrix[(i, j)];
545                        }
546                        workspace.augmented_b[(i, 0)] = f_neg[(i, 0)];
547                    }
548                    for i in 0..workspace.num_variables {
549                        for j in 0..workspace.num_variables {
550                            workspace.augmented_j[(workspace.num_equations + i, j)] = 0.0;
551                        }
552                        workspace.augmented_j[(workspace.num_equations + i, i)] = sqrt_reg;
553                        workspace.augmented_b[(workspace.num_equations + i, 0)] = 0.0;
554                    }
555
556                    workspace
557                        .augmented_j
558                        .solve_least_squares_qr_with_info(&workspace.augmented_b)
559                }
560                None => {
561                    let mut augmented_j =
562                        Matrix::new(j_matrix.rows() + j_matrix.cols(), j_matrix.cols());
563                    let mut augmented_b = Matrix::new(j_matrix.rows() + j_matrix.cols(), 1);
564                    for i in 0..j_matrix.rows() {
565                        for j in 0..j_matrix.cols() {
566                            augmented_j[(i, j)] = j_matrix[(i, j)];
567                        }
568                        augmented_b[(i, 0)] = f_neg[(i, 0)];
569                    }
570                    for i in 0..j_matrix.cols() {
571                        augmented_j[(j_matrix.rows() + i, i)] = sqrt_reg;
572                        augmented_b[(j_matrix.rows() + i, 0)] = 0.0;
573                    }
574                    augmented_j.solve_least_squares_qr_with_info(&augmented_b)
575                }
576            };
577
578            match reg_result {
579                Ok((delta, info)) => {
580                    last_solution = Some(delta);
581                    if !is_ill_conditioned(&info) {
582                        return Ok(last_solution.unwrap());
583                    }
584                }
585                Err(err) => {
586                    last_err = Some(err);
587                }
588            }
589        }
590
591        if let Some(delta) = last_solution {
592            return Ok(delta);
593        }
594
595        Err(LeastSquaresSolveError::Singular {
596            qr_err: qr_err
597                .or_else(|| {
598                    unreg_info
599                        .map(|info| format!("QR ill-conditioned (rank {}, cond {:.2e})", info.rank, info.cond_est))
600                })
601                .unwrap_or_else(|| "QR least squares failed".to_string()),
602            normal_eq_err: last_err.unwrap_or_else(|| "Regularized QR failed".to_string()),
603        })
604    }
605
606    fn map_initial_guess(
607        &self,
608        initial_guess: HashMap<String, f64>,
609    ) -> Result<HashMap<VarId, f64>, SolverError> {
610        let mut unknown = Vec::new();
611        let mut vars = HashMap::new();
612
613        for (name, value) in initial_guess {
614            if let Some(id) = self.var_table.get_id(&name) {
615                vars.insert(id, value);
616            } else {
617                unknown.push(name);
618            }
619        }
620
621        if !unknown.is_empty() {
622            unknown.sort();
623            return Err(SolverError::InvalidInput(format!(
624                "Initial guess includes unknown variables: {unknown:?}"
625            )));
626        }
627
628        self.validate_initial_guess(&vars)?;
629        Ok(vars)
630    }
631
632    fn validate_initial_guess(&self, vars: &HashMap<VarId, f64>) -> Result<(), SolverError> {
633        let mut missing = Vec::new();
634        for (idx, name) in self.var_table.names().iter().enumerate() {
635            if !vars.contains_key(&VarId::new(idx)) {
636                missing.push(name.clone());
637            }
638        }
639        if missing.is_empty() {
640            return Ok(());
641        }
642
643        missing.sort();
644        Err(SolverError::InvalidInput(format!(
645            "Initial guess missing values for variables: {missing:?}"
646        )))
647    }
648
649    fn values_by_name(&self, vars: &HashMap<VarId, f64>) -> HashMap<String, f64> {
650        let mut values = HashMap::with_capacity(self.var_table.len());
651        for (idx, name) in self.var_table.names().iter().enumerate() {
652            if let Some(value) = vars.get(&VarId::new(idx)) {
653                values.insert(name.clone(), *value);
654            }
655        }
656        values
657    }
658
659    fn build_diagnostic(
660        &self,
661        message: String,
662        iterations: usize,
663        vars: &HashMap<VarId, f64>,
664        residuals_matrix: &Matrix,
665    ) -> SolverRunDiagnostic {
666        let mut residuals = Vec::with_capacity(residuals_matrix.rows());
667        for i in 0..residuals_matrix.rows() {
668            residuals.push(residuals_matrix[(i, 0)]);
669        }
670
671        SolverRunDiagnostic {
672            message,
673            iterations,
674            error: residuals_matrix.norm(),
675            values: self.values_by_name(vars),
676            residuals,
677        }
678    }
679
680    /// Backtracking line search to find optimal step size
681    ///
682    /// Tries progressively smaller step sizes until one is found that reduces the error.
683    /// This helps prevent oscillation and improves convergence robustness.
684    ///
685    /// # Arguments
686    /// * `jacobian` - Jacobian evaluator for computing function values
687    /// * `vars` - Current variable values
688    /// * `delta` - Newton step direction
689    /// * `current_error` - Current function error |f(x)|
690    ///
691    /// # Returns
692    /// Optimal step size alpha such that x_new = x + alpha * delta reduces error
693    fn line_search(
694        &self,
695        jacobian: &Jacobian,
696        vars: &HashMap<VarId, f64>,
697        delta: &Matrix,
698        current_error: f64,
699        f_vals: &mut Matrix,
700    ) -> Result<f64, SolverError> {
701        let mut alpha = 1.0; // Start with full Newton step
702        let mut new_vars = vars.clone();
703
704        // Try up to 20 different step sizes
705        for _ in 0..20 {
706            new_vars.clone_from(vars);
707            // Compute new variable values: x_new = x + alpha * delta
708            for (i, &var_id) in self.variables.iter().enumerate() {
709                new_vars.insert(var_id, vars[&var_id] + alpha * delta[(i, 0)]);
710            }
711
712            // Check if this step size reduces the error sufficiently
713            jacobian.evaluate_functions_checked_into(&new_vars, f_vals)?;
714            let new_error = f_vals.norm();
715            // Armijo condition: require sufficient decrease
716            if new_error < current_error * (1.0 - 0.5 * alpha) {
717                return Ok(alpha);
718            }
719            // Reduce step size and try again
720            alpha *= 0.5;
721
722            // Give up if step size becomes too small
723            if alpha < 1e-10 {
724                break;
725            }
726        }
727
728        // Return at least the minimum damping to avoid complete stagnation
729        Ok(alpha.max(self.min_damping))
730    }
731
732    /// Solve the system using Newton-Raphson with line search
733    ///
734    /// Line search helps improve convergence robustness by automatically
735    /// finding good step sizes, especially useful for difficult systems.
736    pub fn solve_with_line_search(
737        &self,
738        initial_guess: HashMap<String, f64>,
739    ) -> Result<Solution, SolverError> {
740        let vars = self.map_initial_guess(initial_guess)?;
741        self.solve_modified_internal(vars, true)
742    }
743}
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748    use crate::compiler::Compiler;
749    use crate::exp::Exp;
750
751    struct TestRng {
752        state: u64,
753    }
754
755    impl TestRng {
756        fn new(seed: u64) -> Self {
757            Self { state: seed }
758        }
759
760        fn next_u32(&mut self) -> u32 {
761            self.state = self
762                .state
763                .wrapping_mul(6364136223846793005)
764                .wrapping_add(1);
765            (self.state >> 32) as u32
766        }
767
768        fn next_f64(&mut self) -> f64 {
769            let v = self.next_u32() as f64 / u32::MAX as f64;
770            2.0 * v - 1.0
771        }
772    }
773
774    fn linear_equation(coeffs: &[f64], vars: &[Exp], rhs: f64) -> Exp {
775        let mut sum = Exp::val(0.0);
776        for (coeff, var) in coeffs.iter().zip(vars.iter()) {
777            let term = Exp::mul(Exp::val(*coeff), var.clone());
778            sum = Exp::add(sum, term);
779        }
780        Exp::sub(sum, Exp::val(rhs))
781    }
782
783    fn solver_for(equations: Vec<Exp>) -> NewtonRaphsonSolver {
784        let compiled = Compiler::compile(&equations).expect("compile failed");
785        NewtonRaphsonSolver::new(compiled)
786    }
787
788    fn solver_for_with_vars(equations: Vec<Exp>, variables: &[&str]) -> NewtonRaphsonSolver {
789        let compiled = Compiler::compile(&equations).expect("compile failed");
790        NewtonRaphsonSolver::new_with_variables(compiled, variables)
791            .expect("failed to select solve variables")
792    }
793
794    #[test]
795    fn test_simple_system() {
796        // Test system: x^2 + y^2 = 1, xy = 0.25
797        // Solution should be approximately x ~= 0.5, y ~= 0.866 (or vice versa)
798        let x = Exp::var("x");
799        let y = Exp::var("y");
800
801        let eq1 = Exp::sub(
802            Exp::add(Exp::power(x.clone(), 2.0), Exp::power(y.clone(), 2.0)),
803            Exp::val(1.0),
804        );
805        let eq2 = Exp::sub(Exp::mul(x.clone(), y.clone()), Exp::val(0.25));
806
807        let solver = solver_for(vec![eq1, eq2]);
808
809        let mut initial = HashMap::new();
810        initial.insert("x".to_string(), 0.5);
811        initial.insert("y".to_string(), 0.866);
812
813        let solution = match solver.solve(initial.clone()) {
814            Ok(sol) => sol,
815            Err(_) => match solver.solve_with_line_search(initial) {
816                Ok(sol) => sol,
817                Err(e) => panic!("Failed to solve: {:?}", e),
818            },
819        };
820        assert!(solution.converged);
821        assert!(solution.error < 1e-10);
822
823        let x_sol = solution.values.get("x").copied().unwrap();
824        let y_sol = solution.values.get("y").copied().unwrap();
825        assert!((x_sol * x_sol + y_sol * y_sol - 1.0).abs() < 1e-10);
826        assert!((x_sol * y_sol - 0.25).abs() < 1e-10);
827    }
828
829    #[test]
830    fn test_transcendental_system() {
831        // Test transcendental system: sin(x) = 2y, cos(y) = x
832        let x = Exp::var("x");
833        let y = Exp::var("y");
834
835        let eq1 = Exp::sub(Exp::sin(x.clone()), Exp::mul(y.clone(), Exp::val(2.0)));
836        let eq2 = Exp::sub(Exp::cos(y.clone()), x.clone());
837
838        let solver = solver_for(vec![eq1, eq2])
839            .with_tolerance(1e-8)
840            .with_max_iterations(50)
841            .with_regularization(1e-8);
842
843        let mut initial = HashMap::new();
844        initial.insert("x".to_string(), 0.5);
845        initial.insert("y".to_string(), 0.25);
846
847        let solution = solver
848            .solve_with_line_search(initial)
849            .expect("Failed to solve transcendental system");
850
851        assert!(solution.converged);
852        let x_sol = solution.values.get("x").copied().unwrap();
853        let y_sol = solution.values.get("y").copied().unwrap();
854        assert!((x_sol.sin() - 2.0 * y_sol).abs() < 1e-8);
855        assert!((y_sol.cos() - x_sol).abs() < 1e-8);
856    }
857
858    #[test]
859    fn test_solver_errors_on_missing_variable() {
860        // System: x - a = 0, solving for x with a treated as a fixed parameter.
861        // Missing `a` should be an error (not implicitly treated as 0).
862        let x = Exp::var("x");
863        let a = Exp::var("a");
864        let eq = Exp::sub(x.clone(), a.clone());
865
866        let solver = solver_for_with_vars(vec![eq], &["x"]);
867
868        let mut initial = HashMap::new();
869        initial.insert("x".to_string(), 1.0);
870
871        let err = solver.solve(initial).expect_err("expected missing variable error");
872        match err {
873            SolverError::InvalidInput(msg) => {
874                assert!(msg.contains("a"), "unexpected message: {msg}");
875            }
876            other => panic!("expected InvalidInput, got {:?}", other),
877        }
878    }
879
880    #[test]
881    fn test_solver_invalid_input_on_missing_initial_guess_variable() {
882        let x = Exp::var("x");
883        let y = Exp::var("y");
884        let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
885
886        let solver = solver_for(vec![eq]);
887
888        let mut initial = HashMap::new();
889        initial.insert("x".to_string(), 0.0);
890
891        let err = solver
892            .solve(initial)
893            .expect_err("expected invalid input due to missing y");
894        match err {
895            SolverError::InvalidInput(msg) => {
896                assert!(msg.contains("y"), "unexpected message: {msg}");
897            }
898            other => panic!("expected InvalidInput, got {:?}", other),
899        }
900    }
901
902    #[test]
903    fn test_try_with_equation_traces_invalid_length() {
904        let x = Exp::var("x");
905        let eq = Exp::sub(x.clone(), Exp::val(1.0));
906
907        let solver = solver_for(vec![eq]);
908
909        let traces = vec![
910            Some(EquationTrace {
911                constraint_id: 0,
912                description: "eq0".to_string(),
913            }),
914            None,
915        ];
916        let err = match solver.try_with_equation_traces(traces) {
917            Ok(_) => panic!("expected invalid input for trace length mismatch"),
918            Err(err) => err,
919        };
920        match err {
921            SolverError::InvalidInput(msg) => {
922                assert!(msg.contains("Equation trace length (2)"), "{}", msg);
923                assert!(msg.contains("number of equations (1)"), "{}", msg);
924            }
925            other => panic!("expected InvalidInput, got {:?}", other),
926        }
927    }
928
929    #[test]
930    fn test_line_search_preserves_fixed_variables() {
931        // Regression test: line search must preserve "fixed parameter" variables that are
932        // not part of `self.variables` but are referenced by equations.
933        //
934        // If line search drops them, evaluation errors and the solver fails.
935        let x = Exp::var("x");
936        let a = Exp::var("a");
937        let eq = Exp::sub(x.clone(), a.clone());
938
939        let solver = solver_for_with_vars(vec![eq], &["x"]);
940
941        let mut initial = HashMap::new();
942        initial.insert("x".to_string(), 0.0);
943        initial.insert("a".to_string(), 2.0);
944
945        let solution = solver
946            .solve_with_line_search(initial)
947            .expect("solver should converge when fixed variables are provided");
948        assert!(solution.converged);
949        assert!((solution.values.get("x").copied().unwrap() - 2.0).abs() < 1e-10);
950    }
951
952    #[test]
953    fn test_least_squares_qr_handles_ill_conditioned_overdetermined_system_without_regularization()
954    {
955        // This system is overdetermined (3 equations, 2 unknowns) with an ill-conditioned
956        // Jacobian whose columns are nearly linearly dependent. Normal equations (J^T J) can
957        // lose the tiny distinguishing term and become singular in floating point. QR-based
958        // least squares should still solve it.
959        let eps = 2f64.powi(-27); // exactly representable; eps^2 is below 1 ulp at ~3.0
960
961        let x = Exp::var("x");
962        let y = Exp::var("y");
963
964        // A = [[1, 1], [1, 1+eps], [1, 1-eps]]
965        // b corresponds to solution (x,y) = (1,1)
966        let eq1 = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(2.0));
967        let eq2 = Exp::sub(
968            Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 + eps))),
969            Exp::val(2.0 + eps),
970        );
971        let eq3 = Exp::sub(
972            Exp::add(x.clone(), Exp::mul(y.clone(), Exp::val(1.0 - eps))),
973            Exp::val(2.0 - eps),
974        );
975
976        let solver = solver_for(vec![eq1, eq2, eq3])
977            .with_regularization(0.0)
978            .with_damping(1.0)
979            .with_max_iterations(10)
980            .with_tolerance(1e-10);
981
982        let mut initial = HashMap::new();
983        initial.insert("x".to_string(), 0.0);
984        initial.insert("y".to_string(), 0.0);
985
986        let solution = solver
987            .solve(initial)
988            .expect("expected solver to converge with QR least squares");
989
990        assert!(solution.converged, "{:?}", solution);
991        assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-8);
992        assert!((solution.values.get("y").copied().unwrap() - 1.0).abs() < 1e-8);
993    }
994
995    #[test]
996    fn test_underconstrained_system_returns_min_norm_solution() {
997        // Underdetermined system: x + y = 1 has infinitely many solutions.
998        // The solver should return the minimum-norm solution: x = 0.5, y = 0.5.
999        let x = Exp::var("x");
1000        let y = Exp::var("y");
1001
1002        let eq = Exp::sub(Exp::add(x.clone(), y.clone()), Exp::val(1.0));
1003        let solver = solver_for(vec![eq])
1004            .with_tolerance(1e-12)
1005            .with_max_iterations(10);
1006
1007        let mut initial = HashMap::new();
1008        initial.insert("x".to_string(), 0.0);
1009        initial.insert("y".to_string(), 0.0);
1010
1011        let solution = solver.solve(initial).expect("expected solver to converge");
1012        assert!(solution.converged, "{:?}", solution);
1013        assert!((solution.values.get("x").copied().unwrap() - 0.5).abs() < 1e-10);
1014        assert!((solution.values.get("y").copied().unwrap() - 0.5).abs() < 1e-10);
1015    }
1016
1017    #[test]
1018    fn test_rank_deficient_overconstrained_recovers_with_regularization() {
1019        // Equations depend only on x; y is unconstrained. The Jacobian is rank deficient.
1020        let x = Exp::var("x");
1021        let y = Exp::var("y");
1022        let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1023
1024        let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1025        let eq2 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()), Exp::val(2.0));
1026        let eq3 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()), Exp::val(3.0));
1027
1028        let solver = solver_for(vec![eq1, eq2, eq3])
1029            .with_tolerance(1e-12)
1030            .with_max_iterations(10);
1031
1032        let mut initial = HashMap::new();
1033        initial.insert("x".to_string(), 0.0);
1034        initial.insert("y".to_string(), 0.0);
1035
1036        let solution = solver.solve(initial).expect("expected solver to converge");
1037        assert!(solution.converged, "{:?}", solution);
1038        assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1039        assert!((solution.values.get("y").copied().unwrap() - 0.0).abs() < 1e-10);
1040    }
1041
1042    #[test]
1043    fn test_rank_deficient_overconstrained_fails_without_regularization() {
1044        let x = Exp::var("x");
1045        let y = Exp::var("y");
1046        let y_zero = Exp::mul(y.clone(), Exp::val(0.0));
1047
1048        let eq1 = Exp::sub(Exp::add(x.clone(), y_zero.clone()), Exp::val(1.0));
1049        let eq2 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(2.0)), y_zero.clone()), Exp::val(2.0));
1050        let eq3 = Exp::sub(Exp::add(Exp::mul(x.clone(), Exp::val(3.0)), y_zero.clone()), Exp::val(3.0));
1051
1052        let solver = solver_for(vec![eq1, eq2, eq3])
1053            .with_regularization(0.0)
1054            .with_tolerance(1e-12)
1055            .with_max_iterations(10);
1056
1057        let mut initial = HashMap::new();
1058        initial.insert("x".to_string(), 0.0);
1059        initial.insert("y".to_string(), 0.0);
1060
1061        let err = solver.solve(initial).expect_err("expected singular matrix error");
1062        assert!(matches!(err, SolverError::SingularMatrix(_)), "{:?}", err);
1063    }
1064
1065    #[test]
1066    fn test_overconstrained_consistent_system_solves() {
1067        // Overdetermined but consistent system: x = 1 and 2x = 2.
1068        let x = Exp::var("x");
1069        let eq1 = Exp::sub(x.clone(), Exp::val(1.0));
1070        let eq2 = Exp::sub(Exp::mul(x.clone(), Exp::val(2.0)), Exp::val(2.0));
1071
1072        let solver = solver_for(vec![eq1, eq2])
1073            .with_tolerance(1e-12)
1074            .with_max_iterations(10);
1075
1076        let mut initial = HashMap::new();
1077        initial.insert("x".to_string(), 0.0);
1078
1079        let solution = solver.solve(initial).expect("expected solver to converge");
1080        assert!(solution.converged, "{:?}", solution);
1081        assert!((solution.values.get("x").copied().unwrap() - 1.0).abs() < 1e-10);
1082    }
1083
1084    #[test]
1085    fn test_random_square_linear_systems() {
1086        let mut rng = TestRng::new(0x51ab_1e55_cafe_f00d);
1087        for _ in 0..5 {
1088            let n = 3;
1089            let mut a = vec![vec![0.0; n]; n];
1090            for i in 0..n {
1091                for j in 0..n {
1092                    a[i][j] = rng.next_f64();
1093                }
1094                a[i][i] += 2.0;
1095            }
1096
1097            let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1098            let mut b = vec![0.0; n];
1099            for i in 0..n {
1100                for j in 0..n {
1101                    b[i] += a[i][j] * x_true[j];
1102                }
1103            }
1104
1105            let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1106            let equations: Vec<Exp> = (0..n)
1107                .map(|i| linear_equation(&a[i], &vars, b[i]))
1108                .collect();
1109
1110            let solver = solver_for(equations)
1111                .with_tolerance(1e-12)
1112                .with_max_iterations(20);
1113
1114            let mut initial = HashMap::new();
1115            for i in 0..n {
1116                initial.insert(format!("x{i}"), 0.0);
1117            }
1118
1119            let solution = solver.solve(initial).expect("expected to solve");
1120            for i in 0..n {
1121                let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1122                assert!((value - x_true[i]).abs() < 1e-8);
1123            }
1124        }
1125    }
1126
1127    #[test]
1128    fn test_random_overconstrained_linear_systems() {
1129        let mut rng = TestRng::new(0xa11c_e551_dead_beef);
1130        let m = 5;
1131        let n = 3;
1132
1133        for _ in 0..5 {
1134            let mut a = vec![vec![0.0; n]; m];
1135            for i in 0..m {
1136                for j in 0..n {
1137                    a[i][j] = rng.next_f64();
1138                }
1139            }
1140            for i in 0..n {
1141                a[i][i] += 2.0;
1142            }
1143
1144            let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1145            let mut b = vec![0.0; m];
1146            for i in 0..m {
1147                for j in 0..n {
1148                    b[i] += a[i][j] * x_true[j];
1149                }
1150            }
1151
1152            let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1153            let equations: Vec<Exp> = (0..m)
1154                .map(|i| linear_equation(&a[i], &vars, b[i]))
1155                .collect();
1156
1157            let solver = solver_for(equations)
1158                .with_tolerance(1e-12)
1159                .with_max_iterations(20);
1160
1161            let mut initial = HashMap::new();
1162            for i in 0..n {
1163                initial.insert(format!("x{i}"), 0.0);
1164            }
1165
1166            let solution = solver.solve(initial).expect("expected to solve");
1167            for i in 0..n {
1168                let value = solution.values.get(&format!("x{i}")).copied().unwrap();
1169                assert!((value - x_true[i]).abs() < 1e-8);
1170            }
1171        }
1172    }
1173
1174    #[test]
1175    fn test_random_underdetermined_min_norm_structure() {
1176        let mut rng = TestRng::new(0x0ddc_affe_fade_bead);
1177        let _m = 2;
1178        let n = 4;
1179        let vars: Vec<Exp> = (0..n).map(|i| Exp::var(format!("x{i}"))).collect();
1180
1181        for _ in 0..5 {
1182            let b0 = rng.next_f64();
1183            let b1 = rng.next_f64();
1184            let x2_zero = Exp::mul(vars[2].clone(), Exp::val(0.0));
1185            let x3_zero = Exp::mul(vars[3].clone(), Exp::val(0.0));
1186            let eq1 = Exp::sub(Exp::add(vars[0].clone(), x2_zero.clone()), Exp::val(b0));
1187            let eq2 = Exp::sub(Exp::add(vars[1].clone(), x3_zero.clone()), Exp::val(b1));
1188
1189            let solver = solver_for(vec![eq1, eq2])
1190                .with_tolerance(1e-12)
1191                .with_max_iterations(20);
1192
1193            let mut initial = HashMap::new();
1194            for i in 0..n {
1195                initial.insert(format!("x{i}"), 0.0);
1196            }
1197
1198            let solution = solver.solve(initial).expect("expected to solve");
1199            assert!((solution.values.get("x0").copied().unwrap() - b0).abs() < 1e-10);
1200            assert!((solution.values.get("x1").copied().unwrap() - b1).abs() < 1e-10);
1201            assert!(solution.values.get("x2").copied().unwrap().abs() < 1e-10);
1202            assert!(solution.values.get("x3").copied().unwrap().abs() < 1e-10);
1203        }
1204    }
1205}