use crate::optimize::roots::impl_generic::levenberg_marquardt_impl;
use crate::optimize::roots::traits::LevenbergMarquardtAlgorithms;
use crate::optimize::roots::{RootOptions, RootTensorResult};
use numr::error::Result;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;
impl LevenbergMarquardtAlgorithms<CudaRuntime> for CudaClient {
fn levenberg_marquardt<F>(
&self,
f: F,
x0: &Tensor<CudaRuntime>,
options: &RootOptions,
) -> Result<RootTensorResult<CudaRuntime>>
where
F: Fn(&Tensor<CudaRuntime>) -> Result<Tensor<CudaRuntime>>,
{
let result = levenberg_marquardt_impl(self, f, x0, options).map_err(|e| {
numr::error::Error::backend_limitation("cuda", "levenberg_marquardt", e.to_string())
})?;
Ok(RootTensorResult {
x: result.x,
fun: result.fun,
iterations: result.iterations,
residual_norm: result.residual_norm,
converged: result.converged,
})
}
}