use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, Zero};
use std::fmt::Debug;
pub fn allclose<T>(a: &Array<T>, b: &Array<T>) -> bool
where
T: Clone + Float + Debug,
{
allclose_with_tol(
a,
b,
T::from(1e-7).expect("Failed to convert 1e-7 to type T"),
T::zero(),
)
}
pub fn allclose_with_tol<T>(a: &Array<T>, b: &Array<T>, rtol: T, atol: T) -> bool
where
T: Clone + Float + Debug,
{
if a.shape() != b.shape() {
return false;
}
let a_data = a.to_vec();
let b_data = b.to_vec();
for (a_val, b_val) in a_data.iter().zip(b_data.iter()) {
if !isclose(*a_val, *b_val, rtol, atol) {
return false;
}
}
true
}
pub fn isclose<T>(a: T, b: T, rtol: T, atol: T) -> bool
where
T: Clone + Float + Debug,
{
if a == b {
return true;
}
if a.is_nan() && b.is_nan() {
return true;
}
let tol = atol + rtol * b.abs();
(a - b).abs() <= tol
}
pub fn array_equal<T>(a: &Array<T>, b: &Array<T>, equal_nan: Option<bool>) -> bool
where
T: Clone + PartialEq + Debug + 'static,
{
let equal_nan = equal_nan.unwrap_or(false);
if a.shape() != b.shape() {
return false;
}
if equal_nan {
if let Some(result) = array_equal_with_nan_handling(a, b) {
return result;
}
}
a.to_vec() == b.to_vec()
}
fn array_equal_with_nan_handling<T>(a: &Array<T>, b: &Array<T>) -> Option<bool>
where
T: Clone + PartialEq + Debug + 'static,
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { &*(a as *const Array<T> as *const Array<f32>) };
let b_f32 = unsafe { &*(b as *const Array<T> as *const Array<f32>) };
let a_vec = a_f32.to_vec();
let b_vec = b_f32.to_vec();
if a_vec.len() != b_vec.len() {
return Some(false);
}
for i in 0..a_vec.len() {
if a_vec[i] != b_vec[i] {
if a_vec[i].is_nan() && b_vec[i].is_nan() {
continue;
}
return Some(false);
}
}
return Some(true);
}
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { &*(a as *const Array<T> as *const Array<f64>) };
let b_f64 = unsafe { &*(b as *const Array<T> as *const Array<f64>) };
let a_vec = a_f64.to_vec();
let b_vec = b_f64.to_vec();
if a_vec.len() != b_vec.len() {
return Some(false);
}
for i in 0..a_vec.len() {
if a_vec[i] != b_vec[i] {
if a_vec[i].is_nan() && b_vec[i].is_nan() {
continue;
}
return Some(false);
}
}
return Some(true);
}
None
}
pub fn array_compare<T>(a: &Array<T>, b: &Array<T>, options: &ArrayCompareOptions) -> bool
where
T: Clone + PartialEq + Debug + 'static,
{
if a.shape() == b.shape() {
return array_compare_equal_shapes(a, b, options);
}
if options.allow_broadcasting {
if let Ok(broadcast_arrays) = crate::stride_tricks::broadcast_arrays(&[a, b]) {
return array_compare_equal_shapes(&broadcast_arrays[0], &broadcast_arrays[1], options);
}
}
false
}
fn array_compare_equal_shapes<T>(a: &Array<T>, b: &Array<T>, options: &ArrayCompareOptions) -> bool
where
T: Clone + PartialEq + Debug + 'static,
{
debug_assert_eq!(a.shape(), b.shape(), "Arrays must have the same shape");
let a_vec = a.to_vec();
let b_vec = b.to_vec();
let mut ignore_mask = vec![false; a_vec.len()];
if let Some(indices) = &options.ignore_indices {
for &idx in indices {
if idx < ignore_mask.len() {
ignore_mask[idx] = true;
}
}
}
if options.equal_nan {
if let Some(result) = array_compare_with_nan_handling(a, b, &ignore_mask) {
return result;
}
}
for i in 0..a_vec.len() {
if ignore_mask[i] {
continue; }
if a_vec[i] != b_vec[i] {
return false;
}
}
true
}
fn array_compare_with_nan_handling<T>(
a: &Array<T>,
b: &Array<T>,
ignore_mask: &[bool],
) -> Option<bool>
where
T: Clone + PartialEq + Debug + 'static,
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { &*(a as *const Array<T> as *const Array<f32>) };
let b_f32 = unsafe { &*(b as *const Array<T> as *const Array<f32>) };
let a_vec = a_f32.to_vec();
let b_vec = b_f32.to_vec();
for i in 0..a_vec.len() {
if ignore_mask[i] {
continue;
}
if a_vec[i] != b_vec[i] {
if a_vec[i].is_nan() && b_vec[i].is_nan() {
continue;
}
return Some(false);
}
}
return Some(true);
}
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { &*(a as *const Array<T> as *const Array<f64>) };
let b_f64 = unsafe { &*(b as *const Array<T> as *const Array<f64>) };
let a_vec = a_f64.to_vec();
let b_vec = b_f64.to_vec();
for i in 0..a_vec.len() {
if ignore_mask[i] {
continue;
}
if a_vec[i] != b_vec[i] {
if a_vec[i].is_nan() && b_vec[i].is_nan() {
continue;
}
return Some(false);
}
}
return Some(true);
}
None
}
#[derive(Debug, Clone, Default)]
pub struct ArrayCompareOptions {
pub equal_nan: bool,
pub allow_broadcasting: bool,
pub ignore_indices: Option<Vec<usize>>,
pub rtol: Option<f64>,
pub atol: Option<f64>,
}
pub fn all<T>(a: &Array<T>) -> bool
where
T: Clone + PartialEq + Debug,
bool: From<T>,
{
a.to_vec().iter().all(|val| bool::from(val.clone()))
}
pub fn any<T>(a: &Array<T>) -> bool
where
T: Clone + PartialEq + Debug,
bool: From<T>,
{
a.to_vec().iter().any(|val| bool::from(val.clone()))
}
pub fn greater<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<bool>>
where
T: Clone + PartialOrd + Debug,
{
let broadcast_shape = Array::<T>::broadcast_shape(&a.shape(), &b.shape()).map_err(|_| {
NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
}
})?;
let a_broadcast = if a.shape() != broadcast_shape {
a.broadcast_to(&broadcast_shape)?
} else {
a.clone()
};
let b_broadcast = if b.shape() != broadcast_shape {
b.broadcast_to(&broadcast_shape)?
} else {
b.clone()
};
let a_data = a_broadcast.to_vec();
let b_data = b_broadcast.to_vec();
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(a_val, b_val)| a_val > b_val)
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn greater_equal<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<bool>>
where
T: Clone + PartialOrd + Debug,
{
let broadcast_shape = Array::<T>::broadcast_shape(&a.shape(), &b.shape()).map_err(|_| {
NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
}
})?;
let a_broadcast = if a.shape() != broadcast_shape {
a.broadcast_to(&broadcast_shape)?
} else {
a.clone()
};
let b_broadcast = if b.shape() != broadcast_shape {
b.broadcast_to(&broadcast_shape)?
} else {
b.clone()
};
let a_data = a_broadcast.to_vec();
let b_data = b_broadcast.to_vec();
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(a_val, b_val)| a_val >= b_val)
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn less<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<bool>>
where
T: Clone + PartialOrd + Debug,
{
let broadcast_shape = Array::<T>::broadcast_shape(&a.shape(), &b.shape()).map_err(|_| {
NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
}
})?;
let a_broadcast = if a.shape() != broadcast_shape {
a.broadcast_to(&broadcast_shape)?
} else {
a.clone()
};
let b_broadcast = if b.shape() != broadcast_shape {
b.broadcast_to(&broadcast_shape)?
} else {
b.clone()
};
let a_data = a_broadcast.to_vec();
let b_data = b_broadcast.to_vec();
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(a_val, b_val)| a_val < b_val)
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn less_equal<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<bool>>
where
T: Clone + PartialOrd + Debug,
{
let broadcast_shape = Array::<T>::broadcast_shape(&a.shape(), &b.shape()).map_err(|_| {
NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
}
})?;
let a_broadcast = if a.shape() != broadcast_shape {
a.broadcast_to(&broadcast_shape)?
} else {
a.clone()
};
let b_broadcast = if b.shape() != broadcast_shape {
b.broadcast_to(&broadcast_shape)?
} else {
b.clone()
};
let a_data = a_broadcast.to_vec();
let b_data = b_broadcast.to_vec();
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(a_val, b_val)| a_val <= b_val)
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn equal<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<bool>>
where
T: Clone + PartialEq + Debug,
{
let broadcast_shape = Array::<T>::broadcast_shape(&a.shape(), &b.shape()).map_err(|_| {
NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
}
})?;
let a_broadcast = if a.shape() != broadcast_shape {
a.broadcast_to(&broadcast_shape)?
} else {
a.clone()
};
let b_broadcast = if b.shape() != broadcast_shape {
b.broadcast_to(&broadcast_shape)?
} else {
b.clone()
};
let a_data = a_broadcast.to_vec();
let b_data = b_broadcast.to_vec();
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(a_val, b_val)| a_val == b_val)
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn not_equal<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<bool>>
where
T: Clone + PartialEq + Debug,
{
let broadcast_shape = Array::<T>::broadcast_shape(&a.shape(), &b.shape()).map_err(|_| {
NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
}
})?;
let a_broadcast = if a.shape() != broadcast_shape {
a.broadcast_to(&broadcast_shape)?
} else {
a.clone()
};
let b_broadcast = if b.shape() != broadcast_shape {
b.broadcast_to(&broadcast_shape)?
} else {
b.clone()
};
let a_data = a_broadcast.to_vec();
let b_data = b_broadcast.to_vec();
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(a_val, b_val)| a_val != b_val)
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn isclose_array<T>(a: &Array<T>, b: &Array<T>, rtol: T, atol: T) -> Result<Array<bool>>
where
T: Clone + Float + Debug,
{
let broadcast_shape = Array::<T>::broadcast_shape(&a.shape(), &b.shape()).map_err(|_| {
NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
}
})?;
let a_broadcast = if a.shape() != broadcast_shape {
a.broadcast_to(&broadcast_shape)?
} else {
a.clone()
};
let b_broadcast = if b.shape() != broadcast_shape {
b.broadcast_to(&broadcast_shape)?
} else {
b.clone()
};
let a_data = a_broadcast.to_vec();
let b_data = b_broadcast.to_vec();
let result: Vec<bool> = a_data
.iter()
.zip(b_data.iter())
.map(|(a_val, b_val)| isclose(*a_val, *b_val, rtol, atol))
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn logical_and(x1: &Array<bool>, x2: &Array<bool>) -> Result<Array<bool>> {
let broadcast_shape = Array::<bool>::broadcast_shape(&x1.shape(), &x2.shape())?;
let x1_broadcast = x1.broadcast_to(&broadcast_shape)?;
let x2_broadcast = x2.broadcast_to(&broadcast_shape)?;
let result_data: Vec<bool> = x1_broadcast
.to_vec()
.iter()
.zip(x2_broadcast.to_vec().iter())
.map(|(&a, &b)| a && b)
.collect();
Ok(Array::from_vec(result_data).reshape(&broadcast_shape))
}
pub fn logical_or(x1: &Array<bool>, x2: &Array<bool>) -> Result<Array<bool>> {
let broadcast_shape = Array::<bool>::broadcast_shape(&x1.shape(), &x2.shape())?;
let x1_broadcast = x1.broadcast_to(&broadcast_shape)?;
let x2_broadcast = x2.broadcast_to(&broadcast_shape)?;
let result_data: Vec<bool> = x1_broadcast
.to_vec()
.iter()
.zip(x2_broadcast.to_vec().iter())
.map(|(&a, &b)| a || b)
.collect();
Ok(Array::from_vec(result_data).reshape(&broadcast_shape))
}
pub fn logical_not(x: &Array<bool>) -> Result<Array<bool>> {
let result_data: Vec<bool> = x.to_vec().iter().map(|&a| !a).collect();
Ok(Array::from_vec(result_data).reshape(&x.shape()))
}
pub fn logical_xor(x1: &Array<bool>, x2: &Array<bool>) -> Result<Array<bool>> {
let broadcast_shape = Array::<bool>::broadcast_shape(&x1.shape(), &x2.shape())?;
let x1_broadcast = x1.broadcast_to(&broadcast_shape)?;
let x2_broadcast = x2.broadcast_to(&broadcast_shape)?;
let result_data: Vec<bool> = x1_broadcast
.to_vec()
.iter()
.zip(x2_broadcast.to_vec().iter())
.map(|(&a, &b)| a ^ b)
.collect();
Ok(Array::from_vec(result_data).reshape(&broadcast_shape))
}
pub fn count_nonzero<T>(a: &Array<T>, axis: Option<usize>) -> Result<Array<usize>>
where
T: Clone + Zero + PartialEq,
{
if let Some(ax) = axis {
if ax >= a.ndim() {
return Err(NumRs2Error::InvalidOperation(format!(
"axis {} is out of bounds for array of dimension {}",
ax,
a.ndim()
)));
}
let shape = a.shape();
let mut new_shape = shape.clone();
new_shape.remove(ax);
if new_shape.is_empty() {
new_shape = vec![1];
}
let axis_size = shape[ax];
let stride_before: usize = shape[..ax].iter().product();
let stride_after: usize = shape[ax + 1..].iter().product();
let total_size = stride_before * stride_after;
let mut counts = vec![0usize; total_size];
let data = a.to_vec();
for i in 0..stride_before {
for j in 0..axis_size {
for k in 0..stride_after {
let idx = i * axis_size * stride_after + j * stride_after + k;
let out_idx = i * stride_after + k;
if data[idx] != T::zero() {
counts[out_idx] += 1;
}
}
}
}
Ok(Array::from_vec(counts).reshape(&new_shape))
} else {
let count = a.to_vec().into_iter().filter(|x| *x != T::zero()).count();
Ok(Array::from_vec(vec![count]))
}
}
pub fn flatnonzero<T>(a: &Array<T>) -> Result<Array<usize>>
where
T: Clone + Zero + PartialEq,
{
let data = a.to_vec();
let indices: Vec<usize> = data
.into_iter()
.enumerate()
.filter_map(|(idx, val)| if val != T::zero() { Some(idx) } else { None })
.collect();
Ok(Array::from_vec(indices))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allclose() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array::from_vec(vec![1.0000001, 2.0000002, 3.0000003]);
let c = Array::from_vec(vec![1.001, 2.002, 3.003]);
assert!(allclose(&a, &b));
assert!(!allclose(&a, &c));
assert!(allclose_with_tol(&a, &c, 1e-2, 0.0));
}
#[test]
fn test_array_equal() {
let a = Array::from_vec(vec![1, 2, 3]);
let b = Array::from_vec(vec![1, 2, 3]);
let c = Array::from_vec(vec![1, 2, 4]);
assert!(array_equal(&a, &b, None));
assert!(!array_equal(&a, &c, None));
let d = Array::from_vec(vec![1, 2, 3, 4]).reshape(&[2, 2]);
assert!(!array_equal(&a, &d, None));
}
#[test]
fn test_isclose() {
assert!(isclose(1.0, 1.0000001, 1e-7, 0.0));
assert!(!isclose(1.0, 1.001, 1e-7, 0.0));
assert!(isclose(f64::NAN, f64::NAN, 1e-7, 0.0));
assert!(isclose(f64::INFINITY, f64::INFINITY, 1e-7, 0.0));
assert!(!isclose(f64::INFINITY, 1.0, 1e-7, 0.0));
}
#[test]
fn test_all_any() {
let all_true = Array::from_vec(vec![true, true, true]);
let mixed = Array::from_vec(vec![true, false, true]);
let all_false = Array::from_vec(vec![false, false, false]);
assert!(all(&all_true));
assert!(!all(&mixed));
assert!(!all(&all_false));
assert!(any(&all_true));
assert!(any(&mixed));
assert!(!any(&all_false));
}
#[test]
fn test_comparison_ops() {
let a = Array::from_vec(vec![1, 2, 3]);
let b = Array::from_vec(vec![0, 2, 4]);
let result = greater(&a, &b).expect("greater comparison should succeed");
assert_eq!(result.to_vec(), vec![true, false, false]);
let result = greater_equal(&a, &b).expect("greater_equal comparison should succeed");
assert_eq!(result.to_vec(), vec![true, true, false]);
let result = less(&a, &b).expect("less comparison should succeed");
assert_eq!(result.to_vec(), vec![false, false, true]);
let result = less_equal(&a, &b).expect("less_equal comparison should succeed");
assert_eq!(result.to_vec(), vec![false, true, true]);
let result = equal(&a, &b).expect("equal comparison should succeed");
assert_eq!(result.to_vec(), vec![false, true, false]);
let result = not_equal(&a, &b).expect("not_equal comparison should succeed");
assert_eq!(result.to_vec(), vec![true, false, true]);
}
#[test]
fn test_broadcasting() {
let a = Array::from_vec(vec![1, 2, 3]);
let b = Array::from_vec(vec![1]).reshape(&[1]);
let result = equal(&a, &b).expect("broadcast equal should succeed");
assert_eq!(result.to_vec(), vec![true, false, false]);
let c = Array::from_vec(vec![1, 2, 3, 4]).reshape(&[2, 2]);
let d = Array::from_vec(vec![1, 2]).reshape(&[1, 2]);
let result = equal(&c, &d).expect("2D broadcast equal should succeed");
assert_eq!(result.shape(), vec![2, 2]);
assert_eq!(result.to_vec(), vec![true, true, false, false]);
}
#[test]
fn test_isclose_array() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array::from_vec(vec![1.0000001, 2.0000002, 3.0000003]);
let result = isclose_array(&a, &b, 1e-7, 0.0).expect("isclose_array should succeed");
assert_eq!(result.to_vec(), vec![true, true, true]);
let result = isclose_array(&a, &b, 1e-10, 0.0)
.expect("isclose_array with strict tol should succeed");
assert_eq!(result.to_vec(), vec![false, false, false]);
}
}