use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::impl_generic::utils::{
SINGULAR_THRESHOLD, compute_cost as utils_compute_cost,
finite_difference_jacobian as utils_finite_difference_jacobian,
tensor_norm as utils_tensor_norm,
};
use crate::optimize::least_squares::traits::LeastSquaresOptions;
use super::TensorLeastSquaresResult;
pub fn leastsq_impl<R, C, F>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &LeastSquaresOptions,
) -> OptimizeResult<TensorLeastSquaresResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
let n = x0.shape()[0];
if n == 0 {
return Err(OptimizeError::InvalidInput {
context: "leastsq: empty initial guess".to_string(),
});
}
let mut x = x0.clone();
let mut fx = f(&x).map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: initial evaluation - {}", e),
})?;
let m = fx.shape()[0];
if m == 0 {
return Err(OptimizeError::InvalidInput {
context: "leastsq: residual function returns empty vector".to_string(),
});
}
let mut nfev = 1;
let mut cost = compute_cost(client, &fx)?;
let mut lambda = 0.001;
let lambda_up = 10.0;
let lambda_down = 0.1;
let lambda_min = SINGULAR_THRESHOLD;
let lambda_max = 1e10;
for iter in 0..options.max_iter {
if cost < options.f_tol {
return Ok(TensorLeastSquaresResult {
x,
residuals: fx,
cost,
iterations: iter + 1,
nfev,
converged: true,
});
}
let jacobian = finite_difference_jacobian(client, &f, &x, &fx, m, n, options.eps)?;
nfev += n;
let jt = jacobian
.transpose(0, 1)
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: transpose - {}", e),
})?;
let jtj = client
.matmul(&jt, &jacobian)
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: J^T J - {}", e),
})?;
let jtj_damped = add_scaled_diagonal(client, &jtj, lambda, n)?;
let fx_col = fx
.reshape(&[m, 1])
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: reshape fx - {}", e),
})?;
let jtf = client
.matmul(&jt, &fx_col)
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: J^T f - {}", e),
})?;
let jtf_vec = jtf
.reshape(&[n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: reshape jtf - {}", e),
})?;
let grad_norm = tensor_norm(client, &jtf_vec)?;
if grad_norm < options.g_tol {
return Ok(TensorLeastSquaresResult {
x,
residuals: fx,
cost,
iterations: iter + 1,
nfev,
converged: true,
});
}
let neg_jtf = client
.mul_scalar(&jtf, -1.0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: negate jtf - {}", e),
})?;
let dx_col = match LinearAlgebraAlgorithms::solve(client, &jtj_damped, &neg_jtf) {
Ok(dx) => dx,
Err(_) => {
lambda *= lambda_up;
lambda = lambda.clamp(lambda_min, lambda_max);
continue;
}
};
let dx = dx_col
.reshape(&[n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: reshape dx - {}", e),
})?;
let x_new = client
.add(&x, &dx)
.map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: update x - {}", e),
})?;
let fx_new = f(&x_new).map_err(|e| OptimizeError::NumericalError {
message: format!("leastsq: evaluation - {}", e),
})?;
nfev += 1;
let cost_new = compute_cost(client, &fx_new)?;
if cost_new < cost {
let dx_norm = tensor_norm(client, &dx)?;
if dx_norm < options.x_tol {
return Ok(TensorLeastSquaresResult {
x: x_new,
residuals: fx_new,
cost: cost_new,
iterations: iter + 1,
nfev,
converged: true,
});
}
x = x_new;
fx = fx_new;
cost = cost_new;
lambda *= lambda_down;
} else {
lambda *= lambda_up;
}
lambda = lambda.clamp(lambda_min, lambda_max);
}
Ok(TensorLeastSquaresResult {
x,
residuals: fx,
cost,
iterations: options.max_iter,
nfev,
converged: false,
})
}
fn compute_cost<R, C>(client: &C, fx: &Tensor<R>) -> OptimizeResult<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
utils_compute_cost(client, fx).map_err(|e| OptimizeError::NumericalError {
message: format!("compute_cost: {}", e),
})
}
fn tensor_norm<R, C>(client: &C, v: &Tensor<R>) -> OptimizeResult<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
utils_tensor_norm(client, v).map_err(|e| OptimizeError::NumericalError {
message: format!("tensor_norm: {}", e),
})
}
fn finite_difference_jacobian<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
fx: &Tensor<R>,
m: usize,
n: usize,
eps: f64,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
utils_finite_difference_jacobian(client, f, x, fx, m, n, eps).map_err(|e| {
OptimizeError::NumericalError {
message: format!("finite_difference_jacobian: {}", e),
}
})
}
fn add_scaled_diagonal<R, C>(
client: &C,
a: &Tensor<R>,
lambda: f64,
n: usize,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
let diag_vec =
LinearAlgebraAlgorithms::diag(client, a).map_err(|e| OptimizeError::NumericalError {
message: format!("add_scaled_diagonal: diag - {}", e),
})?;
let abs_diag = client
.abs(&diag_vec)
.map_err(|e| OptimizeError::NumericalError {
message: format!("add_scaled_diagonal: abs - {}", e),
})?;
let threshold = client
.fill(&[n], SINGULAR_THRESHOLD, DType::F64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("add_scaled_diagonal: threshold - {}", e),
})?;
let clamped_diag =
client
.maximum(&abs_diag, &threshold)
.map_err(|e| OptimizeError::NumericalError {
message: format!("add_scaled_diagonal: max - {}", e),
})?;
let scaled_diag =
client
.mul_scalar(&clamped_diag, lambda)
.map_err(|e| OptimizeError::NumericalError {
message: format!("add_scaled_diagonal: scale - {}", e),
})?;
let diag_matrix = LinearAlgebraAlgorithms::diagflat(client, &scaled_diag).map_err(|e| {
OptimizeError::NumericalError {
message: format!("add_scaled_diagonal: diagflat - {}", e),
}
})?;
client
.add(a, &diag_matrix)
.map_err(|e| OptimizeError::NumericalError {
message: format!("add_scaled_diagonal: add - {}", e),
})
}