1use crate::construction::calculate_condition_number;
2use crate::estimate::EstimationError;
3use crate::faer_ndarray::{
4 FaerArrayView, FaerLinalgError, array2_to_matmut, factorize_symmetricwith_fallback,
5};
6use crate::faer_ndarray::{FaerCholesky, FaerEigh};
7use crate::matrix::symmetrize_in_place;
8use faer::Side;
9use ndarray::{
10 Array1, Array2, Array3, ArrayBase, ArrayView1, ArrayView2, ArrayView3, Data, Dimension, Zip, s,
11};
12
13#[inline]
23pub(crate) const fn splitmix64(state: &mut u64) -> u64 {
24 *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
25 let mut z = *state;
26 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
27 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
28 z ^ (z >> 31)
29}
30
31#[inline]
35pub(crate) const fn splitmix64_hash(x: u64) -> u64 {
36 let mut state = x;
37 splitmix64(&mut state)
38}
39
40pub(crate) fn stack_offsets(blocks: &[&Array1<f64>]) -> Array1<f64> {
49 let total: usize = blocks.iter().map(|block| block.len()).sum();
50 let mut out = Array1::<f64>::zeros(total);
51 let mut row = 0usize;
52 for block in blocks {
53 let end = row + block.len();
54 out.slice_mut(ndarray::s![row..end]).assign(block);
55 row = end;
56 }
57 out
58}
59
60#[inline]
67pub(crate) fn stable_softplus(x: f64) -> f64 {
68 if x > 0.0 {
69 x + (-x).exp().ln_1p()
70 } else {
71 x.exp().ln_1p()
72 }
73}
74
75#[inline]
86pub(crate) fn stable_logistic(x: f64) -> f64 {
87 if x >= 0.0 {
88 1.0 / (1.0 + (-x).exp())
89 } else {
90 let ex = x.exp();
91 ex / (1.0 + ex)
92 }
93}
94
95#[inline]
97pub(crate) fn array_is_finite<S, D>(values: &ArrayBase<S, D>) -> bool
98where
99 S: Data<Elem = f64>,
100 D: Dimension,
101{
102 values.iter().all(|v| v.is_finite())
103}
104
105#[inline]
110pub(crate) fn inf_norm<I: IntoIterator<Item = f64>>(values: I) -> f64 {
111 values.into_iter().fold(0.0_f64, |acc, x| acc.max(x.abs()))
112}
113
114const HESSIAN_CONDITION_TARGET: f64 = 1e10;
115const MAX_FACTORIZATION_ATTEMPTS: usize = 4;
116const MAX_SOLVE_RETRIES: usize = 8;
117
118#[derive(Default, Clone, Copy)]
119pub(crate) struct KahanSum {
120 sum: f64,
121 c: f64,
122}
123
124impl KahanSum {
125 #[inline]
126 pub(crate) fn add(&mut self, value: f64) {
127 let y = value - self.c;
128 let t = self.sum + y;
129 self.c = (t - self.sum) - y;
130 self.sum = t;
131 }
132
133 #[inline]
134 pub(crate) fn sum(self) -> f64 {
135 self.sum
136 }
137}
138
139pub(crate) fn matrix_inversewith_regularization(
152 matrix: &Array2<f64>,
153 label: &str,
154) -> Option<Array2<f64>> {
155 StableSolver::new(label).inversewith_regularization(matrix)
156}
157
158pub(crate) struct StableSolver<'a> {
159 label: &'a str,
160}
161
162impl<'a> StableSolver<'a> {
163 pub(crate) fn new(label: &'a str) -> Self {
164 Self { label }
165 }
166
167 pub(crate) fn factorize(
168 &self,
169 matrix: &Array2<f64>,
170 ) -> Result<crate::faer_ndarray::FaerSymmetricFactor, FaerLinalgError> {
171 let view = FaerArrayView::new(matrix);
172 factorize_symmetricwith_fallback(view.as_ref(), Side::Lower)
173 }
174
175 pub(crate) fn factorize_any<S>(
179 &self,
180 matrix: &ArrayBase<S, ndarray::Ix2>,
181 ) -> Result<crate::faer_ndarray::FaerSymmetricFactor, FaerLinalgError>
182 where
183 S: Data<Elem = f64>,
184 {
185 let view = FaerArrayView::new(matrix);
186 factorize_symmetricwith_fallback(view.as_ref(), Side::Lower)
187 }
188
189 pub(crate) fn inversewith_regularization(&self, matrix: &Array2<f64>) -> Option<Array2<f64>> {
190 let p = matrix.nrows();
191 if p == 0 || matrix.ncols() != p {
192 return None;
193 }
194
195 let mut planner = RidgePlanner::new(matrix);
196 let (factor, _, regularized) = self.factorize_with_ridge_plan(matrix, &mut planner)?;
197 let mut inv = Array2::<f64>::eye(p);
198 let mut invview = array2_to_matmut(&mut inv);
199 factor.solve_in_place(invview.as_mut());
200
201 if !inv.iter().all(|v| v.is_finite()) {
202 log::warn!("Non-finite inverse produced for {}", self.label);
203 return None;
204 }
205
206 for i in 0..p {
208 for j in (i + 1)..p {
209 let avg = 0.5 * (inv[[i, j]] + inv[[j, i]]);
210 inv[[i, j]] = avg;
211 inv[[j, i]] = avg;
212 }
213 }
214 assert_eq!(regularized.nrows(), p);
215 Some(inv)
216 }
217
218 pub(crate) fn solvevectorwithridge_retries(
219 &self,
220 matrix: &Array2<f64>,
221 rhs: &Array1<f64>,
222 baseridge: f64,
223 ) -> Option<Array1<f64>> {
224 let p = matrix.nrows();
225 if matrix.ncols() != p || rhs.len() != p {
226 return None;
227 }
228
229 let diag_scale = max_abs_diag(matrix);
242 for retry in 0..MAX_SOLVE_RETRIES {
243 let ridge = if baseridge > 0.0 {
244 baseridge * diag_scale * 10f64.powi(retry as i32)
245 } else {
246 0.0
247 };
248 let h = addridge(matrix, ridge);
249 let factor = match self.factorize(&h) {
250 Ok(f) => f,
251 Err(_) => continue,
252 };
253 let mut out = rhs.clone();
254 let mut out_mat = crate::faer_ndarray::array1_to_col_matmut(&mut out);
255 factor.solve_in_place(out_mat.as_mut());
256 if out.iter().all(|v| v.is_finite()) {
257 return Some(out);
258 }
259 }
260 None
261 }
262
263 pub(crate) fn solve_with_pseudoinverse_fallback(
293 &self,
294 matrix: &Array2<f64>,
295 rhs: &Array1<f64>,
296 baseridge: f64,
297 rel_tol: f64,
298 rank_tol: f64,
299 ) -> Option<Array1<f64>> {
300 use crate::faer_ndarray::FaerEigh;
301 use faer::Side;
302
303 let p = matrix.nrows();
304 if matrix.ncols() != p || rhs.len() != p {
305 return None;
306 }
307
308 let delta = self.solvevectorwithridge_retries(matrix, rhs, baseridge)?;
310
311 let matrix_delta = matrix.dot(&delta);
315 let residual_inf = matrix_delta
316 .iter()
317 .zip(rhs.iter())
318 .map(|(h, r)| (h - r).abs())
319 .fold(0.0_f64, f64::max);
320 let rhs_inf = rhs.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
321 let rel = residual_inf / (1.0 + rhs_inf);
322
323 if rel.is_finite() && rel < rel_tol {
324 return Some(delta);
325 }
326
327 let (eigvals, eigvecs) = matrix.eigh(Side::Lower).ok()?;
329 let max_abs_eig = eigvals.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
330 if !max_abs_eig.is_finite() || max_abs_eig <= 0.0 {
331 return Some(delta);
332 }
333 let cutoff = rank_tol * max_abs_eig;
334
335 let mut pseudo = Array1::<f64>::zeros(p);
336 let mut excluded = 0usize;
337 for k in 0..p {
338 let lam = eigvals[k];
339 if !lam.is_finite() || lam.abs() <= cutoff {
340 excluded += 1;
341 continue;
342 }
343 let u_k = eigvecs.column(k);
344 let proj = u_k.iter().zip(rhs.iter()).map(|(u, r)| u * r).sum::<f64>();
345 let scale = proj / lam;
346 for i in 0..p {
347 pseudo[i] += scale * u_k[i];
348 }
349 }
350
351 if !pseudo.iter().all(|v| v.is_finite()) {
352 return Some(delta);
353 }
354
355 log::debug!(
356 "[{}] pseudoinverse fallback engaged: rel = {:.3e} > rel_tol = {:.3e}, \
357 excluded {} of {} eigenvalues below cutoff = {:.3e} × max |λ| = {:.3e}",
358 self.label,
359 rel,
360 rel_tol,
361 excluded,
362 p,
363 rank_tol,
364 max_abs_eig,
365 );
366
367 Some(pseudo)
368 }
369
370 fn factorize_with_ridge_plan(
371 &self,
372 matrix: &Array2<f64>,
373 planner: &mut RidgePlanner,
374 ) -> Option<(crate::faer_ndarray::FaerSymmetricFactor, f64, Array2<f64>)> {
375 loop {
376 let ridge = planner.ridge();
377 let h_eff = addridge(matrix, ridge);
378 if let Ok(factor) = self.factorize(&h_eff) {
379 return Some((factor, ridge, h_eff));
380 }
381 if planner.attempts() >= MAX_FACTORIZATION_ATTEMPTS {
382 log::warn!(
383 "Failed to factorize {} after ridge {:.3e}",
384 self.label,
385 ridge
386 );
387 return None;
388 }
389 planner.bumpwith_matrix(matrix);
390 }
391 }
392}
393
394pub(crate) fn max_abs_diag(matrix: &Array2<f64>) -> f64 {
395 matrix
396 .diag()
397 .iter()
398 .copied()
399 .map(f64::abs)
400 .fold(0.0, f64::max)
401 .max(1.0)
402}
403
404pub(crate) fn row_mismatch_message(
405 y_len: usize,
406 w_len: usize,
407 x_rows: usize,
408 offset_len: usize,
409) -> Option<String> {
410 if y_len == w_len && y_len == x_rows && y_len == offset_len {
411 None
412 } else {
413 Some(format!(
414 "Row mismatch: y={}, w={}, X.rows={}, offset={}",
415 y_len, w_len, x_rows, offset_len
416 ))
417 }
418}
419
420pub(crate) fn predict_gam_dimension_mismatch_message(
421 x_rows: usize,
422 x_cols: usize,
423 beta_len: usize,
424 offset_len: usize,
425) -> Option<String> {
426 if x_cols != beta_len {
427 return Some(format!(
428 "predict_gam dimension mismatch: X has {} columns but beta has length {}",
429 x_cols, beta_len
430 ));
431 }
432 if x_rows != offset_len {
433 return Some(format!(
434 "predict_gam dimension mismatch: X has {} rows but offset has length {}",
435 x_rows, offset_len
436 ));
437 }
438 None::<String>
439}
440
441pub(crate) fn add_relative_diag_ridge(matrix: &mut Array2<f64>, scale: f64, floor: f64) -> f64 {
442 let ridge = scale
443 * matrix
444 .diag()
445 .iter()
446 .map(|&value| value.abs())
447 .fold(0.0, f64::max)
448 .max(floor);
449 for idx in 0..matrix.nrows() {
450 matrix[[idx, idx]] += ridge;
451 }
452 ridge
453}
454
455pub(crate) fn boundary_hit_indices(
456 values: ArrayView1<'_, f64>,
457 bound: f64,
458 tolerance: f64,
459) -> (Vec<usize>, Vec<usize>) {
460 let at_lower = values
461 .iter()
462 .enumerate()
463 .filter_map(|(idx, &value)| (value <= -bound + tolerance).then_some(idx))
464 .collect();
465 let at_upper = values
466 .iter()
467 .enumerate()
468 .filter_map(|(idx, &value)| (value >= bound - tolerance).then_some(idx))
469 .collect();
470 (at_lower, at_upper)
471}
472
473pub(crate) fn symmetric_spectrum_condition_number(matrix: &Array2<f64>) -> f64 {
482 matrix
483 .eigh(Side::Lower)
484 .ok()
485 .map(|(evals, _)| {
486 let min = evals
487 .iter()
488 .fold(f64::INFINITY, |acc, &value| acc.min(value));
489 let max = evals
490 .iter()
491 .fold(f64::NEG_INFINITY, |acc, &value| acc.max(value));
492 max / min.max(1e-12)
493 })
494 .unwrap_or(f64::NAN)
495}
496
497pub(crate) fn symmetric_extremes(matrix: &Array2<f64>) -> Option<(f64, f64)> {
501 let (evals, _) = matrix.eigh(Side::Lower).ok()?;
502 let mut min = f64::INFINITY;
503 let mut max = f64::NEG_INFINITY;
504 for &v in evals.iter() {
505 if v < min {
506 min = v;
507 }
508 if v > max {
509 max = v;
510 }
511 }
512 if min.is_finite() && max.is_finite() {
513 Some((min, max))
514 } else {
515 None
516 }
517}
518
519pub fn enforce_symmetry(matrix: &mut Array2<f64>) {
529 let n = matrix.nrows();
530 assert_eq!(n, matrix.ncols());
531 for i in 0..n {
532 for j in i + 1..n {
533 let avg = 0.5 * (matrix[[i, j]] + matrix[[j, i]]);
534 matrix[[i, j]] = avg;
535 matrix[[j, i]] = avg;
536 }
537 }
538}
539
540pub(crate) fn addridge(matrix: &Array2<f64>, ridge: f64) -> Array2<f64> {
541 if ridge <= 0.0 {
542 return matrix.clone();
543 }
544 let mut regularized = matrix.clone();
545 let n = regularized.nrows();
546 for i in 0..n {
547 regularized[[i, i]] += ridge;
548 }
549 regularized
550}
551
552pub(crate) fn boundary_hit_step_fraction(
553 slack: f64,
554 directional_slack_change: f64,
555 current_step_limit: f64,
556) -> Option<f64> {
557 if !slack.is_finite()
558 || !directional_slack_change.is_finite()
559 || !current_step_limit.is_finite()
560 || current_step_limit <= 0.0
561 {
562 return None;
563 }
564
565 let scale = slack
566 .abs()
567 .max(directional_slack_change.abs())
568 .max(current_step_limit.abs())
569 .max(1.0);
570 let directional_tol = (64.0 * f64::EPSILON * scale).max(1e-14);
571 if directional_slack_change >= -directional_tol {
572 return None;
573 }
574
575 let step = (slack / -directional_slack_change).max(0.0);
576 if step.is_finite() && step < current_step_limit {
577 return Some(step);
578 }
579 None
580}
581
582#[derive(Debug, Clone, Copy, PartialEq)]
583pub struct PcgSolveInfo {
584 pub iterations: usize,
585 pub converged: bool,
586 pub relative_residual_norm: f64,
587 pub initial_residual_norm: f64,
588 pub final_residual_norm: f64,
589 pub residual_reduction: f64,
590 pub condition_estimate: Option<f64>,
591}
592
593#[derive(Debug, Clone)]
594struct PcgDiagnostics {
595 residuals: Vec<f64>,
596 alpha: Vec<f64>,
597 beta: Vec<f64>,
598}
599
600impl PcgDiagnostics {
601 fn new(initial_residual_norm: f64) -> Self {
602 Self {
603 residuals: vec![initial_residual_norm],
604 alpha: Vec::new(),
605 beta: Vec::new(),
606 }
607 }
608
609 fn push_iteration(&mut self, alpha: f64, beta: Option<f64>, residual_norm: f64) {
610 self.alpha.push(alpha);
611 if let Some(beta) = beta {
612 self.beta.push(beta);
613 }
614 self.residuals.push(residual_norm);
615 }
616
617 fn condition_estimate(&self) -> Option<f64> {
618 let k = self.alpha.len();
631 if k == 0 || k > 256 {
632 return None;
633 }
634 let mut t = ndarray::Array2::<f64>::zeros((k, k));
635 for i in 0..k {
636 let alpha_i = self.alpha[i];
637 if !alpha_i.is_finite() || alpha_i <= 0.0 {
638 return None;
639 }
640 let mut diag = 1.0 / alpha_i;
641 if i > 0 {
642 let beta_prev = self.beta.get(i - 1).copied()?;
643 if !beta_prev.is_finite() || beta_prev < 0.0 {
644 return None;
645 }
646 diag += beta_prev / self.alpha[i - 1];
647 }
648 t[[i, i]] = diag;
649 if i + 1 < k {
650 let beta_i = self.beta.get(i).copied().unwrap_or(0.0);
651 if !beta_i.is_finite() || beta_i < 0.0 {
652 return None;
653 }
654 let off = beta_i.sqrt() / alpha_i;
655 t[[i, i + 1]] = off;
656 t[[i + 1, i]] = off;
657 }
658 }
659 let (evals, _) = t.eigh(Side::Lower).ok()?;
660 let mut lower = f64::INFINITY;
661 let mut upper = f64::NEG_INFINITY;
662 for &v in evals.iter() {
663 if !v.is_finite() {
664 return None;
665 }
666 if v < lower {
667 lower = v;
668 }
669 if v > upper {
670 upper = v;
671 }
672 }
673 if lower > 0.0 && upper > 0.0 {
674 Some(upper / lower)
675 } else {
676 None
677 }
678 }
679
680 fn info(
681 &self,
682 iterations: usize,
683 converged: bool,
684 rhs_norm: f64,
685 final_residual_norm: f64,
686 ) -> PcgSolveInfo {
687 let initial = self.residuals.first().copied().unwrap_or(rhs_norm);
688 PcgSolveInfo {
689 iterations,
690 converged,
691 relative_residual_norm: final_residual_norm / rhs_norm.max(1.0),
692 initial_residual_norm: initial,
693 final_residual_norm,
694 residual_reduction: if initial > 0.0 {
695 final_residual_norm / initial
696 } else {
697 0.0
698 },
699 condition_estimate: self.condition_estimate(),
700 }
701 }
702}
703
704pub fn solve_spd_pcg_with_info<F>(
705 apply: F,
706 rhs: &Array1<f64>,
707 preconditioner_diag: &Array1<f64>,
708 rel_tol: f64,
709 max_iter: usize,
710) -> Option<(Array1<f64>, PcgSolveInfo)>
711where
712 F: Fn(&Array1<f64>) -> Array1<f64>,
713{
714 let p = rhs.len();
715 if p == 0 || preconditioner_diag.len() != p || max_iter == 0 {
716 return None;
717 }
718 let rhs_norm = rhs.dot(rhs).sqrt();
719 if !rhs_norm.is_finite() {
720 return None;
721 }
722 if rhs_norm == 0.0 {
723 return Some((
724 Array1::<f64>::zeros(p),
725 PcgSolveInfo {
726 iterations: 0,
727 converged: true,
728 relative_residual_norm: 0.0,
729 initial_residual_norm: 0.0,
730 final_residual_norm: 0.0,
731 residual_reduction: 0.0,
732 condition_estimate: None,
733 },
734 ));
735 }
736
737 let tol = rel_tol.max(1e-12) * rhs_norm.max(1.0);
738 let mut x = Array1::<f64>::zeros(p);
739 let mut r = rhs.clone();
740 let mut diagnostics = PcgDiagnostics::new(rhs_norm);
741
742 let mut inv_m = Array1::<f64>::zeros(p);
754 let mut bad_diag = false;
755 for (slot, &m) in inv_m.iter_mut().zip(preconditioner_diag.iter()) {
756 if !m.is_finite() || m < 0.0 {
757 bad_diag = true;
763 break;
764 }
765 *slot = 1.0 / m.max(1e-12);
766 }
767 if bad_diag {
768 log::warn!(
769 "SPD PCG rejected: preconditioner diagonal contained a negative or \
770 non-finite entry; caller should route to a direct factorization \
771 or indefinite Krylov path."
772 );
773 return None;
774 }
775
776 let mut z = Array1::<f64>::zeros(p);
777 Zip::from(&mut z)
778 .and(&r)
779 .and(&inv_m)
780 .par_for_each(|zi, &ri, &im| {
781 *zi = ri * im;
782 });
783 let mut p_dir = z.clone();
784 let mut rz_old = r.dot(&z);
785 if !rz_old.is_finite() || rz_old <= 0.0 {
786 return None;
787 }
788
789 for iter in 0..max_iter {
790 let ap = apply(&p_dir);
791 if ap.len() != p {
792 return None;
793 }
794 let denom = p_dir.dot(&ap);
795 if !denom.is_finite() || denom <= 0.0 {
796 return None;
797 }
798 let alpha = rz_old / denom;
799 if !alpha.is_finite() {
800 return None;
801 }
802 x.scaled_add(alpha, &p_dir);
803 r.scaled_add(-alpha, &ap);
804 if (iter + 1) % 32 == 0 {
805 let ax = apply(&x);
809 if ax.len() != p {
810 return None;
811 }
812 r.assign(rhs);
813 r.scaled_add(-1.0, &ax);
814 }
815 let r_norm = r.dot(&r).sqrt();
816 if r_norm.is_finite() && r_norm <= tol {
817 diagnostics.push_iteration(alpha, None, r_norm);
818 return x
819 .iter()
820 .all(|v| v.is_finite())
821 .then_some((x, diagnostics.info(iter + 1, true, rhs_norm, r_norm)));
822 }
823 Zip::from(&mut z)
824 .and(&r)
825 .and(&inv_m)
826 .par_for_each(|zi, &ri, &im| {
827 *zi = ri * im;
828 });
829 let rz_new = r.dot(&z);
830 if !rz_new.is_finite() || rz_new <= 0.0 {
831 return None;
832 }
833 let beta = rz_new / rz_old;
834 if !beta.is_finite() {
835 return None;
836 }
837 diagnostics.push_iteration(alpha, Some(beta), r_norm);
838 Zip::from(&mut p_dir).and(&z).par_for_each(|pi, &zi| {
841 *pi = zi + beta * *pi;
842 });
843 rz_old = rz_new;
844 }
845 None
846}
847
848pub fn solve_spd_pcg<F>(
849 apply: F,
850 rhs: &Array1<f64>,
851 preconditioner_diag: &Array1<f64>,
852 rel_tol: f64,
853 max_iter: usize,
854) -> Option<Array1<f64>>
855where
856 F: Fn(&Array1<f64>) -> Array1<f64>,
857{
858 solve_spd_pcg_with_info(apply, rhs, preconditioner_diag, rel_tol, max_iter)
859 .map(|(solution, _)| solution)
860}
861
862pub fn solve_spd_pcg_with_info_into<F>(
868 apply: F,
869 rhs: &Array1<f64>,
870 preconditioner_diag: &Array1<f64>,
871 rel_tol: f64,
872 max_iter: usize,
873) -> Option<(Array1<f64>, PcgSolveInfo)>
874where
875 F: Fn(&Array1<f64>, &mut Array1<f64>),
876{
877 let p = rhs.len();
878 if p == 0 || preconditioner_diag.len() != p || max_iter == 0 {
879 return None;
880 }
881 let rhs_norm = rhs.dot(rhs).sqrt();
882 if !rhs_norm.is_finite() {
883 return None;
884 }
885 if rhs_norm == 0.0 {
886 return Some((
887 Array1::<f64>::zeros(p),
888 PcgSolveInfo {
889 iterations: 0,
890 converged: true,
891 relative_residual_norm: 0.0,
892 initial_residual_norm: 0.0,
893 final_residual_norm: 0.0,
894 residual_reduction: 0.0,
895 condition_estimate: None,
896 },
897 ));
898 }
899
900 let tol = rel_tol.max(1e-12) * rhs_norm.max(1.0);
901 let mut x = Array1::<f64>::zeros(p);
902 let mut r = rhs.clone();
903 let mut diagnostics = PcgDiagnostics::new(rhs_norm);
904
905 if preconditioner_diag
906 .iter()
907 .any(|&m| !m.is_finite() || m <= 0.0)
908 {
909 return None;
910 }
911 let mut inv_m = Array1::<f64>::zeros(p);
912 Zip::from(&mut inv_m)
913 .and(preconditioner_diag)
914 .par_for_each(|inv, &m| {
915 *inv = 1.0 / m.max(1e-12);
916 });
917
918 let mut z = Array1::<f64>::zeros(p);
919 Zip::from(&mut z)
920 .and(&r)
921 .and(&inv_m)
922 .par_for_each(|zi, &ri, &im| {
923 *zi = ri * im;
924 });
925 let mut p_dir = z.clone();
926 let mut rz_old = r.dot(&z);
927 if !rz_old.is_finite() || rz_old <= 0.0 {
928 return None;
929 }
930
931 let mut ap = Array1::<f64>::zeros(p);
933
934 for iter in 0..max_iter {
935 apply(&p_dir, &mut ap);
936 if ap.len() != p {
937 return None;
938 }
939 let denom = p_dir.dot(&ap);
940 if !denom.is_finite() || denom <= 0.0 {
941 return None;
942 }
943 let alpha = rz_old / denom;
944 if !alpha.is_finite() {
945 return None;
946 }
947 x.scaled_add(alpha, &p_dir);
948 r.scaled_add(-alpha, &ap);
949 if (iter + 1) % 32 == 0 {
950 apply(&x, &mut ap);
953 if ap.len() != p {
954 return None;
955 }
956 r.assign(rhs);
957 r.scaled_add(-1.0, &ap);
958 }
959 let r_norm = r.dot(&r).sqrt();
960 if r_norm.is_finite() && r_norm <= tol {
961 diagnostics.push_iteration(alpha, None, r_norm);
962 return x
963 .iter()
964 .all(|v| v.is_finite())
965 .then_some((x, diagnostics.info(iter + 1, true, rhs_norm, r_norm)));
966 }
967 Zip::from(&mut z)
968 .and(&r)
969 .and(&inv_m)
970 .par_for_each(|zi, &ri, &im| {
971 *zi = ri * im;
972 });
973 let rz_new = r.dot(&z);
974 if !rz_new.is_finite() || rz_new <= 0.0 {
975 return None;
976 }
977 let beta = rz_new / rz_old;
978 if !beta.is_finite() {
979 return None;
980 }
981 diagnostics.push_iteration(alpha, Some(beta), r_norm);
982 Zip::from(&mut p_dir).and(&z).par_for_each(|pi, &zi| {
983 *pi = zi + beta * *pi;
984 });
985 rz_old = rz_new;
986 }
987 None
988}
989
990#[derive(Clone)]
991pub(crate) struct RidgePlanner {
992 cond_estimate: Option<f64>,
993 ridge: f64,
994 attempts: usize,
995 scale: f64,
996}
997
998impl RidgePlanner {
999 pub(crate) fn new(matrix: &Array2<f64>) -> Self {
1000 let scale = max_abs_diag(matrix);
1001 let min_step = scale * 1e-10;
1002 Self {
1011 cond_estimate: None,
1012 ridge: min_step,
1013 attempts: 0,
1014 scale,
1015 }
1016 }
1017
1018 pub(crate) fn ridge(&self) -> f64 {
1019 self.ridge
1020 }
1021
1022 #[inline]
1023 fn estimate_conditionwithridge(&self, matrix: &Array2<f64>, ridge: f64) -> Option<f64> {
1024 let regularized = if ridge > 0.0 {
1025 addridge(matrix, ridge)
1026 } else {
1027 matrix.clone()
1028 };
1029 calculate_condition_number(®ularized)
1030 .ok()
1031 .filter(|c| c.is_finite() && *c > 0.0)
1032 }
1033
1034 pub(crate) fn bumpwith_matrix(&mut self, matrix: &Array2<f64>) {
1035 self.attempts += 1;
1036 let min_step = self.scale * 1e-10;
1037 let base = self.ridge.max(min_step);
1038
1039 let spd_floor = self.scale * 1e-8;
1046 let mut next_ridge = if let Some((lam_min, _lam_max)) = symmetric_extremes(matrix) {
1047 let deficit = (spd_floor - lam_min).max(0.0);
1051 let proposal = (1.5 * deficit).max(base * 1.5).max(min_step);
1052 proposal.min(base * 10.0)
1056 } else {
1057 f64::NAN
1058 };
1059
1060 if !next_ridge.is_finite() {
1064 let cond_now = self.estimate_conditionwithridge(matrix, base);
1065 self.cond_estimate = cond_now;
1066 next_ridge = if let Some(cond) = cond_now {
1067 let ratio = cond / HESSIAN_CONDITION_TARGET;
1068 let mut multiplier = if ratio > 1.0 {
1069 ratio.sqrt().clamp(1.5, 10.0)
1070 } else {
1071 (2.0 + self.attempts as f64).clamp(3.0, 10.0)
1072 };
1073 let mut proposal = base * multiplier;
1074 if let Some(cond_next) = self.estimate_conditionwithridge(matrix, proposal)
1075 && cond_next > cond * 0.9
1076 && ratio > 1.0
1077 {
1078 multiplier = (multiplier * 1.8).clamp(2.0, 10.0);
1079 proposal = base * multiplier;
1080 }
1081 proposal.max(min_step)
1082 } else if self.ridge <= 0.0 {
1083 min_step
1084 } else {
1085 (base * 10.0).max(min_step)
1086 };
1087 }
1088
1089 if !next_ridge.is_finite() || next_ridge <= 0.0 {
1090 next_ridge = self.scale;
1091 }
1092
1093 self.ridge = next_ridge;
1094 }
1095
1096 pub(crate) fn attempts(&self) -> usize {
1097 self.attempts
1098 }
1099}
1100
1101pub fn gaussian_weighted_ridge(
1108 x: ArrayView2<'_, f64>,
1109 y: ArrayView2<'_, f64>,
1110 penalty: ArrayView2<'_, f64>,
1111 weights: ArrayView1<'_, f64>,
1112 ridge_lambda: f64,
1113) -> Result<(Array2<f64>, Array2<f64>), String> {
1114 let n = x.nrows();
1115 let p = x.ncols();
1116 if n == 0 || p == 0 {
1117 return Err("X cannot be empty".to_string());
1118 }
1119 if y.nrows() != n {
1120 return Err(format!(
1121 "X/Y row mismatch: X has {n} rows but Y has {} rows",
1122 y.nrows()
1123 ));
1124 }
1125 if y.ncols() == 0 {
1126 return Err("Y must have at least one column".to_string());
1127 }
1128 if weights.len() != n {
1129 return Err(format!(
1130 "weights length mismatch: expected {n}, got {}",
1131 weights.len()
1132 ));
1133 }
1134 if penalty.nrows() != p || penalty.ncols() != p {
1135 return Err(format!(
1136 "penalty shape mismatch: expected {p}x{p}, got {}x{}",
1137 penalty.nrows(),
1138 penalty.ncols()
1139 ));
1140 }
1141 if !ridge_lambda.is_finite() || ridge_lambda < 0.0 {
1142 return Err(format!(
1143 "ridge_lambda must be finite and non-negative; got {ridge_lambda}"
1144 ));
1145 }
1146 if x.iter()
1147 .chain(y.iter())
1148 .chain(penalty.iter())
1149 .chain(weights.iter())
1150 .any(|value| !value.is_finite())
1151 {
1152 return Err("weighted ridge inputs must be finite".to_string());
1153 }
1154 if weights.iter().any(|value| *value < 0.0) {
1155 return Err("weights must be non-negative likelihood row weights".to_string());
1156 }
1157
1158 let mut wx = x.to_owned();
1159 let mut wy = y.to_owned();
1160 for i in 0..n {
1161 let wi = weights[i];
1162 wx.row_mut(i).iter_mut().for_each(|value| *value *= wi);
1163 wy.row_mut(i).iter_mut().for_each(|value| *value *= wi);
1164 }
1165 let mut system = x.t().dot(&wx);
1166 if ridge_lambda > 0.0 {
1167 system += &(penalty.to_owned() * ridge_lambda);
1168 }
1169 let rhs = x.t().dot(&wy);
1170 let factor =
1171 factorize_symmetricwith_fallback(FaerArrayView::new(&system).as_ref(), Side::Lower)
1172 .map_err(|err| format!("weighted ridge factorization failed: {err}"))?;
1173 let mut coefficients = rhs;
1174 let mut coefficients_view = array2_to_matmut(&mut coefficients);
1175 factor.solve_in_place(coefficients_view.as_mut());
1176 if coefficients.iter().any(|value| !value.is_finite()) {
1177 return Err("weighted ridge solve produced non-finite coefficients".to_string());
1178 }
1179 let fitted = x.dot(&coefficients);
1180 Ok((coefficients, fitted))
1181}
1182
1183pub fn gaussian_weighted_ridge_batch(
1190 x: ArrayView3<'_, f64>,
1191 y: ArrayView3<'_, f64>,
1192 penalty: ArrayView2<'_, f64>,
1193 weights: ArrayView2<'_, f64>,
1194 ridge_lambda: f64,
1195 row_counts: Option<ArrayView1<'_, usize>>,
1196) -> Result<(Array3<f64>, Array3<f64>), String> {
1197 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1198
1199 let (batch, n_max, p) = x.dim();
1200 let (y_batch, y_n_max, d) = y.dim();
1201 if batch == 0 || n_max == 0 || p == 0 {
1202 return Err("batched X must have non-empty K, N, and coefficient dimensions".to_string());
1203 }
1204 if y_batch != batch || y_n_max != n_max {
1205 return Err(format!(
1206 "batched X/Y shape mismatch: X is ({batch}, {n_max}, {p}) but Y is ({y_batch}, {y_n_max}, {d})"
1207 ));
1208 }
1209 if d == 0 {
1210 return Err("batched Y must have at least one output column".to_string());
1211 }
1212 if weights.nrows() != batch || weights.ncols() != n_max {
1213 return Err(format!(
1214 "batched weights shape mismatch: expected ({batch}, {n_max}), got ({}, {})",
1215 weights.nrows(),
1216 weights.ncols()
1217 ));
1218 }
1219 if penalty.nrows() != p || penalty.ncols() != p {
1220 return Err(format!(
1221 "penalty shape mismatch: expected {p}x{p}, got {}x{}",
1222 penalty.nrows(),
1223 penalty.ncols()
1224 ));
1225 }
1226 if !ridge_lambda.is_finite() || ridge_lambda < 0.0 {
1227 return Err(format!(
1228 "ridge_lambda must be finite and non-negative; got {ridge_lambda}"
1229 ));
1230 }
1231 if x.iter()
1232 .chain(y.iter())
1233 .chain(penalty.iter())
1234 .chain(weights.iter())
1235 .any(|value| !value.is_finite())
1236 {
1237 return Err("batched weighted ridge inputs must be finite".to_string());
1238 }
1239 if weights.iter().any(|value| *value < 0.0) {
1240 return Err("batched weights must be non-negative likelihood row weights".to_string());
1241 }
1242
1243 let active_rows: Vec<usize> = match row_counts {
1244 Some(counts) => {
1245 if counts.len() != batch {
1246 return Err(format!(
1247 "row_counts length mismatch: expected {batch}, got {}",
1248 counts.len()
1249 ));
1250 }
1251 counts.to_vec()
1252 }
1253 None => vec![n_max; batch],
1254 };
1255 for (b, &n_rows) in active_rows.iter().enumerate() {
1256 if n_rows > n_max {
1257 return Err(format!(
1258 "row_counts[{b}]={n_rows} exceeds padded row count {n_max}"
1259 ));
1260 }
1261 }
1262
1263 let results: Vec<Result<(usize, Array2<f64>, Array2<f64>), String>> = (0..batch)
1264 .into_par_iter()
1265 .map(|b| {
1266 let n_rows = active_rows[b];
1267 if n_rows == 0 {
1268 return Ok((
1269 b,
1270 Array2::<f64>::zeros((p, d)),
1271 Array2::<f64>::zeros((0, d)),
1272 ));
1273 }
1274 gaussian_weighted_ridge(
1275 x.slice(s![b, 0..n_rows, ..]),
1276 y.slice(s![b, 0..n_rows, ..]),
1277 penalty,
1278 weights.slice(s![b, 0..n_rows]),
1279 ridge_lambda,
1280 )
1281 .map(|(coefficients, fitted)| (b, coefficients, fitted))
1282 .map_err(|err| format!("batched weighted ridge fit {b} failed: {err}"))
1283 })
1284 .collect();
1285
1286 let mut coefficients = Array3::<f64>::zeros((batch, p, d));
1287 let mut fitted = Array3::<f64>::zeros((batch, n_max, d));
1288 for result in results {
1289 let (b, fit_coefficients, fit_fitted) = result?;
1290 coefficients
1291 .slice_mut(s![b, .., ..])
1292 .assign(&fit_coefficients);
1293 let n_rows = fit_fitted.nrows();
1294 if n_rows > 0 {
1295 fitted.slice_mut(s![b, 0..n_rows, ..]).assign(&fit_fitted);
1296 }
1297 }
1298 Ok((coefficients, fitted))
1299}
1300
1301pub fn block_penalty_rank_and_pinv(
1305 penalty: &Array2<f64>,
1306) -> Result<(usize, Array2<f64>), EstimationError> {
1307 let (eigs, vecs) = penalty.to_owned().eigh(Side::Lower).map_err(|_| {
1308 EstimationError::ModelIsIllConditioned {
1309 condition_number: f64::INFINITY,
1310 }
1311 })?;
1312 let max_abs = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
1313 let tol = (1.0e-10 * max_abs).max(1.0e-14);
1314 let mut rank = 0_usize;
1315 let mut scaled = Array2::<f64>::zeros(vecs.dim());
1316 for col in 0..eigs.len() {
1317 if eigs[col] > tol {
1318 rank += 1;
1319 for row in 0..vecs.nrows() {
1320 scaled[[row, col]] = vecs[[row, col]] / eigs[col];
1321 }
1322 }
1323 }
1324 Ok((rank, scaled.dot(&vecs.t())))
1325}
1326
1327pub fn invert_spd_with_ridge(
1330 matrix: &Array2<f64>,
1331 ridge_rel: f64,
1332) -> Result<Array2<f64>, EstimationError> {
1333 let n = matrix.nrows();
1334 let eye = Array2::<f64>::eye(n);
1335 let scale = (0..n).map(|i| matrix[[i, i]].abs()).fold(1.0_f64, f64::max);
1336 let ridges = [0.0, ridge_rel, 1.0e-10, 1.0e-8, 1.0e-6, 1.0e-4];
1337 for rel in ridges {
1338 let mut candidate = matrix.clone();
1339 if rel > 0.0 {
1340 for i in 0..n {
1341 candidate[[i, i]] += rel * scale;
1342 }
1343 }
1344 if let Ok(chol) = candidate.cholesky(Side::Lower) {
1345 return Ok(chol.solve_mat(&eye));
1346 }
1347 }
1348 Err(EstimationError::ModelIsIllConditioned {
1349 condition_number: f64::INFINITY,
1350 })
1351}
1352
1353pub fn solve_symmetric_vector_with_floor(
1357 matrix: &Array2<f64>,
1358 rhs: &Array1<f64>,
1359 ridge_rel: f64,
1360) -> Result<Array1<f64>, EstimationError> {
1361 let n = matrix.nrows();
1362 let mut sym = matrix.clone();
1363 symmetrize_in_place(&mut sym);
1364 let (eigs, vecs) =
1365 sym.eigh(Side::Lower)
1366 .map_err(|_| EstimationError::ModelIsIllConditioned {
1367 condition_number: f64::INFINITY,
1368 })?;
1369 let max_eig = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
1370 let floor = (ridge_rel * max_eig.max(1.0)).max(1.0e-12);
1371 let projected = vecs.t().dot(rhs);
1372 let mut scaled = Array1::<f64>::zeros(n);
1373 for i in 0..n {
1374 let denom = if eigs[i].abs() >= floor {
1375 eigs[i]
1376 } else if eigs[i].is_sign_negative() {
1377 -floor
1378 } else {
1379 floor
1380 };
1381 scaled[i] = projected[i] / denom;
1382 }
1383 let out = vecs.dot(&scaled);
1384 if out.iter().all(|value| value.is_finite()) {
1385 Ok(out)
1386 } else {
1387 Err(EstimationError::ModelIsIllConditioned {
1388 condition_number: f64::INFINITY,
1389 })
1390 }
1391}
1392
1393pub fn solve_dense_block_system(
1397 hessian: &Array2<f64>,
1398 rhs: &Array1<f64>,
1399 context: &str,
1400) -> Result<Array1<f64>, String> {
1401 let mut rhs2 = Array2::<f64>::zeros((rhs.len(), 1));
1402 for i in 0..rhs.len() {
1403 rhs2[[i, 0]] = rhs[i];
1404 }
1405 let factor =
1406 factorize_symmetricwith_fallback(FaerArrayView::new(hessian).as_ref(), Side::Lower)
1407 .map_err(|err| format!("{context} factorization failed: {err}"))?;
1408 {
1409 let mut rhs_view = array2_to_matmut(&mut rhs2);
1410 factor.solve_in_place(rhs_view.as_mut());
1411 }
1412 let mut out = Array1::<f64>::zeros(rhs.len());
1413 for i in 0..rhs.len() {
1414 out[i] = rhs2[[i, 0]];
1415 }
1416 if out.iter().any(|v| !v.is_finite()) {
1417 return Err(format!("{context} solve produced non-finite coefficients"));
1418 }
1419 Ok(out)
1420}
1421
1422#[cfg(test)]
1423mod ridge_tests {
1424 use super::{gaussian_weighted_ridge, gaussian_weighted_ridge_batch};
1425 use ndarray::{Array2, Array3, ArrayView2, array, s};
1426
1427 fn assert_close(lhs: ArrayView2<'_, f64>, rhs: ArrayView2<'_, f64>, tol: f64) {
1428 assert_eq!(lhs.dim(), rhs.dim());
1429 for ((i, j), value) in lhs.indexed_iter() {
1430 let diff = (*value - rhs[[i, j]]).abs();
1431 assert!(
1432 diff <= tol,
1433 "matrix mismatch at ({i}, {j}): lhs={}, rhs={}, diff={diff}",
1434 value,
1435 rhs[[i, j]]
1436 );
1437 }
1438 }
1439
1440 #[test]
1441 fn weighted_ridge_batch_matches_single_fit_on_active_rows() {
1442 let x = Array3::from_shape_vec(
1443 (2, 3, 2),
1444 vec![1.0, 0.0, 1.0, 1.0, 0.5, 1.0, 2.0, 1.0, 0.0, 1.0, 9.0, 9.0],
1445 )
1446 .unwrap();
1447 let y = Array3::from_shape_vec((2, 3, 1), vec![1.0, 2.0, 1.5, 2.5, -0.5, 99.0]).unwrap();
1448 let weights = array![[1.0, 0.5, 2.0], [1.0, 3.0, 0.0]];
1449 let penalty = Array2::eye(2);
1450 let row_counts = array![3_usize, 2_usize];
1451
1452 let (coefficients, fitted) = gaussian_weighted_ridge_batch(
1453 x.view(),
1454 y.view(),
1455 penalty.view(),
1456 weights.view(),
1457 0.25,
1458 Some(row_counts.view()),
1459 )
1460 .unwrap();
1461
1462 for b in 0..2 {
1463 let n = row_counts[b];
1464 let (expected_coefficients, expected_fitted) = gaussian_weighted_ridge(
1465 x.slice(s![b, 0..n, ..]),
1466 y.slice(s![b, 0..n, ..]),
1467 penalty.view(),
1468 weights.slice(s![b, 0..n]),
1469 0.25,
1470 )
1471 .unwrap();
1472 assert_close(
1473 coefficients.slice(s![b, .., ..]),
1474 expected_coefficients.view(),
1475 1.0e-10,
1476 );
1477 assert_close(
1478 fitted.slice(s![b, 0..n, ..]),
1479 expected_fitted.view(),
1480 1.0e-10,
1481 );
1482 }
1483 assert_eq!(fitted[[1, 2, 0]], 0.0);
1484 }
1485}
1486
1487#[cfg(test)]
1488mod tests {
1489 use super::{
1490 boundary_hit_step_fraction, solve_spd_pcg, solve_spd_pcg_with_info,
1491 solve_spd_pcg_with_info_into,
1492 };
1493 use ndarray::{Array1, array};
1494
1495 #[test]
1496 fn boundary_hit_step_fraction_ignores_near_tangential_direction() {
1497 let step = boundary_hit_step_fraction(1.0, -1e-16, 1.0);
1498 assert_eq!(step, None);
1499 }
1500
1501 #[test]
1502 fn boundary_hit_step_fraction_returns_first_finite_hit() {
1503 let step = boundary_hit_step_fraction(0.25, -0.5, 1.0);
1504 assert_eq!(step, Some(0.5));
1505 }
1506
1507 #[test]
1508 fn boundary_hit_step_fraction_rejects_non_finite_candidate() {
1509 let step = boundary_hit_step_fraction(1.0, f64::NEG_INFINITY, 1.0);
1510 assert_eq!(step, None);
1511 }
1512
1513 #[test]
1514 fn solve_spd_pcg_matches_reference_solution() {
1515 let h = array![[4.0, 1.0], [1.0, 3.0]];
1516 let b = array![1.0, 2.0];
1517 let m = Array1::from_vec(vec![4.0, 3.0]);
1518 let x = solve_spd_pcg(|v| h.dot(v), &b, &m, 1e-10, 20).expect("pcg solve");
1519 assert!((x[0] - 0.0909090909).abs() < 1e-8);
1520 assert!((x[1] - 0.6363636363).abs() < 1e-8);
1521 }
1522
1523 #[test]
1524 fn solve_spd_pcg_rejects_zero_iteration_budget() {
1525 let h = array![[4.0, 1.0], [1.0, 3.0]];
1526 let b = array![1.0, 2.0];
1527 let m = Array1::from_vec(vec![4.0, 3.0]);
1528 assert!(solve_spd_pcg_with_info(|v| h.dot(v), &b, &m, 1e-10, 0).is_none());
1529 assert!(solve_spd_pcg(|v| h.dot(v), &b, &m, 1e-10, 0).is_none());
1530 }
1531
1532 #[test]
1533 fn matrix_free_qp_beta_matches_dense_reference_with_diagnostics() {
1534 let h = array![
1540 [12.0, 2.0, 0.5, 0.0],
1541 [2.0, 9.0, 1.25, 0.25],
1542 [0.5, 1.25, 7.0, 1.5],
1543 [0.0, 0.25, 1.5, 5.0],
1544 ];
1545 let rhs = array![1.0, -0.5, 2.0, 0.75];
1546 let precond = h.diag().to_owned();
1547 let factor = super::StableSolver::new("synthetic dense reference")
1548 .factorize(&h)
1549 .expect("dense SPD reference");
1550 let mut dense = rhs.clone();
1551 let mut dense_view = crate::faer_ndarray::array1_to_col_matmut(&mut dense);
1552 factor.solve_in_place(dense_view.as_mut());
1553 let (pcg, info) = solve_spd_pcg_with_info_into(
1554 |v, out| {
1555 let prod = h.dot(v);
1556 out.assign(&prod);
1557 },
1558 &rhs,
1559 &precond,
1560 1e-12,
1561 4 * rhs.len(),
1562 )
1563 .expect("matrix-free pcg");
1564
1565 assert!(info.converged);
1566 assert!(info.iterations <= 4 * rhs.len());
1567 assert!(info.final_residual_norm < info.initial_residual_norm);
1568 assert!(info.residual_reduction < 1e-10);
1569 assert!(info.condition_estimate.is_some());
1570 for (reference, actual) in dense.iter().zip(pcg.iter()) {
1571 assert!(
1572 (reference - actual).abs() < 1e-10,
1573 "dense={reference} pcg={actual}"
1574 );
1575 }
1576 }
1577
1578 #[test]
1579 fn solve_spd_pcg_with_info_into_rejects_zero_iteration_budget() {
1580 let h = array![[4.0, 1.0], [1.0, 3.0]];
1581 let b = array![1.0, 2.0];
1582 let m = Array1::from_vec(vec![4.0, 3.0]);
1583 assert!(
1584 solve_spd_pcg_with_info_into(
1585 |v, out| {
1586 let prod = h.dot(v);
1587 out.assign(&prod);
1588 },
1589 &b,
1590 &m,
1591 1e-10,
1592 0,
1593 )
1594 .is_none()
1595 );
1596 }
1597}