use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::minimize::MinimizeOptions;
use super::helpers::{TensorMinimizeResult, line_search_tensor};
use super::utils::{SINGULAR_THRESHOLD, tensor_norm};
pub fn powell_impl<R, C, F>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &MinimizeOptions,
) -> OptimizeResult<TensorMinimizeResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<f64>,
{
let n = x0.shape()[0];
if n == 0 {
return Err(OptimizeError::InvalidInput {
context: "powell: empty initial guess".to_string(),
});
}
let mut x = x0.clone();
let mut fx = f(&x).map_err(|e| OptimizeError::NumericalError {
message: format!("powell: initial evaluation - {}", e),
})?;
let mut nfev = 1;
let mut directions = create_identity_matrix::<R, C>(client, n)?;
for iter in 0..options.max_iter {
let x_start = x.clone();
let fx_start = fx;
let mut max_decrease = 0.0;
let mut max_decrease_idx = 0;
for i in 0..n {
let direction = extract_row(client, &directions, i)?;
let (x_new, fx_new, evals) = line_search_tensor(client, &f, &x, &direction, fx)?;
nfev += evals;
let decrease = fx - fx_new;
if decrease > max_decrease {
max_decrease = decrease;
max_decrease_idx = i;
}
x = x_new;
fx = fx_new;
}
if 2.0 * (fx_start - fx).abs()
<= options.f_tol * (fx_start.abs() + fx.abs() + SINGULAR_THRESHOLD)
{
return Ok(TensorMinimizeResult {
x,
fun: fx,
iterations: iter + 1,
nfev,
converged: true,
});
}
let new_direction =
client
.sub(&x, &x_start)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: new direction - {}", e),
})?;
let new_dir_norm =
tensor_norm(client, &new_direction).map_err(|e| OptimizeError::NumericalError {
message: format!("powell: direction norm - {}", e),
})?;
if new_dir_norm > SINGULAR_THRESHOLD {
directions =
update_direction_set(client, &directions, max_decrease_idx, &new_direction, n)?;
}
}
Ok(TensorMinimizeResult {
x,
fun: fx,
iterations: options.max_iter,
nfev,
converged: false,
})
}
fn create_identity_matrix<R, C>(client: &C, n: usize) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
client
.eye(n, None, DType::F64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: create identity - {}", e),
})
}
fn extract_row<R, C>(_client: &C, matrix: &Tensor<R>, row: usize) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let n = matrix.shape()[1];
matrix
.narrow(0, row, 1)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: narrow row - {}", e),
})?
.contiguous()
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: contiguous row - {}", e),
})?
.reshape(&[n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: reshape row - {}", e),
})
}
fn update_direction_set<R, C>(
client: &C,
directions: &Tensor<R>,
remove_idx: usize,
new_direction: &Tensor<R>,
n: usize,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
let new_dir_row = new_direction
.contiguous()
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: contiguous new dir - {}", e),
})?
.unsqueeze(0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: unsqueeze new dir - {}", e),
})?;
let mut rows_to_cat: Vec<Tensor<R>> = Vec::with_capacity(n);
for i in 0..n {
if i == remove_idx {
continue;
}
let row = directions
.narrow(0, i, 1)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: narrow row {} - {}", i, e),
})?
.contiguous()
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: contiguous row {} - {}", i, e),
})?;
rows_to_cat.push(row);
}
rows_to_cat.push(new_dir_row);
let refs: Vec<&Tensor<R>> = rows_to_cat.iter().collect();
client
.cat(&refs, 0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell: concat directions - {}", e),
})
}