use crate::error::OptimizeError;
use scirs2_core::ndarray::Array1;
pub(crate) const GPU_CG_THRESHOLD: usize = 4096;
pub(crate) enum GpuCgResult {
Done(f64, f64, Array1<f64>),
FallbackToCpu,
}
pub(crate) fn try_gpu_cg_update(
g: &Array1<f64>,
g_new: &Array1<f64>,
p: &Array1<f64>,
use_gpu: bool,
threshold: usize,
) -> GpuCgResult {
let n = g.len();
if !use_gpu || n < threshold {
return GpuCgResult::FallbackToCpu;
}
do_gpu_cg_dispatch(g, g_new, p)
}
fn do_gpu_cg_dispatch(g: &Array1<f64>, g_new: &Array1<f64>, p: &Array1<f64>) -> GpuCgResult {
#[cfg(feature = "gpu")]
{
use scirs2_core::array_protocol::gpu_ndarray::{global_context, is_gpu_available};
if !is_gpu_available() {
return GpuCgResult::FallbackToCpu;
}
let ctx = match global_context() {
Some(c) => c,
None => return GpuCgResult::FallbackToCpu,
};
match gpu_cg_update_inner(g, g_new, p, &ctx) {
Ok(result) => result,
Err(_) => GpuCgResult::FallbackToCpu,
}
}
#[cfg(not(feature = "gpu"))]
{
let _ = (g, g_new, p);
GpuCgResult::FallbackToCpu
}
}
#[cfg(feature = "gpu")]
fn gpu_cg_update_inner(
g: &Array1<f64>,
g_new: &Array1<f64>,
p: &Array1<f64>,
ctx: &std::sync::Arc<scirs2_core::gpu::backends::WebGPUContext>,
) -> Result<GpuCgResult, OptimizeError> {
use scirs2_core::array_protocol::gpu_ndarray::GpuNdarray;
let n = g.len();
let upload = |arr: &Array1<f64>| -> Result<GpuNdarray<f32>, OptimizeError> {
let data_f32: Vec<f32> = arr.iter().map(|&v| v as f32).collect();
GpuNdarray::from_ndarray_data(&data_f32, vec![n], std::sync::Arc::clone(ctx))
.map_err(|e| OptimizeError::ComputationError(format!("GPU upload: {e}")))
};
let gpu_dot = |a: &GpuNdarray<f32>, b: &GpuNdarray<f32>| -> Result<f64, OptimizeError> {
a.dot_gpu(b)
.map(f64::from)
.map_err(|e| OptimizeError::ComputationError(format!("GPU dot: {e}")))
};
let g_gpu = upload(g)?;
let g_new_gpu = upload(g_new)?;
let p_gpu = upload(p)?;
let g_new_norm_sq = gpu_dot(&g_new_gpu, &g_new_gpu)?;
let g_norm_sq = gpu_dot(&g_gpu, &g_gpu)?;
let beta = if g_norm_sq < 1e-10 {
0.0_f64
} else {
g_new_norm_sq / g_norm_sq
};
let scaled_p = p_gpu
.multiply_by_scalar_f32(beta as f32)
.map_err(|e| OptimizeError::ComputationError(format!("GPU scale p: {e}")))?;
let neg_g_new = g_new_gpu
.multiply_by_scalar_f32(-1.0_f32)
.map_err(|e| OptimizeError::ComputationError(format!("GPU negate g_new: {e}")))?;
let new_p_gpu = neg_g_new
.add(&scaled_p)
.map_err(|e| OptimizeError::ComputationError(format!("GPU direction add: {e}")))?;
let new_p_host = new_p_gpu
.to_vec()
.map_err(|e| OptimizeError::ComputationError(format!("GPU download: {e}")))?;
if new_p_host.len() != n {
return Err(OptimizeError::ComputationError(format!(
"GPU CG result length mismatch: got {}, expected {n}",
new_p_host.len()
)));
}
let new_p = Array1::from_iter(new_p_host.into_iter().map(f64::from));
Ok(GpuCgResult::Done(g_new_norm_sq, beta, new_p))
}