use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{NumCast, ToPrimitive, Zero};
use scirs2_core::Complex;
pub fn partition<T: Clone + PartialOrd>(
array: &Array<T>,
kth: usize,
axis: Option<usize>,
) -> Result<Array<T>> {
match axis {
None => {
let mut data = array.to_vec();
let n = data.len();
if kth >= n {
return Err(NumRs2Error::DimensionMismatch(format!(
"kth ({}) is out of bounds for array of size {}",
kth, n
)));
}
quick_select(&mut data, 0, n - 1, kth);
Ok(Array::from_vec(data).reshape(&array.shape()))
}
Some(axis_val) => {
let shape = array.shape();
if axis_val >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis_val,
shape.len()
)));
}
let axis_size = shape[axis_val];
if kth >= axis_size {
return Err(NumRs2Error::DimensionMismatch(format!(
"kth ({}) is out of bounds for axis {} with size {}",
kth, axis_val, axis_size
)));
}
let mut result = array.clone();
let result_vec = result.array_mut().as_slice_mut().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to get mutable slice".into())
})?;
let pre_axis_size: usize = shape.iter().take(axis_val).product();
let post_axis_size: usize = shape.iter().skip(axis_val + 1).product();
for i_pre in 0..pre_axis_size {
for i_post in 0..post_axis_size {
let mut slice = Vec::with_capacity(axis_size);
for i_axis in 0..axis_size {
let idx =
i_pre * (axis_size * post_axis_size) + i_axis * post_axis_size + i_post;
slice.push(result_vec[idx].clone());
}
quick_select(&mut slice, 0, axis_size - 1, kth);
#[allow(clippy::needless_range_loop)]
for i_axis in 0..axis_size {
let idx =
i_pre * (axis_size * post_axis_size) + i_axis * post_axis_size + i_post;
result_vec[idx] = slice[i_axis].clone();
}
}
}
Ok(result)
}
}
}
fn quick_select<T: Clone + PartialOrd>(arr: &mut [T], left: usize, right: usize, k: usize) {
if left == right {
return;
}
let pivot_idx = choose_pivot(arr, left, right);
let pivot_idx = partition_around_pivot(arr, left, right, pivot_idx);
match k.cmp(&pivot_idx) {
std::cmp::Ordering::Equal => {
}
std::cmp::Ordering::Less => {
if pivot_idx > 0 {
quick_select(arr, left, pivot_idx - 1, k);
}
}
std::cmp::Ordering::Greater => {
quick_select(arr, pivot_idx + 1, right, k);
}
}
}
fn choose_pivot<T: PartialOrd>(arr: &[T], left: usize, right: usize) -> usize {
if right - left < 2 {
return left;
}
let mid = left + (right - left) / 2;
let mut indices = [left, mid, right];
if arr[indices[0]] > arr[indices[1]] {
indices.swap(0, 1);
}
if arr[indices[1]] > arr[indices[2]] {
indices.swap(1, 2);
}
if arr[indices[0]] > arr[indices[1]] {
indices.swap(0, 1);
}
indices[1]
}
fn partition_around_pivot<T: Clone + PartialOrd>(
arr: &mut [T],
left: usize,
right: usize,
pivot_idx: usize,
) -> usize {
let pivot_value = arr[pivot_idx].clone();
arr.swap(pivot_idx, right);
let mut store_idx = left;
for i in left..right {
if arr[i] < pivot_value {
arr.swap(i, store_idx);
store_idx += 1;
}
}
arr.swap(store_idx, right);
store_idx
}
pub fn searchsorted<T: Clone + PartialOrd>(
a: &Array<T>,
v: &Array<T>,
side: Option<&str>,
sorter: Option<&Array<usize>>,
) -> Result<Array<usize>> {
let side = side.unwrap_or("left");
if side != "left" && side != "right" {
return Err(NumRs2Error::InvalidOperation(format!(
"Side '{}' is invalid, must be 'left' or 'right'",
side
)));
}
let a_sorted = if let Some(sorter_array) = sorter {
if sorter_array.ndim() != 1 {
return Err(NumRs2Error::InvalidOperation(
"Sorter array must be 1-dimensional".into(),
));
}
if sorter_array.size() != a.size() {
return Err(NumRs2Error::InvalidOperation(format!(
"Sorter size ({}) does not match array size ({})",
sorter_array.size(),
a.size()
)));
}
let mut sorted_data = Vec::with_capacity(a.size());
let a_vec = a.to_vec();
let sorter_vec = sorter_array.to_vec();
for &idx in &sorter_vec {
if idx >= a_vec.len() {
return Err(NumRs2Error::InvalidOperation(format!(
"Sorter index {} out of range for array of size {}",
idx,
a_vec.len()
)));
}
sorted_data.push(a_vec[idx].clone());
}
Array::from_vec(sorted_data)
} else {
a.clone()
};
let a_flat = if a_sorted.ndim() != 1 {
a_sorted.flatten(None)
} else {
a_sorted
};
let a_flat_vec = a_flat.to_vec();
for i in 1..a_flat_vec.len() {
if a_flat_vec[i] < a_flat_vec[i - 1] {
return Err(NumRs2Error::InvalidOperation(
"The input array must be sorted in ascending order".into(),
));
}
}
let v_vec = v.to_vec();
let mut result = Vec::with_capacity(v_vec.len());
for val in &v_vec {
let idx = if side == "left" {
binary_search_left(&a_flat_vec, val)
} else {
binary_search_right(&a_flat_vec, val)
};
result.push(idx);
}
Ok(Array::from_vec(result).reshape(&v.shape()))
}
fn binary_search_left<T: PartialOrd>(arr: &[T], value: &T) -> usize {
let mut left = 0;
let mut right = arr.len();
while left < right {
let mid = left + (right - left) / 2;
if &arr[mid] < value {
left = mid + 1;
} else {
right = mid;
}
}
left
}
fn binary_search_right<T: PartialOrd>(arr: &[T], value: &T) -> usize {
let mut left = 0;
let mut right = arr.len();
while left < right {
let mid = left + (right - left) / 2;
if value < &arr[mid] {
right = mid;
} else {
left = mid + 1;
}
}
left
}
pub fn sort<T: Clone + PartialOrd>(array: &Array<T>, kind: Option<&str>) -> Result<Array<T>> {
let sort_kind = kind.unwrap_or("quicksort");
match sort_kind {
"mergesort" => msort(array),
"quicksort" => {
let mut data = array.to_vec();
data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(Array::from_vec(data).reshape(&array.shape()))
}
"heapsort" => {
let mut data = array.to_vec();
heap_sort(&mut data);
Ok(Array::from_vec(data).reshape(&array.shape()))
}
_ => Err(NumRs2Error::InvalidOperation(format!(
"Unknown sort kind: {}. Must be 'quicksort', 'mergesort', or 'heapsort'",
sort_kind
))),
}
}
fn heap_sort<T: Clone + PartialOrd>(arr: &mut [T]) {
let len = arr.len();
if len <= 1 {
return;
}
for i in (0..len / 2).rev() {
heapify(arr, len, i);
}
for i in (1..len).rev() {
arr.swap(0, i);
heapify(arr, i, 0);
}
}
fn heapify<T: Clone + PartialOrd>(arr: &mut [T], n: usize, i: usize) {
let mut largest = i;
let left = 2 * i + 1;
let right = 2 * i + 2;
if left < n && arr[left] > arr[largest] {
largest = left;
}
if right < n && arr[right] > arr[largest] {
largest = right;
}
if largest != i {
arr.swap(i, largest);
heapify(arr, n, largest);
}
}
pub fn sort_complex<T>(array: &Array<Complex<T>>) -> Result<Array<Complex<T>>>
where
T: Clone + PartialOrd + num_traits::Float,
{
let flattened = if array.ndim() == 1 {
array.clone()
} else {
array.flatten(None)
};
let mut data = flattened.to_vec();
data.sort_by(|a, b| {
let mag_a = a.norm();
let mag_b = b.norm();
match mag_a
.partial_cmp(&mag_b)
.unwrap_or(std::cmp::Ordering::Equal)
{
std::cmp::Ordering::Equal => {
let arg_a = a.arg();
let arg_b = b.arg();
arg_a
.partial_cmp(&arg_b)
.unwrap_or(std::cmp::Ordering::Equal)
}
other => other,
}
});
Ok(Array::from_vec(data).reshape(&array.shape()))
}
pub fn bincount<T, W>(
x: &Array<T>,
weights: Option<&Array<W>>,
minlength: Option<usize>,
) -> Result<Array<W>>
where
T: Clone + ToPrimitive + PartialOrd + Zero,
W: Clone + Zero + std::ops::AddAssign + NumCast,
{
if x.shape().len() != 1 {
return Err(NumRs2Error::InvalidOperation(
"bincount requires 1D input array".to_string(),
));
}
let x_data = x.to_vec();
let _n = x_data.len();
for val in &x_data {
if *val < T::zero() {
return Err(NumRs2Error::InvalidOperation(
"bincount requires non-negative integers".to_string(),
));
}
}
let max_val = x_data
.iter()
.filter_map(|v| v.to_usize())
.max()
.unwrap_or(0);
let output_len = std::cmp::max(max_val + 1, minlength.unwrap_or(0));
let mut counts = vec![W::zero(); output_len];
if let Some(w) = weights {
if w.shape() != x.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: x.shape().to_vec(),
actual: w.shape().to_vec(),
});
}
let w_data = w.to_vec();
for (i, val) in x_data.iter().enumerate() {
if let Some(idx) = val.to_usize() {
if idx < output_len {
counts[idx] += w_data[i].clone();
}
}
}
} else {
for val in &x_data {
if let Some(idx) = val.to_usize() {
if idx < output_len {
counts[idx] += W::from(1).unwrap_or(W::zero());
}
}
}
}
Ok(Array::from_vec(counts))
}
pub fn digitize<T>(x: &Array<T>, bins: &Array<T>, right: bool) -> Result<Array<usize>>
where
T: Clone + PartialOrd,
{
if bins.shape().len() != 1 {
return Err(NumRs2Error::InvalidOperation(
"bins must be 1-dimensional".to_string(),
));
}
let bins_data = bins.to_vec();
if bins_data.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"bins cannot be empty".to_string(),
));
}
let mut increasing = true;
let mut decreasing = true;
for i in 1..bins_data.len() {
if bins_data[i] <= bins_data[i - 1] {
increasing = false;
}
if bins_data[i] >= bins_data[i - 1] {
decreasing = false;
}
}
if !increasing && !decreasing {
return Err(NumRs2Error::InvalidOperation(
"bins must be monotonically increasing or decreasing".to_string(),
));
}
let x_data = x.to_vec();
let mut indices = Vec::with_capacity(x_data.len());
for val in x_data {
let idx = if increasing {
if right {
binary_search_right(&bins_data, &val)
} else {
binary_search_left(&bins_data, &val)
}
} else {
let n = bins_data.len();
if right {
n - binary_search_left(&bins_data.iter().rev().cloned().collect::<Vec<_>>(), &val)
} else {
n - binary_search_right(&bins_data.iter().rev().cloned().collect::<Vec<_>>(), &val)
}
};
indices.push(idx);
}
Ok(Array::from_vec(indices).reshape(&x.shape()))
}
pub fn lexsort<T: Clone + PartialOrd + Zero>(keys: &Array<T>) -> Result<Array<usize>> {
let shape = keys.shape();
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"lexsort requires a 2D array of keys".to_string(),
));
}
let n_keys = shape[0];
let n_items = shape[1];
if n_items == 0 {
return Ok(Array::from_vec(vec![]));
}
let mut indices: Vec<usize> = (0..n_items).collect();
for key_idx in 0..n_keys {
let key_row_data: Vec<T> = (0..n_items)
.map(|i| {
keys.get(&[key_idx, i])
.expect("key_idx and i should be within bounds as validated by shape")
})
.collect();
indices.sort_by(|&a, &b| {
key_row_data[a]
.partial_cmp(&key_row_data[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
}
Ok(Array::from_vec(indices))
}
pub fn msort<T: Clone + PartialOrd>(array: &Array<T>) -> Result<Array<T>> {
let flattened = if array.ndim() == 1 {
array.clone()
} else {
array.flatten(None)
};
let mut data = flattened.to_vec();
merge_sort(&mut data);
Ok(Array::from_vec(data).reshape(&array.shape()))
}
fn merge_sort<T: Clone + PartialOrd>(arr: &mut [T]) {
let len = arr.len();
if len <= 1 {
return;
}
let mid = len / 2;
merge_sort(&mut arr[..mid]);
merge_sort(&mut arr[mid..]);
let mut temp = Vec::with_capacity(len);
let (left, right) = arr.split_at(mid);
let mut l = 0;
let mut r = 0;
while l < left.len() && r < right.len() {
if left[l] <= right[r] {
temp.push(left[l].clone());
l += 1;
} else {
temp.push(right[r].clone());
r += 1;
}
}
while l < left.len() {
temp.push(left[l].clone());
l += 1;
}
while r < right.len() {
temp.push(right[r].clone());
r += 1;
}
for (i, item) in temp.into_iter().enumerate() {
arr[i] = item;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::Array;
#[test]
fn test_partition_1d() {
let a = Array::from_vec(vec![9, 4, 1, 7, 5, 3, 8, 2, 6]);
let partitioned = partition(&a, 3, None).expect("operation should succeed");
let kth_element = partitioned.get(&[3]).expect("operation should succeed");
for i in 0..3 {
assert!(partitioned.get(&[i]).expect("operation should succeed") <= kth_element);
}
for i in 4..9 {
assert!(partitioned.get(&[i]).expect("operation should succeed") >= kth_element);
}
}
#[test]
fn test_searchsorted_left() {
let a = Array::from_vec(vec![1, 3, 5, 7, 9]);
let v = Array::from_vec(vec![0, 1, 2, 4, 8, 10]);
let indices = searchsorted(&a, &v, Some("left"), None).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![0, 0, 1, 2, 4, 5]);
}
#[test]
fn test_searchsorted_right() {
let a = Array::from_vec(vec![1, 3, 5, 7, 9]);
let v = Array::from_vec(vec![0, 1, 2, 4, 8, 10]);
let indices = searchsorted(&a, &v, Some("right"), None).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![0, 1, 1, 2, 4, 5]);
}
#[test]
fn test_searchsorted_duplicates() {
let a = Array::from_vec(vec![1, 1, 1, 3, 3, 5]);
let v = Array::from_vec(vec![1, 3]);
let indices_left =
searchsorted(&a, &v, Some("left"), None).expect("operation should succeed");
assert_eq!(indices_left.to_vec(), vec![0, 3]);
let indices_right =
searchsorted(&a, &v, Some("right"), None).expect("operation should succeed");
assert_eq!(indices_right.to_vec(), vec![3, 5]); }
#[test]
fn test_binary_search_functions() {
let arr = vec![1, 3, 5, 7, 9];
assert_eq!(binary_search_left(&arr, &0), 0);
assert_eq!(binary_search_left(&arr, &1), 0);
assert_eq!(binary_search_left(&arr, &2), 1);
assert_eq!(binary_search_left(&arr, &10), 5);
assert_eq!(binary_search_right(&arr, &0), 0);
assert_eq!(binary_search_right(&arr, &1), 1);
assert_eq!(binary_search_right(&arr, &2), 1);
assert_eq!(binary_search_right(&arr, &10), 5);
}
#[test]
fn test_bincount() {
let x = Array::from_vec(vec![0, 1, 1, 3, 2, 1, 7]);
let counts: Array<i32> = bincount(&x, None, None).expect("operation should succeed");
assert_eq!(counts.shape(), vec![8]);
assert_eq!(counts.to_vec(), vec![1, 3, 1, 1, 0, 0, 0, 1]);
let counts: Array<i32> = bincount(&x, None, Some(10)).expect("operation should succeed");
assert_eq!(counts.shape(), vec![10]);
assert_eq!(counts.to_vec(), vec![1, 3, 1, 1, 0, 0, 0, 1, 0, 0]);
let weights = Array::from_vec(vec![0.5, 0.5, 0.5, 1.0, 1.0, 0.5, 2.0]);
let weighted_counts: Array<f64> =
bincount(&x, Some(&weights), None).expect("operation should succeed");
assert_eq!(weighted_counts.shape(), vec![8]);
assert_eq!(
weighted_counts.to_vec(),
vec![0.5, 1.5, 1.0, 1.0, 0.0, 0.0, 0.0, 2.0]
);
let empty: Array<i32> = Array::from_vec(vec![]);
let counts: Array<i32> =
bincount(&empty, None::<&Array<i32>>, Some(5)).expect("operation should succeed");
assert_eq!(counts.to_vec(), vec![0, 0, 0, 0, 0]);
}
#[test]
fn test_digitize() {
let x = Array::from_vec(vec![0.2, 6.4, 3.0, 1.6]);
let bins = Array::from_vec(vec![0.0, 1.0, 2.5, 4.0, 10.0]);
let indices = digitize(&x, &bins, false).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![1, 4, 3, 2]);
let x = Array::from_vec(vec![0.0, 1.0, 2.5, 4.0, 10.0]);
let indices = digitize(&x, &bins, false).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![0, 1, 2, 3, 4]);
let indices = digitize(&x, &bins, true).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![1, 2, 3, 4, 5]);
let x = Array::from_vec(vec![-1.0, 15.0]);
let indices = digitize(&x, &bins, false).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![0, 5]);
let bins_dec = Array::from_vec(vec![10.0, 4.0, 2.5, 1.0, 0.0]);
let x = Array::from_vec(vec![0.2, 6.4, 3.0, 1.6]);
let indices = digitize(&x, &bins_dec, false).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![4, 1, 2, 3]);
let x = Array::from_vec(vec![0.2, 6.4, 3.0, 1.6]).reshape(&[2, 2]);
let bins = Array::from_vec(vec![0.0, 1.0, 2.5, 4.0, 10.0]);
let indices = digitize(&x, &bins, false).expect("operation should succeed");
assert_eq!(indices.shape(), vec![2, 2]);
assert_eq!(indices.to_vec(), vec![1, 4, 3, 2]);
}
#[test]
fn test_msort() {
let a = Array::from_vec(vec![3, 1, 4, 1, 5, 9, 2, 6]);
let sorted = msort(&a).expect("operation should succeed");
assert_eq!(sorted.to_vec(), vec![1, 1, 2, 3, 4, 5, 6, 9]);
let empty: Array<i32> = Array::from_vec(vec![]);
let sorted_empty = msort(&empty).expect("operation should succeed");
assert_eq!(sorted_empty.to_vec(), Vec::<i32>::new());
let single = Array::from_vec(vec![42]);
let sorted_single = msort(&single).expect("operation should succeed");
assert_eq!(sorted_single.to_vec(), vec![42]);
let float_arr = Array::from_vec(vec![3.14, 2.71, 1.41, 1.73]);
let sorted_float = msort(&float_arr).expect("operation should succeed");
assert_eq!(sorted_float.to_vec(), vec![1.41, 1.73, 2.71, 3.14]);
}
#[test]
fn test_sort_complex() {
let a = Array::from_vec(vec![
Complex::new(3.0, 4.0), Complex::new(1.0, 0.0), Complex::new(0.0, 1.0), Complex::new(2.0, 0.0), ]);
let sorted = sort_complex(&a).expect("operation should succeed");
let magnitudes: Vec<f64> = sorted.to_vec().iter().map(|c| c.norm()).collect();
for i in 1..magnitudes.len() {
assert!(magnitudes[i] >= magnitudes[i - 1]);
}
let b = Array::from_vec(vec![
Complex::new(1.0, 0.0), Complex::new(0.0, 1.0), Complex::new(-1.0, 0.0), Complex::new(0.0, -1.0), ]);
let sorted_b = sort_complex(&b).expect("operation should succeed");
for val in sorted_b.to_vec() {
assert!((val.norm() - 1.0_f64).abs() < 1e-10);
}
}
#[test]
fn test_generic_sort() {
let a = Array::from_vec(vec![3, 1, 4, 1, 5, 9, 2, 6]);
let sorted = sort(&a, Some("quicksort")).expect("operation should succeed");
assert_eq!(sorted.to_vec(), vec![1, 1, 2, 3, 4, 5, 6, 9]);
let sorted_merge = sort(&a, Some("mergesort")).expect("operation should succeed");
assert_eq!(sorted_merge.to_vec(), vec![1, 1, 2, 3, 4, 5, 6, 9]);
let sorted_heap = sort(&a, Some("heapsort")).expect("operation should succeed");
assert_eq!(sorted_heap.to_vec(), vec![1, 1, 2, 3, 4, 5, 6, 9]);
let sorted_default = sort(&a, None).expect("operation should succeed");
assert_eq!(sorted_default.to_vec(), vec![1, 1, 2, 3, 4, 5, 6, 9]);
let result = sort(&a, Some("invalid"));
assert!(result.is_err());
}
#[test]
fn test_lexsort_basic() {
let keys = Array::from_vec(vec![
1, 0, 1, 0, 1, 2, 2, 1, 1, 3, ])
.reshape(&[2, 5]);
let indices = lexsort(&keys).expect("operation should succeed");
assert_eq!(indices.to_vec(), vec![3, 2, 1, 0, 4]);
}
}