use crate::error::OptimizeError;
use scirs2_core::ndarray::Array1;
pub(crate) const GPU_LBFGS_THRESHOLD: usize = 4096;
pub(crate) enum GpuDispatchResult {
Done(Array1<f64>),
FallbackToCpu,
}
pub(crate) fn try_gpu_two_loop_recursion(
s_vectors: &[Array1<f64>],
y_vectors: &[Array1<f64>],
rho_values: &[f64],
gradient: &Array1<f64>,
use_gpu: bool,
threshold: usize,
) -> GpuDispatchResult {
let n = gradient.len();
if !use_gpu || n < threshold {
return GpuDispatchResult::FallbackToCpu;
}
do_gpu_dispatch(s_vectors, y_vectors, rho_values, gradient)
}
fn do_gpu_dispatch(
s_vectors: &[Array1<f64>],
y_vectors: &[Array1<f64>],
rho_values: &[f64],
gradient: &Array1<f64>,
) -> GpuDispatchResult {
#[cfg(feature = "gpu")]
{
use scirs2_core::array_protocol::gpu_ndarray::{global_context, is_gpu_available};
if !is_gpu_available() {
return GpuDispatchResult::FallbackToCpu;
}
let ctx = match global_context() {
Some(c) => c,
None => return GpuDispatchResult::FallbackToCpu,
};
match gpu_two_loop_inner(s_vectors, y_vectors, rho_values, gradient, &ctx) {
Ok(dir) => GpuDispatchResult::Done(dir),
Err(_) => GpuDispatchResult::FallbackToCpu,
}
}
#[cfg(not(feature = "gpu"))]
{
let _ = (s_vectors, y_vectors, rho_values, gradient);
GpuDispatchResult::FallbackToCpu
}
}
#[cfg(feature = "gpu")]
fn gpu_two_loop_inner(
s_vectors: &[Array1<f64>],
y_vectors: &[Array1<f64>],
rho_values: &[f64],
gradient: &Array1<f64>,
ctx: &std::sync::Arc<scirs2_core::gpu::backends::WebGPUContext>,
) -> Result<Array1<f64>, OptimizeError> {
use scirs2_core::array_protocol::gpu_ndarray::GpuNdarray;
let n = gradient.len();
let m = s_vectors.len();
let upload = |data: &[f64]| -> Result<GpuNdarray<f32>, OptimizeError> {
let data_f32: Vec<f32> = data.iter().map(|&v| v as f32).collect();
let len = data_f32.len();
GpuNdarray::from_ndarray_data(&data_f32, vec![len], std::sync::Arc::clone(ctx))
.map_err(|e| OptimizeError::ComputationError(format!("GPU upload: {e}")))
};
let dot_gpu = |a: &GpuNdarray<f32>, b: &GpuNdarray<f32>| -> Result<f64, OptimizeError> {
a.dot_gpu(b)
.map(|v| f64::from(v))
.map_err(|e| OptimizeError::ComputationError(format!("GPU dot: {e}")))
};
let mut s_gpu: Vec<GpuNdarray<f32>> = Vec::with_capacity(m);
let mut y_gpu: Vec<GpuNdarray<f32>> = Vec::with_capacity(m);
for sv in s_vectors {
s_gpu.push(upload(sv.as_slice().ok_or_else(|| {
OptimizeError::ComputationError("s_vector not contiguous".into())
})?)?);
}
for yv in y_vectors {
y_gpu.push(upload(yv.as_slice().ok_or_else(|| {
OptimizeError::ComputationError("y_vector not contiguous".into())
})?)?);
}
let grad_slice: Vec<f64> = gradient.iter().copied().collect();
let mut q_gpu = upload(&grad_slice)?;
let mut alpha_values: Vec<f64> = Vec::with_capacity(m);
for i in (0..m).rev() {
let rho_i = rho_values[i];
let alpha_i = rho_i * dot_gpu(&s_gpu[i], &q_gpu)?;
alpha_values.push(alpha_i);
let scaled_y = y_gpu[i]
.multiply_by_scalar_f32(alpha_i as f32)
.map_err(|e| OptimizeError::ComputationError(format!("GPU scale y: {e}")))?;
q_gpu = q_gpu
.subtract(&scaled_y)
.map_err(|e| OptimizeError::ComputationError(format!("GPU q subtract: {e}")))?;
}
let mut r_gpu = if m > 0 {
let ys = dot_gpu(&y_gpu[m - 1], &s_gpu[m - 1])?;
let yy = dot_gpu(&y_gpu[m - 1], &y_gpu[m - 1])?;
if ys > 0.0 && yy > 0.0 {
let gamma = (ys / yy) as f32;
q_gpu
.multiply_by_scalar_f32(gamma)
.map_err(|e| OptimizeError::ComputationError(format!("GPU gamma scale: {e}")))?
} else {
q_gpu
}
} else {
q_gpu
};
for i in 0..m {
let alpha_i = alpha_values[m - 1 - i];
let rho_i = rho_values[i];
let beta_i = rho_i * dot_gpu(&y_gpu[i], &r_gpu)?;
let coeff = (alpha_i - beta_i) as f32;
let scaled_s = s_gpu[i]
.multiply_by_scalar_f32(coeff)
.map_err(|e| OptimizeError::ComputationError(format!("GPU scale s: {e}")))?;
r_gpu = r_gpu
.add(&scaled_s)
.map_err(|e| OptimizeError::ComputationError(format!("GPU r add: {e}")))?;
}
let r_host = r_gpu
.to_vec()
.map_err(|e| OptimizeError::ComputationError(format!("GPU download: {e}")))?;
if r_host.len() != n {
return Err(OptimizeError::ComputationError(format!(
"GPU result length mismatch: got {}, expected {n}",
r_host.len()
)));
}
Ok(Array1::from_iter(
r_host.into_iter().map(|v| -(f64::from(v))),
))
}