use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
fn sum_squared<R, C>(client: &C, x: &Tensor<R>) -> Result<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
let x_sq = client.mul(x, x)?;
let sum = client.sum(&x_sq, &[0], false)?;
let sum_val: Vec<f64> = sum.to_vec();
Ok(sum_val[0])
}
pub fn tensor_norm<R, C>(client: &C, x: &Tensor<R>) -> Result<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
sum_squared(client, x).map(|sq| sq.sqrt())
}
pub fn tensor_dot<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> Result<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
let prod = client.mul(a, b)?;
let sum = client.sum(&prod, &[0], false)?;
let sum_val: Vec<f64> = sum.to_vec();
Ok(sum_val[0])
}
pub fn finite_difference_gradient<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
fx: f64,
eps: f64,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<f64>,
{
let n = x.shape()[0];
let identity = client.eye(n, None, DType::F64)?;
let eps_identity = client.mul_scalar(&identity, eps)?;
let mut grad_components: Vec<Tensor<R>> = Vec::with_capacity(n);
for i in 0..n {
let delta = eps_identity.narrow(0, i, 1)?.contiguous()?.reshape(&[n])?;
let x_plus = client.add(x, &delta)?;
let fx_plus = f(&x_plus)?;
let grad_i = (fx_plus - fx) / eps;
let grad_i_tensor = client.fill(&[1], grad_i, DType::F64)?;
grad_components.push(grad_i_tensor);
}
let refs: Vec<&Tensor<R>> = grad_components.iter().collect();
client.cat(&refs, 0)
}
pub fn tensor_add<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
client.add(a, b)
}
pub fn tensor_sub<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
client.sub(a, b)
}
pub fn tensor_scale<R, C>(client: &C, x: &Tensor<R>, s: f64) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + RuntimeClient<R>,
{
client.mul_scalar(x, s)
}
pub const SINGULAR_THRESHOLD: f64 = 1e-12;
pub fn compute_cost<R, C>(client: &C, x: &Tensor<R>) -> Result<f64>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
sum_squared(client, x)
}
pub fn finite_difference_jacobian<R, C, F>(
client: &C,
f: &F,
x: &Tensor<R>,
fx: &Tensor<R>,
_m: usize,
n: usize,
eps: f64,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<Tensor<R>>,
{
let identity = client.eye(n, None, DType::F64)?;
let eps_identity = client.mul_scalar(&identity, eps)?;
let mut jac_columns: Vec<Tensor<R>> = Vec::with_capacity(n);
for j in 0..n {
let delta = eps_identity.narrow(0, j, 1)?.contiguous()?.reshape(&[n])?;
let x_plus = client.add(x, &delta)?;
let fx_plus = f(&x_plus)?;
let diff = client.sub(&fx_plus, fx)?;
let jac_col = client.mul_scalar(&diff, 1.0 / eps)?;
let jac_col_2d = jac_col.unsqueeze(1)?;
jac_columns.push(jac_col_2d);
}
let refs: Vec<&Tensor<R>> = jac_columns.iter().collect();
client.cat(&refs, 1)
}