1use super::{
2 LinearSolver, NonlinearSystem, RowMap, SolverError, SolverResult, SparseColMatRef,
3 init_global_parallelism,
4 linalg::{DenseLu, FaerLu, SparseQr},
5};
6use error_stack::Report;
7use faer::mat::Mat as FaerMat;
8use faer_traits::ComplexField;
9use num_traits::{Float, One, ToPrimitive, Zero};
10use std::panic;
11
12const AUTO_DENSE_THRESHOLD: usize = 100;
13const FTOL_DEFAULT: f64 = 1e-8;
14const XTOL_DEFAULT: f64 = 1e-8;
15const GTOL_DEFAULT: f64 = 1e-8;
16
17#[derive(Clone, Copy, Debug, PartialEq)]
18pub enum MatrixFormat {
19 Sparse,
20 Dense,
21 Auto,
22}
23
24impl Default for MatrixFormat {
25 fn default() -> Self {
26 Self::Auto
28 }
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum NormType {
33 L2,
34 LInf,
35}
36
37#[derive(Clone, Copy, Debug)]
38pub struct NewtonCfg<T> {
39 pub tol: T,
40 pub tol_grad: T,
41 pub tol_step: T,
42 pub damping: T,
43 pub max_iter: usize,
44 pub format: MatrixFormat,
45
46 pub adaptive: bool,
48 pub min_damping: T,
49 pub max_damping: T,
50 pub grow: T,
51 pub shrink: T,
52 pub divergence_ratio: T,
53 pub ls_backtrack: T,
54 pub ls_max_steps: usize,
55
56 pub n_threads: usize,
57}
58
59impl<T: Float> Default for NewtonCfg<T> {
60 fn default() -> Self {
61 let _ = init_global_parallelism(0);
62 Self {
63 tol: T::from(FTOL_DEFAULT).unwrap(),
64 tol_grad: T::from(GTOL_DEFAULT).unwrap(),
65 tol_step: T::from(XTOL_DEFAULT).unwrap(),
66 damping: T::one(),
67 max_iter: 50,
68 format: MatrixFormat::default(),
69 adaptive: false,
70 min_damping: T::from(0.1).unwrap(),
71 max_damping: T::one(),
72 grow: T::from(1.1).unwrap(),
73 shrink: T::from(0.5).unwrap(),
74 divergence_ratio: T::from(3.0).unwrap(),
75 ls_backtrack: T::from(0.5).unwrap(),
76 ls_max_steps: 10,
77 n_threads: 0,
78 }
79 }
80}
81
82impl<T: Float> NewtonCfg<T> {
83 pub fn sparse() -> Self {
84 Self {
85 format: MatrixFormat::Sparse,
86 ..Default::default()
87 }
88 }
89 pub fn dense() -> Self {
90 Self {
91 format: MatrixFormat::Dense,
92 ..Default::default()
93 }
94 }
95 pub fn with_format(mut self, format: MatrixFormat) -> Self {
96 self.format = format;
97 self
98 }
99 pub fn with_adaptive(mut self, enabled: bool) -> Self {
100 self.adaptive = enabled;
101 self
102 }
103 pub fn with_threads(mut self, n_threads: usize) -> Self {
104 init_global_parallelism(n_threads);
105 self.n_threads = n_threads;
106 self
107 }
108 pub fn with_tol(mut self, tol: T) -> Self {
109 self.tol = tol;
110 self
111 }
112 pub fn with_tol_grad(mut self, tol_grad: T) -> Self {
113 self.tol_grad = tol_grad;
114 self
115 }
116 pub fn with_tol_step(mut self, tol_step: T) -> Self {
117 self.tol_step = tol_step;
118 self
119 }
120}
121
122pub type Iterations = usize;
123
124#[derive(Clone, Debug)]
125pub struct IterationStats<T> {
126 pub iter: usize,
127 pub residual: T,
128 pub damping: T,
129}
130
131#[derive(Clone, Copy, Debug, PartialEq, Eq)]
132pub enum Control {
133 Continue,
134 Cancel,
135}
136
137fn compute_residual_norm<T: Float>(f: &[T], norm_type: NormType) -> T {
138 match norm_type {
139 NormType::LInf => f.iter().map(|&v| v.abs()).fold(T::zero(), |a, b| a.max(b)),
140 NormType::L2 => f
141 .iter()
142 .map(|&v| v.powi(2))
143 .fold(T::zero(), |a, b| a + b)
144 .sqrt(),
145 }
146}
147
148fn compute_step_norm<T: Float>(step: &[T], x: &[T], tol: T) -> T {
149 let step_norm = step
154 .iter()
155 .map(|&v| v.powi(2))
156 .fold(T::zero(), |a, b| a + b)
157 .sqrt();
158 let x_norm = x
159 .iter()
160 .map(|&v| v.powi(2))
161 .fold(T::zero(), |a, b| a + b)
162 .sqrt();
163
164 step_norm / (x_norm + tol)
165}
166
167fn compute_gradient_norm_sparse<T: Float>(
168 jacobian: &SparseColMatRef<'_, usize, T>,
169 residual: &[T],
170) -> T {
171 let mut max_grad = T::zero();
177
178 for col in 0..jacobian.ncols() {
179 let mut grad_component = T::zero();
180 let range = jacobian.col_range(col);
181 let row_idx = jacobian.symbolic().row_idx();
182 let vals = jacobian.val();
183
184 for idx in range {
185 grad_component = grad_component + vals[idx] * residual[row_idx[idx]];
186 }
187
188 let abs_grad = grad_component.abs();
189 if abs_grad > max_grad {
190 max_grad = abs_grad;
191 }
192 }
193
194 max_grad
195}
196
197fn compute_gradient_norm_dense<T: Float>(jacobian: &FaerMat<T>, residual: &[T]) -> T {
198 let mut max_grad = T::zero();
202
203 for col in 0..jacobian.ncols() {
204 let mut grad_component = T::zero();
205
206 for row in 0..jacobian.nrows() {
207 grad_component = grad_component + jacobian[(row, col)] * residual[row];
208 }
209
210 let abs_grad = grad_component.abs();
211 if abs_grad > max_grad {
212 max_grad = abs_grad;
213 }
214 }
215
216 max_grad
217}
218
219trait GradientNorm<M>
221where
222 M: NonlinearSystem,
223 M::Real: Float,
224{
225 fn compute_gradient_norm(&self, model: &mut M, residual: &[M::Real]) -> M::Real;
226}
227
228struct Sparse;
229
230impl<M> GradientNorm<M> for Sparse
231where
232 M: NonlinearSystem,
233 M::Real: Float,
234{
235 fn compute_gradient_norm(&self, model: &mut M, residual: &[M::Real]) -> M::Real {
236 let jac_ref = model.jacobian().attach();
238 compute_gradient_norm_sparse(&jac_ref, residual)
239 }
240}
241
242struct Dense<M>
243where
244 M: NonlinearSystem,
245{
246 jac_dense: FaerMat<M::Real>,
247}
248
249impl<M> GradientNorm<M> for Dense<M>
250where
251 M: NonlinearSystem,
252 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
253{
254 fn compute_gradient_norm(&self, _model: &mut M, residual: &[M::Real]) -> M::Real
255 where
256 M: NonlinearSystem,
257 M::Real: Float,
258 {
259 compute_gradient_norm_dense(&self.jac_dense, residual)
261 }
262}
263
264fn newton_iterate<M, F, Cb, GradNorm>(
265 model: &mut M,
266 x: &mut [M::Real],
267 cfg: NewtonCfg<M::Real>,
268 norm_type: NormType,
269 mut solve: F,
270 mut on_iter: Cb,
271) -> SolverResult<Iterations>
272where
273 GradNorm: GradientNorm<M>,
274 M: NonlinearSystem,
275 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
276 F: FnMut(&mut M, &[M::Real], &[M::Real], &mut [M::Real]) -> SolverResult<GradNorm>,
277 Cb: FnMut(&IterationStats<M::Real>) -> Control,
278{
279 let n_vars = model.layout().n_variables();
280 let n_res = model.layout().n_residuals();
281
282 let mut f = vec![M::Real::zero(); n_res];
284 let mut dx = vec![M::Real::zero(); n_vars];
285 let mut damping = cfg.damping;
286 let mut last_res = M::Real::infinity();
287
288 let mut x_trial = vec![M::Real::zero(); n_vars];
290 let mut f_trial = vec![M::Real::zero(); n_res];
291
292 for iter in 0..cfg.max_iter {
293 model.residual(x, &mut f);
294 let res = compute_residual_norm(&f, norm_type);
295
296 if res < cfg.tol {
299 return Ok(iter);
300 }
301
302 if matches!(
303 on_iter(&IterationStats {
304 iter,
305 residual: res,
306 damping
307 }),
308 Control::Cancel
309 ) {
310 return Err(Report::new(SolverError).attach_printable("solve cancelled"));
311 }
312
313 let jacobian = solve(model, x, &f, &mut dx)?;
317
318 if cfg.tol_step > M::Real::zero() {
323 let step_norm = compute_step_norm(&dx, x, cfg.tol_step);
324 if step_norm < cfg.tol_step {
325 return Ok(iter + 1);
326 }
327 }
328
329 if cfg.tol_grad > M::Real::zero() {
333 let grad_norm = jacobian.compute_gradient_norm(model, &f);
334 if grad_norm < cfg.tol_grad {
335 return Ok(iter + 1);
336 }
337 }
338
339 let mut step_applied = false;
340
341 if cfg.adaptive {
342 if res < last_res {
343 let nd = damping * cfg.grow;
344 damping = if nd > cfg.max_damping {
345 cfg.max_damping
346 } else {
347 nd
348 };
349 } else {
350 let nd = damping * cfg.shrink;
351 damping = if nd < cfg.min_damping {
352 cfg.min_damping
353 } else {
354 nd
355 };
356 }
357
358 if last_res.is_finite() && res > last_res * cfg.divergence_ratio {
359 let mut alpha = if damping * cfg.shrink < cfg.min_damping {
360 cfg.min_damping
361 } else {
362 damping * cfg.shrink
363 };
364
365 for _ in 0..cfg.ls_max_steps {
366 for i in 0..n_vars {
367 x_trial[i] = x[i] + alpha * dx[i];
368 }
369 model.residual(&x_trial, &mut f_trial);
370 let res_try = compute_residual_norm(&f_trial, norm_type);
371
372 if res_try < res {
373 x.copy_from_slice(&x_trial);
374 damping = alpha;
375 step_applied = true;
376 break;
377 }
378 alpha = alpha * cfg.ls_backtrack;
379 if alpha < cfg.min_damping {
380 break;
381 }
382 }
383
384 if !step_applied {
385 return Err(Report::new(SolverError)
386 .attach_printable("divergence guard: line search failed"));
387 }
388 }
389 }
390
391 if !step_applied {
392 for (xi, &dxi) in x.iter_mut().zip(dx.iter()) {
393 *xi = *xi + damping * dxi;
394 }
395 }
396
397 last_res = res;
398 }
399
400 Err(Report::new(SolverError).attach_printable(format!(
401 "Newton solver did not converge after {} iterations",
402 cfg.max_iter
403 )))
404}
405
406pub fn solve<M>(
407 model: &mut M,
408 x: &mut [M::Real],
409 cfg: NewtonCfg<M::Real>,
410) -> SolverResult<Iterations>
411where
412 M: NonlinearSystem,
413 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
414{
415 solve_cb(model, x, cfg, |_| Control::Continue)
416}
417
418pub fn solve_cb<M, Cb>(
419 model: &mut M,
420 x: &mut [M::Real],
421 cfg: NewtonCfg<M::Real>,
422 on_iter: Cb,
423) -> SolverResult<Iterations>
424where
425 M: NonlinearSystem,
426 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
427 Cb: FnMut(&IterationStats<M::Real>) -> Control,
428{
429 let n_vars = model.layout().n_variables();
430 let n_res = model.layout().n_residuals();
431 let is_square = n_vars == n_res;
432
433 let use_dense = if cfg.format == MatrixFormat::Dense {
435 is_square
438 } else if cfg.format == MatrixFormat::Sparse {
439 false
441 } else {
442 is_square && n_vars < AUTO_DENSE_THRESHOLD
444 };
445
446 if use_dense {
447 solve_dense_lu(model, x, cfg, on_iter)
448 } else if is_square {
449 solve_sparse_lu_with_qr_fallback(model, x, cfg, on_iter)
450 } else {
451 solve_sparse_qr(model, x, cfg, on_iter)
452 }
453}
454
455fn solve_dense_lu<M, Cb>(
456 model: &mut M,
457 x: &mut [M::Real],
458 cfg: NewtonCfg<M::Real>,
459 on_iter: Cb,
460) -> SolverResult<Iterations>
461where
462 M: NonlinearSystem,
463 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
464 Cb: FnMut(&IterationStats<M::Real>) -> Control,
465{
466 let n = model.layout().n_variables();
467 let mut lu = DenseLu::<M::Real>::default();
468 let mut jac = FaerMat::<M::Real>::zeros(n, n);
469 let mut rhs = FaerMat::<M::Real>::zeros(n, 1);
470
471 #[allow(clippy::too_many_arguments)]
472 fn solve_inner<T, M2>(
480 model: &mut M2,
481 x: &[T],
482 f: &[T],
483 dx: &mut [T],
484 lu: &mut DenseLu<T>,
485 jac: &mut FaerMat<T>,
486 rhs: &mut FaerMat<T>,
487 ) -> SolverResult<Dense<M2>>
488 where
489 M2: NonlinearSystem<Real = T>,
490 T: ComplexField<Real = T> + Float + Zero + One + ToPrimitive,
491 {
492 model.jacobian_dense(x, jac);
494 lu.factor(jac)?;
495
496 for (i, &fi) in f.iter().enumerate() {
497 rhs[(i, 0)] = -fi;
498 }
499 lu.solve_in_place(rhs.as_mut())?;
500
501 for (i, &val) in rhs.col(0).iter().enumerate() {
502 dx[i] = val;
503 }
504
505 Ok(Dense {
507 jac_dense: jac.clone(),
508 })
509 }
510
511 newton_iterate(
513 model,
514 x,
515 cfg,
516 NormType::LInf,
517 |model, x, f, dx| solve_inner(model, x, f, dx, &mut lu, &mut jac, &mut rhs),
518 on_iter,
519 )
520}
521
522fn solve_sparse<M, S, Cb>(
523 model: &mut M,
524 x: &mut [M::Real],
525 cfg: NewtonCfg<M::Real>,
526 norm_type: NormType,
527 mut solver: S,
528 on_iter: Cb,
529) -> SolverResult<Iterations>
530where
531 M: NonlinearSystem,
532 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
533 S: for<'a> LinearSolver<M::Real, SparseColMatRef<'a, usize, M::Real>>,
534 Cb: FnMut(&IterationStats<M::Real>) -> Control,
535{
536 let n_vars = model.layout().n_variables();
537 let n_res = model.layout().n_residuals();
538 let mut rhs = FaerMat::<M::Real>::zeros(n_res, 1);
539
540 #[allow(clippy::too_many_arguments)]
541 fn solve_inner<T, S>(
542 model: &mut impl NonlinearSystem<Real = T>,
543 x: &[T],
544 f: &[T],
545 dx: &mut [T],
546 solver: &mut S,
547 rhs: &mut FaerMat<T>,
548 n_vars: usize,
549 ) -> SolverResult<Sparse>
550 where
551 T: ComplexField<Real = T> + Float + Zero + One + ToPrimitive,
552 S: for<'a> LinearSolver<T, SparseColMatRef<'a, usize, T>>,
553 {
554 model.refresh_jacobian(x);
556 let jac_ref = model.jacobian().attach();
557 solver.factor(&jac_ref)?;
558
559 rhs.col_mut(0)
560 .as_mut()
561 .iter_mut()
562 .zip(f.iter())
563 .for_each(|(dst, &src)| *dst = -src);
564
565 solver.solve_in_place(rhs.as_mut())?;
566
567 for (i, &val) in rhs.col(0).iter().take(n_vars).enumerate() {
568 dx[i] = val;
569 }
570
571 Ok(Sparse)
573 }
574
575 newton_iterate(
577 model,
578 x,
579 cfg,
580 norm_type,
581 |model, x, f, dx| solve_inner(model, x, f, dx, &mut solver, &mut rhs, n_vars),
582 on_iter,
583 )
584}
585
586fn solve_sparse_lu<M, Cb>(
587 model: &mut M,
588 x: &mut [M::Real],
589 cfg: NewtonCfg<M::Real>,
590 on_iter: Cb,
591) -> SolverResult<Iterations>
592where
593 M: NonlinearSystem,
594 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
595 Cb: FnMut(&IterationStats<M::Real>) -> Control,
596{
597 solve_sparse(
598 model,
599 x,
600 cfg,
601 NormType::LInf,
602 FaerLu::<M::Real>::default(),
603 on_iter,
604 )
605}
606
607fn solve_sparse_qr<M, Cb>(
608 model: &mut M,
609 x: &mut [M::Real],
610 cfg: NewtonCfg<M::Real>,
611 on_iter: Cb,
612) -> SolverResult<Iterations>
613where
614 M: NonlinearSystem,
615 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
616 Cb: FnMut(&IterationStats<M::Real>) -> Control,
617{
618 solve_sparse(
619 model,
620 x,
621 cfg,
622 NormType::L2,
623 SparseQr::<M::Real>::default(),
624 on_iter,
625 )
626}
627
628fn solve_sparse_lu_with_qr_fallback<M, Cb>(
629 model: &mut M,
630 x: &mut [M::Real],
631 cfg: NewtonCfg<M::Real>,
632 mut on_iter: Cb,
633) -> SolverResult<Iterations>
634where
635 M: NonlinearSystem,
636 M::Real: ComplexField<Real = M::Real> + Float + Zero + One + ToPrimitive,
637 Cb: FnMut(&IterationStats<M::Real>) -> Control,
638{
639 let lu_result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
641 solve_sparse_lu(model, x, cfg, &mut on_iter)
642 }));
643
644 match lu_result {
645 Ok(Ok(iterations)) => Ok(iterations),
646 Ok(Err(lu_error)) => Err(lu_error), Err(_panic) => {
648 solve_sparse_qr(model, x, cfg, on_iter)
650 }
651 }
652}