use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Float;
pub fn gradient<T>(
f: &Array<T>,
varargs: Option<GradientSpacing<T>>,
axis: Option<Vec<usize>>,
edge_order: usize,
) -> Result<Vec<Array<T>>>
where
T: Float + Clone + 'static,
{
let ndim = f.ndim();
let shape = f.shape();
if edge_order != 1 && edge_order != 2 {
return Err(NumRs2Error::ValueError(
"edge_order must be 1 or 2".to_string(),
));
}
let axes = match axis {
Some(a) => {
for &ax in &a {
if ax >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
}
a
}
None => (0..ndim).collect(),
};
let spacings = match varargs {
None => vec![T::one(); ndim],
Some(GradientSpacing::Uniform(h)) => vec![h; ndim],
Some(GradientSpacing::PerAxis(spacings)) => {
if spacings.len() != ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"spacing array length {} doesn't match array dimensions {}",
spacings.len(),
ndim
)));
}
spacings
}
};
let mut results = Vec::new();
for &ax in &axes {
let mut grad = Array::zeros(&shape);
let h = spacings[ax];
let n = shape[ax];
if n == 1 {
results.push(grad);
continue;
}
let mut indices = vec![0; ndim];
let total_perp: usize = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.product();
for perp_idx in 0..total_perp {
let mut temp = perp_idx;
let mut _dim_idx = 0;
for i in 0..ndim {
if i != ax {
let stride: usize = shape
.iter()
.enumerate()
.filter(|(j, _)| *j > i && *j != ax)
.map(|(_, &s)| s)
.product();
indices[i] = temp / stride;
temp %= stride;
_dim_idx += 1;
}
}
for i in 0..n {
indices[ax] = i;
let derivative = if i == 0 {
if edge_order == 1 || n < 3 {
indices[ax] = 1;
let f1 = f.get(&indices)?;
indices[ax] = 0;
let f0 = f.get(&indices)?;
(f1 - f0) / h
} else {
indices[ax] = 0;
let f0 = f.get(&indices)?;
indices[ax] = 1;
let f1 = f.get(&indices)?;
indices[ax] = 2;
let f2 = f.get(&indices)?;
(-f2 * T::from(0.5).expect("0.5 should be representable")
+ f1 * T::from(2.0).expect("2.0 should be representable")
- f0 * T::from(1.5).expect("1.5 should be representable"))
/ h
}
} else if i == n - 1 {
if edge_order == 1 || n < 3 {
indices[ax] = n - 1;
let fn1 = f.get(&indices)?;
indices[ax] = n - 2;
let fn2 = f.get(&indices)?;
(fn1 - fn2) / h
} else {
indices[ax] = n - 1;
let fn1 = f.get(&indices)?;
indices[ax] = n - 2;
let fn2 = f.get(&indices)?;
indices[ax] = n - 3;
let fn3 = f.get(&indices)?;
(fn3 * T::from(0.5).expect("0.5 should be representable")
- fn2 * T::from(2.0).expect("2.0 should be representable")
+ fn1 * T::from(1.5).expect("1.5 should be representable"))
/ h
}
} else {
indices[ax] = i + 1;
let fplus = f.get(&indices)?;
indices[ax] = i - 1;
let fminus = f.get(&indices)?;
(fplus - fminus) / (h * T::from(2.0).expect("2.0 should be representable"))
};
indices[ax] = i;
grad.set(&indices, derivative)?;
}
}
results.push(grad);
}
Ok(results)
}
pub enum GradientSpacing<T> {
Uniform(T),
PerAxis(Vec<T>),
}
pub fn signbit<T: Float + Clone>(array: &Array<T>) -> Array<bool> {
array.map(|x| x.is_sign_negative())
}
pub fn reciprocal<T: Float + Clone>(array: &Array<T>) -> Array<T> {
array.map(|x| T::one() / x)
}
pub fn positive<T: Clone>(array: &Array<T>) -> Array<T> {
array.clone()
}
pub fn negative<T: Clone + std::ops::Neg<Output = T>>(array: &Array<T>) -> Array<T> {
array.map(|x| -x)
}
pub fn rint<T: Float + Clone>(array: &Array<T>) -> Array<T> {
array.map(|x| x.round())
}
pub fn fix<T: Float + Clone>(array: &Array<T>) -> Array<T> {
array.map(|x| x.trunc())
}
pub fn fmax<T: Float + Clone>(x1: &Array<T>, x2: &Array<T>) -> Result<Array<T>> {
if x1.shape() != x2.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x1.shape(),
actual: x2.shape(),
});
}
let x1_data = x1.to_vec();
let x2_data = x2.to_vec();
let result: Vec<T> = x1_data
.into_iter()
.zip(x2_data)
.map(|(a, b)| {
if a.is_nan() {
b
} else if b.is_nan() {
a
} else {
a.max(b)
}
})
.collect();
Ok(Array::from_vec(result).reshape(&x1.shape()))
}
pub fn fmin<T: Float + Clone>(x1: &Array<T>, x2: &Array<T>) -> Result<Array<T>> {
if x1.shape() != x2.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x1.shape(),
actual: x2.shape(),
});
}
let x1_data = x1.to_vec();
let x2_data = x2.to_vec();
let result: Vec<T> = x1_data
.into_iter()
.zip(x2_data)
.map(|(a, b)| {
if a.is_nan() {
b
} else if b.is_nan() {
a
} else {
a.min(b)
}
})
.collect();
Ok(Array::from_vec(result).reshape(&x1.shape()))
}