use crate::DType;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use super::utils::tensor_dot;
pub use crate::optimize::minimize::traits::TensorMinimizeResult;
pub fn backtracking_line_search_tensor<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
p: &Tensor<R>,
fx: f64,
grad: &Tensor<R>,
) -> OptimizeResult<(Tensor<R>, f64, usize)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<f64>,
{
let c = 0.0001;
let rho = 0.5;
let grad_dot_p = tensor_dot(client, grad, p).map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: grad_dot_p - {}", e),
})?;
let mut alpha = 1.0;
let mut nfev = 0;
for _ in 0..50 {
let scaled_p = client
.mul_scalar(p, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: scale p - {}", e),
})?;
let x_new = client
.add(x, &scaled_p)
.map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: x + alpha*p - {}", e),
})?;
let fx_new = f(&x_new).map_err(|e| OptimizeError::NumericalError {
message: format!("line_search: f eval - {}", e),
})?;
nfev += 1;
if fx_new <= fx + c * alpha * grad_dot_p {
return Ok((x_new, fx_new, nfev));
}
alpha *= rho;
}
Ok((x.clone(), fx, nfev))
}
pub fn line_search_tensor<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
direction: &Tensor<R>,
fx: f64,
) -> OptimizeResult<(Tensor<R>, f64, usize)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<f64>,
{
let mut alpha = 0.1;
let mut nfev = 0;
let mut best_x = x.clone();
let mut best_fx = fx;
for _ in 0..20 {
let scaled_dir =
client
.mul_scalar(direction, alpha)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell line_search: scale - {}", e),
})?;
let x_new = client
.add(x, &scaled_dir)
.map_err(|e| OptimizeError::NumericalError {
message: format!("powell line_search: add - {}", e),
})?;
let fx_new = f(&x_new).map_err(|e| OptimizeError::NumericalError {
message: format!("powell line_search: f eval - {}", e),
})?;
nfev += 1;
if fx_new < best_fx {
best_x = x_new;
best_fx = fx_new;
alpha *= 1.5;
} else {
alpha *= 0.5;
if alpha < 1e-10 {
break;
}
}
}
Ok((best_x, best_fx, nfev))
}
pub fn compare_f64_nan_safe(a: f64, b: f64) -> std::cmp::Ordering {
a.partial_cmp(&b).unwrap_or_else(|| {
match (a.is_nan(), b.is_nan()) {
(true, true) => std::cmp::Ordering::Equal,
(true, false) => std::cmp::Ordering::Greater,
(false, true) => std::cmp::Ordering::Less,
(false, false) => std::cmp::Ordering::Equal, }
})
}
use numr::autograd::{Var, VarGradStore, backward, backward_with_graph, var_mul, var_sum};
pub fn hvp_reverse_over_reverse<R, C>(
client: &C,
loss: &Var<R>,
x: &Var<R>,
v: &Tensor<R>,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let grads: VarGradStore<R> =
backward_with_graph(loss, client).map_err(|e| OptimizeError::NumericalError {
message: format!("hvp: backward_with_graph failed - {}", e),
})?;
let grad_x: &Var<R> = grads
.get_var(x.id())
.ok_or_else(|| OptimizeError::NumericalError {
message: "hvp: no gradient for x (is x a leaf with requires_grad=true?)".to_string(),
})?;
let v_var = Var::new(v.clone(), false);
let grad_v: Var<R> =
var_mul(grad_x, &v_var, client).map_err(|e| OptimizeError::NumericalError {
message: format!("hvp: var_mul failed - {}", e),
})?;
let grad_v_dot: Var<R> =
var_sum(&grad_v, &[0], false, client).map_err(|e| OptimizeError::NumericalError {
message: format!("hvp: var_sum failed - {}", e),
})?;
let hvp_grads = backward(&grad_v_dot, client).map_err(|e| OptimizeError::NumericalError {
message: format!("hvp: second backward failed - {}", e),
})?;
hvp_grads
.get(x.id())
.cloned()
.ok_or_else(|| OptimizeError::NumericalError {
message: "hvp: no HVP gradient for x".to_string(),
})
}
pub fn hvp_from_fn<R, C, F>(
client: &C,
f: F,
x: &Tensor<R>,
v: &Tensor<R>,
) -> OptimizeResult<(f64, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> Result<Var<R>>,
{
let x_var = Var::new(x.clone(), true);
let loss = f(&x_var, client).map_err(|e| OptimizeError::NumericalError {
message: format!("hvp_from_fn: function evaluation failed - {}", e),
})?;
let loss_value: f64 =
loss.tensor()
.item::<f64>()
.map_err(|e| OptimizeError::NumericalError {
message: format!("hvp_from_fn: scalar extraction - {}", e),
})?;
let hvp = hvp_reverse_over_reverse(client, &loss, &x_var, v)?;
Ok((loss_value, hvp))
}
pub fn gradient_from_fn<R, C, F>(
client: &C,
f: F,
x: &Tensor<R>,
) -> OptimizeResult<(f64, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
F: Fn(&Var<R>, &C) -> Result<Var<R>>,
{
let x_var = Var::new(x.clone(), true);
let loss = f(&x_var, client).map_err(|e| OptimizeError::NumericalError {
message: format!("gradient_from_fn: function evaluation failed - {}", e),
})?;
let loss_value: f64 =
loss.tensor()
.item::<f64>()
.map_err(|e| OptimizeError::NumericalError {
message: format!("gradient_from_fn: scalar extraction - {}", e),
})?;
let grads = backward(&loss, client).map_err(|e| OptimizeError::NumericalError {
message: format!("gradient_from_fn: backward failed - {}", e),
})?;
let grad = grads
.get(x_var.id())
.cloned()
.ok_or_else(|| OptimizeError::NumericalError {
message: "gradient_from_fn: no gradient for x".to_string(),
})?;
Ok((loss_value, grad))
}
#[cfg(test)]
mod hvp_tests {
use super::*;
use numr::autograd::var_sum;
use numr::runtime::Runtime;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_hvp_quadratic() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0, 1.0], &[3], &device);
let (fx, hvp) = hvp_from_fn(
&client,
|x_var, c| {
let x_sq = var_mul(x_var, x_var, c)?;
var_sum(&x_sq, &[0], false, c) },
&x,
&v,
)
.unwrap();
assert!((fx - 14.0).abs() < 1e-10);
let hvp_vals: Vec<f64> = hvp.to_vec();
assert!((hvp_vals[0] - 2.0).abs() < 1e-10);
assert!((hvp_vals[1] - 2.0).abs() < 1e-10);
assert!((hvp_vals[2] - 2.0).abs() < 1e-10);
}
#[test]
fn test_hvp_with_direction() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 0.0, 0.0], &[3], &device);
let (_, hvp) = hvp_from_fn(
&client,
|x_var, c| {
let x_sq = var_mul(x_var, x_var, c)?;
var_sum(&x_sq, &[0], false, c) },
&x,
&v,
)
.unwrap();
let hvp_vals: Vec<f64> = hvp.to_vec();
assert!((hvp_vals[0] - 2.0).abs() < 1e-10);
assert!((hvp_vals[1] - 0.0).abs() < 1e-10);
assert!((hvp_vals[2] - 0.0).abs() < 1e-10);
}
#[test]
fn test_gradient_from_fn() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let (fx, grad) = gradient_from_fn(
&client,
|x_var, c| {
let x_sq = var_mul(x_var, x_var, c)?;
var_sum(&x_sq, &[0], false, c) },
&x,
)
.unwrap();
assert!((fx - 14.0).abs() < 1e-10);
let grad_vals: Vec<f64> = grad.to_vec();
assert!((grad_vals[0] - 2.0).abs() < 1e-10);
assert!((grad_vals[1] - 4.0).abs() < 1e-10);
assert!((grad_vals[2] - 6.0).abs() < 1e-10);
}
}