use crate::gpu::arrow_schur::{ArrowSchurGpuFailure, solve_arrow_newton_step};
use crate::solver::arrow_schur::{
ArrowSchurError, ArrowSchurSystem, ArrowSolveOptions, ArrowSolverMode,
};
use ndarray::Array1;
pub fn solve_arrow_newton_step_gpu(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
match solve_arrow_newton_step(sys, ridge_t, ridge_beta) {
Ok(solution) => Ok((solution.delta_t, solution.delta_beta)),
Err(ArrowSchurGpuFailure::Unavailable) => {
sys.solve_with_options(ridge_t, ridge_beta, &ArrowSolveOptions::automatic(sys.k))
.map(|(dt, db, _diag)| (dt, db))
}
Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem { .. }) => {
let mut opts = ArrowSolveOptions::automatic(sys.k);
opts.mode = ArrowSolverMode::InexactPCG;
sys.solve_with_options(ridge_t, ridge_beta, &opts)
.map(|(dt, db, _diag)| (dt, db))
}
Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
Err(ArrowSchurError::PerRowFactorFailed {
row,
reason: format!("GPU Cholesky factor failed; suggested ridge bump {bump:.3e}"),
})
}
Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
Err(ArrowSchurError::SchurFactorFailed { reason })
}
}
}