use crate::error::OptimizeError;
use scirs2_core::ndarray::{Array1, Array2};
pub(crate) const GPU_NEWTON_THRESHOLD: usize = 4096;
pub(crate) enum GpuNewtonResult {
Done(Array1<f64>),
FallbackToCpu,
}
pub(crate) fn try_gpu_newton_cg_solve(
g: &Array1<f64>,
hess: &Array2<f64>,
tol: f64,
use_gpu: bool,
threshold: usize,
) -> GpuNewtonResult {
let n = g.len();
if !use_gpu || n < threshold {
return GpuNewtonResult::FallbackToCpu;
}
do_gpu_newton_dispatch(g, hess, tol)
}
fn do_gpu_newton_dispatch(g: &Array1<f64>, hess: &Array2<f64>, tol: f64) -> GpuNewtonResult {
#[cfg(feature = "gpu")]
{
use scirs2_core::array_protocol::gpu_ndarray::{global_context, is_gpu_available};
if !is_gpu_available() {
return GpuNewtonResult::FallbackToCpu;
}
let ctx = match global_context() {
Some(c) => c,
None => return GpuNewtonResult::FallbackToCpu,
};
match gpu_newton_cg_inner(g, hess, tol, &ctx) {
Ok(result) => result,
Err(_) => GpuNewtonResult::FallbackToCpu,
}
}
#[cfg(not(feature = "gpu"))]
{
let _ = (g, hess, tol);
GpuNewtonResult::FallbackToCpu
}
}
#[cfg(feature = "gpu")]
fn gpu_newton_cg_inner(
g: &Array1<f64>,
hess: &Array2<f64>,
tol: f64,
ctx: &std::sync::Arc<scirs2_core::gpu::backends::WebGPUContext>,
) -> Result<GpuNewtonResult, OptimizeError> {
use scirs2_core::array_protocol::gpu_ndarray::GpuNdarray;
let n = g.len();
let upload_vec = |arr: &Array1<f64>| -> Result<GpuNdarray<f32>, OptimizeError> {
let data_f32: Vec<f32> = arr.iter().map(|&v| v as f32).collect();
GpuNdarray::from_ndarray_data(&data_f32, vec![n], std::sync::Arc::clone(ctx))
.map_err(|e| OptimizeError::ComputationError(format!("GPU vec upload: {e}")))
};
let upload_col = |arr: &Array1<f64>| -> Result<GpuNdarray<f32>, OptimizeError> {
let data_f32: Vec<f32> = arr.iter().map(|&v| v as f32).collect();
GpuNdarray::from_ndarray_data(&data_f32, vec![n, 1], std::sync::Arc::clone(ctx))
.map_err(|e| OptimizeError::ComputationError(format!("GPU col upload: {e}")))
};
let upload_hess = |h: &Array2<f64>| -> Result<GpuNdarray<f32>, OptimizeError> {
let data_f32: Vec<f32> = h.iter().map(|&v| v as f32).collect();
GpuNdarray::from_ndarray_data(&data_f32, vec![n, n], std::sync::Arc::clone(ctx))
.map_err(|e| OptimizeError::ComputationError(format!("GPU hess upload: {e}")))
};
let gpu_dot = |a: &GpuNdarray<f32>, b: &GpuNdarray<f32>| -> Result<f64, OptimizeError> {
a.dot_gpu(b)
.map(f64::from)
.map_err(|e| OptimizeError::ComputationError(format!("GPU dot: {e}")))
};
let gpu_hessian_vector = |h_gpu: &GpuNdarray<f32>,
p_col: &GpuNdarray<f32>|
-> Result<GpuNdarray<f32>, OptimizeError> {
let result_2d = h_gpu
.matmul(p_col)
.map_err(|e| OptimizeError::ComputationError(format!("GPU matmul: {e}")))?;
let flat = result_2d
.to_vec()
.map_err(|e| OptimizeError::ComputationError(format!("GPU download hp: {e}")))?;
if flat.len() != n {
return Err(OptimizeError::ComputationError(format!(
"GPU matmul output length mismatch: got {}, expected {n}",
flat.len()
)));
}
GpuNdarray::from_ndarray_data(&flat, vec![n], std::sync::Arc::clone(ctx))
.map_err(|e| OptimizeError::ComputationError(format!("GPU re-upload hp: {e}")))
};
let h_gpu = upload_hess(hess)?;
let mut x = Array1::zeros(n);
if g.dot(g) < 1e-10 {
return Ok(GpuNewtonResult::Done(x));
}
let neg_g: Array1<f64> = -g; let mut r = neg_g.clone(); let mut p_vec = r.clone();
let r0_norm = r.dot(&r).sqrt();
let cg_tol = f64::min(0.1, r0_norm * tol);
let max_cg_iters = 2 * n;
for _ in 0..max_cg_iters {
let p_col = upload_col(&p_vec)?;
let hp_1d = gpu_hessian_vector(&h_gpu, &p_col)?;
let p_1d = upload_vec(&p_vec)?;
let php = gpu_dot(&p_1d, &hp_1d)?;
if php <= 1e-10 {
return Ok(GpuNewtonResult::Done(x));
}
let r_1d = upload_vec(&r)?;
let rr = gpu_dot(&r_1d, &r_1d)?;
let alpha = rr / php;
x = &x + &(&p_vec * alpha);
let hp_host: Vec<f64> = hp_1d
.to_vec()
.map_err(|e| OptimizeError::ComputationError(format!("GPU download hp_host: {e}")))?
.into_iter()
.map(f64::from)
.collect();
let hp_arr = Array1::from_vec(hp_host);
let r_new = &r - &(&hp_arr * alpha);
if r_new.dot(&r_new).sqrt() < cg_tol {
r = r_new;
break;
}
let r_new_norm_sq = r_new.dot(&r_new);
let beta = r_new_norm_sq / rr;
p_vec = &r_new + &(&p_vec * beta);
r = r_new;
}
let _ = r; Ok(GpuNewtonResult::Done(x))
}