solvr 0.2.0-beta.2

Advanced computing library for real-world problem solving - optimization, differential equations, interpolation, statistics, and more
Documentation
//! CUDA implementations of multivariate minimization algorithms.

mod newton_cg;
mod trust_exact;
mod trust_krylov;
mod trust_ncg;

use numr::error::Result;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;

use crate::optimize::error::OptimizeResult;
use crate::optimize::minimize::MinimizeOptions;

use super::impl_generic::{
    LbfgsOptions, TensorMinimizeResult, bfgs_impl, conjugate_gradient_impl, lbfgs_impl,
    nelder_mead_impl, powell_impl,
};
use crate::optimize::impl_generic::scalar::{
    bisect_impl, brentq_impl, minimize_scalar_brent_impl, newton_impl,
};
use crate::optimize::scalar::{MinimizeResult, RootResult, ScalarOptions};

impl crate::optimize::OptimizationAlgorithms<CudaRuntime> for CudaClient {
    // Scalar root finding methods
    fn bisect<F>(&self, f: F, a: f64, b: f64, options: &ScalarOptions) -> OptimizeResult<RootResult>
    where
        F: Fn(f64) -> f64,
    {
        bisect_impl(f, a, b, options)
    }

    fn brentq<F>(&self, f: F, a: f64, b: f64, options: &ScalarOptions) -> OptimizeResult<RootResult>
    where
        F: Fn(f64) -> f64,
    {
        brentq_impl(f, a, b, options)
    }

    fn newton<F, DF>(
        &self,
        f: F,
        df: DF,
        x0: f64,
        options: &ScalarOptions,
    ) -> OptimizeResult<RootResult>
    where
        F: Fn(f64) -> f64,
        DF: Fn(f64) -> f64,
    {
        newton_impl(f, df, x0, options)
    }

    fn minimize_scalar_brent<F>(
        &self,
        f: F,
        bracket: Option<(f64, f64, f64)>,
        options: &ScalarOptions,
    ) -> OptimizeResult<MinimizeResult>
    where
        F: Fn(f64) -> f64,
    {
        minimize_scalar_brent_impl(f, bracket, options)
    }

    // Multivariate minimization methods
    fn bfgs<F>(
        &self,
        f: F,
        x0: &Tensor<CudaRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<CudaRuntime>>
    where
        F: Fn(&Tensor<CudaRuntime>) -> Result<f64>,
    {
        bfgs_impl(self, f, x0, options)
    }

    fn lbfgs<F>(
        &self,
        f: F,
        x0: &Tensor<CudaRuntime>,
        options: &LbfgsOptions,
    ) -> OptimizeResult<TensorMinimizeResult<CudaRuntime>>
    where
        F: Fn(&Tensor<CudaRuntime>) -> Result<f64>,
    {
        lbfgs_impl(self, f, x0, options)
    }

    fn nelder_mead<F>(
        &self,
        f: F,
        x0: &Tensor<CudaRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<CudaRuntime>>
    where
        F: Fn(&Tensor<CudaRuntime>) -> Result<f64>,
    {
        nelder_mead_impl(self, f, x0, options)
    }

    fn powell<F>(
        &self,
        f: F,
        x0: &Tensor<CudaRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<CudaRuntime>>
    where
        F: Fn(&Tensor<CudaRuntime>) -> Result<f64>,
    {
        powell_impl(self, f, x0, options)
    }

    fn conjugate_gradient<F>(
        &self,
        f: F,
        x0: &Tensor<CudaRuntime>,
        options: &MinimizeOptions,
    ) -> OptimizeResult<TensorMinimizeResult<CudaRuntime>>
    where
        F: Fn(&Tensor<CudaRuntime>) -> Result<f64>,
    {
        conjugate_gradient_impl(self, f, x0, options)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::optimize::OptimizationAlgorithms;
    use numr::runtime::cuda::CudaDevice;

    fn setup() -> Option<(CudaDevice, CudaClient)> {
        let device = CudaDevice::new(0);
        let client = CudaClient::new(device.clone()).ok()?;
        Some((device, client))
    }

    #[test]
    fn test_bfgs_cuda() {
        let Some((device, client)) = setup() else {
            eprintln!("Skipping CUDA test: no device");
            return;
        };
        let x0 = Tensor::<CudaRuntime>::from_slice(&[1.0, 1.0], &[2], &device);

        let result = client
            .bfgs(
                |x| {
                    let data: Vec<f64> = x.to_vec();
                    Ok(data.iter().map(|xi| xi * xi).sum())
                },
                &x0,
                &MinimizeOptions::default(),
            )
            .unwrap();

        assert!(result.converged);
        assert!(result.fun < 1e-6);
    }
}