solvr 0.2.0-beta.2

Advanced computing library for real-world problem solving - optimization, differential equations, interpolation, statistics, and more
Documentation
//! WebGPU implementations of multivariate minimization algorithms.

mod newton_cg;
mod trust_exact;
mod trust_krylov;
mod trust_ncg;

use numr::error::Result;
use numr::runtime::wgpu::{WgpuClient, WgpuRuntime};
use numr::tensor::Tensor;

use crate::optimize::error::OptimizeResult;
use crate::optimize::minimize::MinimizeOptions;

use super::impl_generic::{
    LbfgsOptions, TensorMinimizeResult, bfgs_impl, conjugate_gradient_impl, lbfgs_impl,
    nelder_mead_impl, powell_impl,
};
use crate::optimize::impl_generic::scalar::{
    bisect_impl, brentq_impl, minimize_scalar_brent_impl, newton_impl,
};
use crate::optimize::scalar::{MinimizeResult, RootResult, ScalarOptions};

impl crate::optimize::OptimizationAlgorithms<WgpuRuntime> for WgpuClient {
    // Scalar root finding methods
    fn bisect<F>(&self, f: F, a: f64, b: f64, options: &ScalarOptions) -> OptimizeResult<RootResult>
    where
        F: Fn(f64) -> f64,
    {
        bisect_impl(f, a, b, options)
    }

    fn brentq<F>(&self, f: F, a: f64, b: f64, options: &ScalarOptions) -> OptimizeResult<RootResult>
    where
        F: Fn(f64) -> f64,
    {
        brentq_impl(f, a, b, options)
    }

    fn newton<F, DF>(
        &self,
        f: F,
        df: DF,
        x0: f64,
        options: &ScalarOptions,
    ) -> OptimizeResult<RootResult>
    where
        F: Fn(f64) -> f64,
        DF: Fn(f64) -> f64,
    {
        newton_impl(f, df, x0, options)
    }

    fn minimize_scalar_brent<F>(
        &self,
        f: F,
        bracket: Option<(f64, f64, f64)>,
        options: &ScalarOptions,
    ) -> OptimizeResult<MinimizeResult>
    where
        F: Fn(f64) -> f64,
    {
        minimize_scalar_brent_impl(f, bracket, options)
    }

    // Multivariate minimization methods
    fn bfgs<F>(
        &self,
        f: F,
        x0: &Tensor<WgpuRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<WgpuRuntime>>
    where
        F: Fn(&Tensor<WgpuRuntime>) -> Result<f64>,
    {
        bfgs_impl(self, f, x0, options)
    }

    fn lbfgs<F>(
        &self,
        f: F,
        x0: &Tensor<WgpuRuntime>,
        options: &LbfgsOptions,
    ) -> OptimizeResult<TensorMinimizeResult<WgpuRuntime>>
    where
        F: Fn(&Tensor<WgpuRuntime>) -> Result<f64>,
    {
        lbfgs_impl(self, f, x0, options)
    }

    fn nelder_mead<F>(
        &self,
        f: F,
        x0: &Tensor<WgpuRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<WgpuRuntime>>
    where
        F: Fn(&Tensor<WgpuRuntime>) -> Result<f64>,
    {
        nelder_mead_impl(self, f, x0, options)
    }

    fn powell<F>(
        &self,
        f: F,
        x0: &Tensor<WgpuRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<WgpuRuntime>>
    where
        F: Fn(&Tensor<WgpuRuntime>) -> Result<f64>,
    {
        powell_impl(self, f, x0, options)
    }

    fn conjugate_gradient<F>(
        &self,
        f: F,
        x0: &Tensor<WgpuRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<WgpuRuntime>>
    where
        F: Fn(&Tensor<WgpuRuntime>) -> Result<f64>,
    {
        conjugate_gradient_impl(self, f, x0, options)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::optimize::OptimizationAlgorithms;
    use numr::runtime::wgpu::WgpuDevice;

    fn setup() -> Option<(WgpuDevice, WgpuClient)> {
        let device = WgpuDevice::new(0);
        let client = WgpuClient::new(device.clone()).ok()?;
        Some((device, client))
    }

    #[test]
    fn test_bfgs_wgpu() {
        let Some((device, client)) = setup() else {
            eprintln!("Skipping WebGPU test: no device");
            return;
        };
        let x0 = Tensor::<WgpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);

        // BFGS internally uses F64 for Hessian (eye), which wgpu doesn't support.
        let result = match client.bfgs(
            |x| {
                let data: Vec<f32> = x.to_vec();
                Ok(data.iter().map(|xi| xi * xi).sum::<f32>() as f64)
            },
            &x0,
            &MinimizeOptions::default(),
        ) {
            Ok(r) => r,
            Err(e) => {
                eprintln!("Skipping test_bfgs_wgpu: {e}");
                return;
            }
        };

        assert!(result.converged);
        assert!(result.fun < 1e-3);
    }
}