use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
type UniqueResult<T> = (Array<T>, Array<usize>, Array<usize>, Array<usize>);
pub fn intersect1d<T: Clone + Eq + Hash + Ord>(ar1: &Array<T>, ar2: &Array<T>) -> Result<Array<T>> {
let set1: HashSet<T> = ar1.to_vec().into_iter().collect();
let set2: HashSet<T> = ar2.to_vec().into_iter().collect();
let mut intersection: Vec<T> = set1.intersection(&set2).cloned().collect();
intersection.sort();
Ok(Array::from_vec(intersection))
}
pub fn union1d<T: Clone + Eq + Hash + Ord>(ar1: &Array<T>, ar2: &Array<T>) -> Result<Array<T>> {
let set1: HashSet<T> = ar1.to_vec().into_iter().collect();
let set2: HashSet<T> = ar2.to_vec().into_iter().collect();
let mut union: Vec<T> = set1.union(&set2).cloned().collect();
union.sort();
Ok(Array::from_vec(union))
}
pub fn setdiff1d<T: Clone + Eq + Hash + Ord>(ar1: &Array<T>, ar2: &Array<T>) -> Result<Array<T>> {
let set1: HashSet<T> = ar1.to_vec().into_iter().collect();
let set2: HashSet<T> = ar2.to_vec().into_iter().collect();
let mut difference: Vec<T> = set1.difference(&set2).cloned().collect();
difference.sort();
Ok(Array::from_vec(difference))
}
pub fn setxor1d<T: Clone + Eq + Hash + Ord>(ar1: &Array<T>, ar2: &Array<T>) -> Result<Array<T>> {
let set1: HashSet<T> = ar1.to_vec().into_iter().collect();
let set2: HashSet<T> = ar2.to_vec().into_iter().collect();
let mut symmetric_diff: Vec<T> = set1.symmetric_difference(&set2).cloned().collect();
symmetric_diff.sort();
Ok(Array::from_vec(symmetric_diff))
}
pub fn in1d<T: Clone + Eq + Hash>(ar1: &Array<T>, ar2: &Array<T>) -> Result<Array<bool>> {
let set2: HashSet<T> = ar2.to_vec().into_iter().collect();
let result: Vec<bool> = ar1.to_vec().iter().map(|x| set2.contains(x)).collect();
Ok(Array::from_vec(result))
}
pub fn isin<T: Clone + Eq + Hash>(
element: &Array<T>,
test_elements: &Array<T>,
assume_unique: bool,
invert: bool,
) -> Result<Array<bool>> {
let test_set: HashSet<T> = if assume_unique {
test_elements.to_vec().into_iter().collect()
} else {
test_elements.to_vec().into_iter().collect()
};
let result: Vec<bool> = element
.to_vec()
.iter()
.map(|x| {
let contains = test_set.contains(x);
if invert {
!contains
} else {
contains
}
})
.collect();
let result_array = Array::from_vec(result);
Ok(result_array.reshape(&element.shape()))
}
pub fn unique_with_options<T: Clone + Eq + Hash + Ord + Debug>(
ar: &Array<T>,
return_index: bool,
return_inverse: bool,
return_counts: bool,
) -> Result<UniqueResult<T>> {
let data = ar.to_vec();
let mut seen = std::collections::HashMap::new();
let mut unique_values = Vec::new();
let mut first_indices = Vec::new();
let _counts: Vec<usize> = Vec::new();
for (i, value) in data.iter().enumerate() {
if let Some((_first_idx, count)) = seen.get_mut(value) {
*count += 1;
} else {
seen.insert(value.clone(), (i, 1));
unique_values.push(value.clone());
first_indices.push(i);
}
}
let mut sorted_indices: Vec<usize> = (0..unique_values.len()).collect();
sorted_indices.sort_by(|&a, &b| unique_values[a].cmp(&unique_values[b]));
let sorted_unique: Vec<T> = sorted_indices
.iter()
.map(|&i| unique_values[i].clone())
.collect();
let unique_array = Array::from_vec(sorted_unique.clone());
let indices_array = if return_index {
let sorted_first_indices: Vec<usize> =
sorted_indices.iter().map(|&i| first_indices[i]).collect();
Array::from_vec(sorted_first_indices)
} else {
Array::from_vec(vec![])
};
let inverse_array = if return_inverse {
let value_to_pos: std::collections::HashMap<&T, usize> = sorted_unique
.iter()
.enumerate()
.map(|(pos, val)| (val, pos))
.collect();
let inverse: Vec<usize> = data
.iter()
.map(|val| {
*value_to_pos
.get(val)
.expect("value must exist in value_to_pos map")
})
.collect();
Array::from_vec(inverse)
} else {
Array::from_vec(vec![])
};
let counts_array = if return_counts {
let sorted_counts: Vec<usize> = sorted_unique
.iter()
.map(|val| seen.get(val).expect("value must exist in seen map").1)
.collect();
Array::from_vec(sorted_counts)
} else {
Array::from_vec(vec![])
};
Ok((unique_array, indices_array, inverse_array, counts_array))
}
pub fn unique_axis<T: Clone + Eq + Hash + Ord + Debug + num_traits::Zero>(
ar: &Array<T>,
axis: Option<usize>,
return_index: bool,
return_inverse: bool,
return_counts: bool,
) -> Result<UniqueResult<T>> {
match axis {
None => {
unique_with_options(ar, return_index, return_inverse, return_counts)
}
Some(ax) => {
if ax >= ar.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array with {} dimensions",
ax,
ar.ndim()
)));
}
let shape = ar.shape();
let axis_dim = shape[ax];
let mut slices = Vec::with_capacity(axis_dim);
for i in 0..axis_dim {
let mut index_specs = vec![crate::indexing::IndexSpec::All; ar.ndim()];
index_specs[ax] = crate::indexing::IndexSpec::Index(i);
let slice = ar.index(&index_specs)?;
slices.push(slice);
}
let slice_vecs: Vec<Vec<T>> = slices.iter().map(|s| s.to_vec()).collect();
let mut seen = std::collections::HashMap::new();
let mut unique_slices = Vec::new();
let mut first_indices = Vec::new();
for (i, slice_vec) in slice_vecs.iter().enumerate() {
if let Some((_first_idx, count)) = seen.get_mut::<Vec<T>>(slice_vec) {
*count += 1;
} else {
seen.insert(slice_vec.clone(), (i, 1));
unique_slices.push(slices[i].clone());
first_indices.push(i);
}
}
let mut sorted_indices: Vec<usize> = (0..unique_slices.len()).collect();
sorted_indices.sort_by(|&a, &b| {
let slice_a = unique_slices[a].to_vec();
let slice_b = unique_slices[b].to_vec();
slice_a.cmp(&slice_b)
});
let sorted_unique_slices: Vec<Array<T>> = sorted_indices
.iter()
.map(|&i| unique_slices[i].clone())
.collect();
let unique_array = if !sorted_unique_slices.is_empty() {
crate::array_ops::stack(&sorted_unique_slices.iter().collect::<Vec<_>>(), ax)?
} else {
let mut empty_shape = shape.clone();
empty_shape[ax] = 0;
Array::zeros(&empty_shape)
};
let indices_array = if return_index {
let sorted_first_indices: Vec<usize> =
sorted_indices.iter().map(|&i| first_indices[i]).collect();
Array::from_vec(sorted_first_indices)
} else {
Array::from_vec(vec![])
};
let inverse_array = if return_inverse {
let mut slice_to_pos = std::collections::HashMap::new();
for (pos, &sorted_idx) in sorted_indices.iter().enumerate() {
let slice_vec = unique_slices[sorted_idx].to_vec();
slice_to_pos.insert(slice_vec, pos);
}
let inverse: Vec<usize> = slice_vecs
.iter()
.map(|slice_vec| {
*slice_to_pos
.get(slice_vec)
.expect("slice_vec must exist in slice_to_pos map")
})
.collect();
Array::from_vec(inverse)
} else {
Array::from_vec(vec![])
};
let counts_array = if return_counts {
let sorted_counts: Vec<usize> = sorted_indices
.iter()
.map(|&i| {
let slice_vec = unique_slices[i].to_vec();
seen.get(&slice_vec)
.expect("slice_vec must exist in seen map")
.1
})
.collect();
Array::from_vec(sorted_counts)
} else {
Array::from_vec(vec![])
};
Ok((unique_array, indices_array, inverse_array, counts_array))
}
}
}
pub fn ediff1d<T>(
ary: &Array<T>,
to_end: Option<&Array<T>>,
to_begin: Option<&Array<T>>,
) -> Result<Array<T>>
where
T: Clone + std::ops::Sub<Output = T>,
{
let data = ary.to_vec();
if data.len() < 2 {
let mut result = Vec::new();
if let Some(begin_array) = to_begin {
result.extend(begin_array.to_vec());
}
if let Some(end_array) = to_end {
result.extend(end_array.to_vec());
}
return Ok(Array::from_vec(result));
}
let mut differences = Vec::with_capacity(data.len() - 1);
if let Some(begin_array) = to_begin {
differences.extend(begin_array.to_vec());
}
for i in 1..data.len() {
differences.push(data[i].clone() - data[i - 1].clone());
}
if let Some(end_array) = to_end {
differences.extend(end_array.to_vec());
}
Ok(Array::from_vec(differences))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_intersect1d() {
let a = Array::from_vec(vec![1, 3, 4, 3]);
let b = Array::from_vec(vec![3, 1, 2, 1]);
let result = intersect1d(&a, &b).expect("intersect1d should succeed");
assert_eq!(result.to_vec(), vec![1, 3]);
}
#[test]
fn test_union1d() {
let a = Array::from_vec(vec![1, 2, 3, 2, 4]);
let b = Array::from_vec(vec![2, 3, 5, 7, 5]);
let result = union1d(&a, &b).expect("union1d should succeed");
assert_eq!(result.to_vec(), vec![1, 2, 3, 4, 5, 7]);
}
#[test]
fn test_setdiff1d() {
let a = Array::from_vec(vec![1, 2, 3, 2, 4, 1]);
let b = Array::from_vec(vec![3, 4, 5, 6]);
let result = setdiff1d(&a, &b).expect("setdiff1d should succeed");
assert_eq!(result.to_vec(), vec![1, 2]);
}
#[test]
fn test_setxor1d() {
let a = Array::from_vec(vec![1, 2, 3, 2, 4]);
let b = Array::from_vec(vec![2, 3, 5, 7, 5]);
let result = setxor1d(&a, &b).expect("setxor1d should succeed");
assert_eq!(result.to_vec(), vec![1, 4, 5, 7]);
}
#[test]
fn test_in1d() {
let a = Array::from_vec(vec![1, 2, 3, 4, 5, 6]);
let b = Array::from_vec(vec![2, 4, 6]);
let result = in1d(&a, &b).expect("in1d should succeed");
assert_eq!(result.to_vec(), vec![false, true, false, true, false, true]);
}
#[test]
fn test_isin() {
let element = Array::from_vec(vec![0, 1, 2, 5, 0]);
let test_elements = Array::from_vec(vec![0, 2, 5, 7, 9]);
let result = isin(&element, &test_elements, false, false).expect("isin should succeed");
assert_eq!(result.to_vec(), vec![true, false, true, true, true]);
let result_inv =
isin(&element, &test_elements, false, true).expect("isin with invert should succeed");
assert_eq!(result_inv.to_vec(), vec![false, true, false, false, false]);
}
#[test]
fn test_unique_with_options() {
let a = Array::from_vec(vec![1, 1, 2, 2, 3, 3]);
let (unique, indices, inverse, counts) =
unique_with_options(&a, true, true, true).expect("unique_with_options should succeed");
assert_eq!(unique.to_vec(), vec![1, 2, 3]);
assert_eq!(indices.to_vec(), vec![0, 2, 4]);
assert_eq!(inverse.to_vec(), vec![0, 0, 1, 1, 2, 2]);
assert_eq!(counts.to_vec(), vec![2, 2, 2]);
}
#[test]
fn test_ediff1d() {
let a = Array::from_vec(vec![1, 2, 4, 7, 0]);
let result = ediff1d(&a, None, None).expect("ediff1d should succeed");
assert_eq!(result.to_vec(), vec![1, 2, 3, -7]);
let begin = Array::from_vec(vec![-99]);
let end = Array::from_vec(vec![99]);
let result_full =
ediff1d(&a, Some(&end), Some(&begin)).expect("ediff1d with begin/end should succeed");
assert_eq!(result_full.to_vec(), vec![-99, 1, 2, 3, -7, 99]);
let single = Array::from_vec(vec![42]);
let result_single =
ediff1d(&single, None, None).expect("ediff1d with single element should succeed");
assert_eq!(result_single.to_vec(), Vec::<i32>::new());
let empty = Array::from_vec(Vec::<i32>::new());
let result_empty =
ediff1d(&empty, None, None).expect("ediff1d with empty array should succeed");
assert_eq!(result_empty.to_vec(), Vec::<i32>::new());
}
}