use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::impl_generic::utils::SINGULAR_THRESHOLD;
use crate::optimize::roots::RootOptions;
use super::TensorRootResult;
use super::newton::finite_difference_jacobian_tensor;
use crate::optimize::impl_generic::utils::tensor_norm;
pub fn levenberg_marquardt_impl<R, C, F>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &RootOptions,
) -> OptimizeResult<TensorRootResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
let n = x0.shape()[0];
if n == 0 {
return Err(OptimizeError::InvalidInput {
context: "levenberg_marquardt: empty initial guess".to_string(),
});
}
let mut x = x0.clone();
let mut fx = f(&x).map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: initial evaluation - {}", e),
})?;
if fx.shape()[0] != n {
return Err(OptimizeError::InvalidInput {
context: format!(
"levenberg_marquardt: function returns {} values but input has {} dimensions",
fx.shape()[0],
n
),
});
}
let mut lambda = 0.001;
let lambda_up = 10.0;
let lambda_down = 0.1;
let lambda_min = SINGULAR_THRESHOLD;
let lambda_max = 1e10;
for iter in 0..options.max_iter {
let res_norm = tensor_norm(client, &fx).map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: norm - {}", e),
})?;
if res_norm < options.tol {
return Ok(TensorRootResult {
x,
fun: fx,
iterations: iter + 1,
residual_norm: res_norm,
converged: true,
});
}
let jacobian = finite_difference_jacobian_tensor(client, &f, &x, &fx, options.eps)?;
let jt = jacobian
.transpose(0, 1)
.map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: transpose - {}", e),
})?;
let jtj = client
.matmul(&jt, &jacobian)
.map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: J^T J - {}", e),
})?;
let jtj_damped = add_lambda_identity(client, &jtj, lambda)?;
let fx_col = fx
.reshape(&[n, 1])
.map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: reshape fx - {}", e),
})?;
let jtf = client
.matmul(&jt, &fx_col)
.map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: J^T f - {}", e),
})?;
let neg_jtf = client
.mul_scalar(&jtf, -1.0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: negate jtf - {}", e),
})?;
let dx_col = match LinearAlgebraAlgorithms::solve(client, &jtj_damped, &neg_jtf) {
Ok(dx) => dx,
Err(_) => {
lambda *= lambda_up;
lambda = lambda.clamp(lambda_min, lambda_max);
continue;
}
};
let dx = dx_col
.reshape(&[n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: reshape dx - {}", e),
})?;
let x_new = client
.add(&x, &dx)
.map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: update x - {}", e),
})?;
let fx_new = f(&x_new).map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: evaluation - {}", e),
})?;
let new_res_norm =
tensor_norm(client, &fx_new).map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: new norm - {}", e),
})?;
if new_res_norm < res_norm {
x = x_new;
fx = fx_new;
lambda *= lambda_down;
let dx_norm = tensor_norm(client, &dx).map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: dx norm - {}", e),
})?;
if dx_norm < options.x_tol {
return Ok(TensorRootResult {
x,
fun: fx,
iterations: iter + 1,
residual_norm: new_res_norm,
converged: true,
});
}
} else {
lambda *= lambda_up;
}
lambda = lambda.clamp(lambda_min, lambda_max);
}
let final_norm = tensor_norm(client, &fx).map_err(|e| OptimizeError::NumericalError {
message: format!("levenberg_marquardt: final norm - {}", e),
})?;
Ok(TensorRootResult {
x,
fun: fx,
iterations: options.max_iter,
residual_norm: final_norm,
converged: false,
})
}
fn add_lambda_identity<R, C>(client: &C, a: &Tensor<R>, lambda: f64) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
let n = a.shape()[0];
let lambda_vec =
client
.fill(&[n], lambda, DType::F64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("add_lambda_identity: fill - {}", e),
})?;
let lambda_i = LinearAlgebraAlgorithms::diagflat(client, &lambda_vec).map_err(|e| {
OptimizeError::NumericalError {
message: format!("add_lambda_identity: diagflat - {}", e),
}
})?;
client
.add(a, &lambda_i)
.map_err(|e| OptimizeError::NumericalError {
message: format!("add_lambda_identity: add - {}", e),
})
}