use crate::array::Array;
use crate::error::{NumRs2Error, Result};
pub fn partition<T: Clone + PartialOrd>(
array: &Array<T>,
kth: usize,
axis: Option<usize>,
) -> Result<Array<T>> {
match axis {
None => {
let mut data = array.to_vec();
let n = data.len();
if kth >= n {
return Err(NumRs2Error::DimensionMismatch(format!(
"kth ({}) is out of bounds for array of size {}",
kth, n
)));
}
quick_select(&mut data, 0, n - 1, kth);
Ok(Array::from_vec(data).reshape(&array.shape()))
}
Some(axis_val) => {
let shape = array.shape();
if axis_val >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis_val,
shape.len()
)));
}
let axis_size = shape[axis_val];
if kth >= axis_size {
return Err(NumRs2Error::DimensionMismatch(format!(
"kth ({}) is out of bounds for axis {} with size {}",
kth, axis_val, axis_size
)));
}
let mut result = array.clone();
let result_vec = result.array_mut().as_slice_mut().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to get mutable slice".into())
})?;
let pre_axis_size: usize = shape.iter().take(axis_val).product();
let post_axis_size: usize = shape.iter().skip(axis_val + 1).product();
for i_pre in 0..pre_axis_size {
for i_post in 0..post_axis_size {
let mut slice = Vec::with_capacity(axis_size);
for i_axis in 0..axis_size {
let idx =
i_pre * (axis_size * post_axis_size) + i_axis * post_axis_size + i_post;
slice.push(result_vec[idx].clone());
}
quick_select(&mut slice, 0, axis_size - 1, kth);
#[allow(clippy::needless_range_loop)]
for i_axis in 0..axis_size {
let idx =
i_pre * (axis_size * post_axis_size) + i_axis * post_axis_size + i_post;
result_vec[idx] = slice[i_axis].clone();
}
}
}
Ok(result)
}
}
}
fn quick_select<T: Clone + PartialOrd>(arr: &mut [T], left: usize, right: usize, k: usize) {
if left == right {
return;
}
let pivot_idx = choose_pivot(arr, left, right);
let pivot_idx = partition_around_pivot(arr, left, right, pivot_idx);
match k.cmp(&pivot_idx) {
std::cmp::Ordering::Equal => {
}
std::cmp::Ordering::Less => {
if pivot_idx > 0 {
quick_select(arr, left, pivot_idx - 1, k);
}
}
std::cmp::Ordering::Greater => {
quick_select(arr, pivot_idx + 1, right, k);
}
}
}
fn choose_pivot<T: PartialOrd>(arr: &[T], left: usize, right: usize) -> usize {
if right - left < 2 {
return left;
}
let mid = left + (right - left) / 2;
let mut indices = [left, mid, right];
if arr[indices[0]] > arr[indices[1]] {
indices.swap(0, 1);
}
if arr[indices[1]] > arr[indices[2]] {
indices.swap(1, 2);
}
if arr[indices[0]] > arr[indices[1]] {
indices.swap(0, 1);
}
indices[1]
}
fn partition_around_pivot<T: Clone + PartialOrd>(
arr: &mut [T],
left: usize,
right: usize,
pivot_idx: usize,
) -> usize {
let pivot_value = arr[pivot_idx].clone();
arr.swap(pivot_idx, right);
let mut store_idx = left;
for i in left..right {
if arr[i] < pivot_value {
arr.swap(i, store_idx);
store_idx += 1;
}
}
arr.swap(store_idx, right);
store_idx
}
pub fn searchsorted<T: Clone + PartialOrd>(
a: &Array<T>,
v: &Array<T>,
side: Option<&str>,
sorter: Option<&Array<usize>>,
) -> Result<Array<usize>> {
let side = side.unwrap_or("left");
if side != "left" && side != "right" {
return Err(NumRs2Error::InvalidOperation(format!(
"Side '{}' is invalid, must be 'left' or 'right'",
side
)));
}
let a_sorted = if let Some(sorter_array) = sorter {
if sorter_array.ndim() != 1 {
return Err(NumRs2Error::InvalidOperation(
"Sorter array must be 1-dimensional".into(),
));
}
if sorter_array.size() != a.size() {
return Err(NumRs2Error::InvalidOperation(format!(
"Sorter size ({}) does not match array size ({})",
sorter_array.size(),
a.size()
)));
}
let mut sorted_data = Vec::with_capacity(a.size());
let a_vec = a.to_vec();
let sorter_vec = sorter_array.to_vec();
for &idx in &sorter_vec {
if idx >= a_vec.len() {
return Err(NumRs2Error::InvalidOperation(format!(
"Sorter index {} out of range for array of size {}",
idx,
a_vec.len()
)));
}
sorted_data.push(a_vec[idx].clone());
}
Array::from_vec(sorted_data)
} else {
a.clone()
};
let a_flat = if a_sorted.ndim() != 1 {
a_sorted.flatten(None)
} else {
a_sorted
};
let a_flat_vec = a_flat.to_vec();
for i in 1..a_flat_vec.len() {
if a_flat_vec[i] < a_flat_vec[i - 1] {
return Err(NumRs2Error::InvalidOperation(
"The input array must be sorted in ascending order".into(),
));
}
}
let v_vec = v.to_vec();
let mut result = Vec::with_capacity(v_vec.len());
for val in &v_vec {
let idx = if side == "left" {
binary_search_left(&a_flat_vec, val)
} else {
binary_search_right(&a_flat_vec, val)
};
result.push(idx);
}
Ok(Array::from_vec(result).reshape(&v.shape()))
}
fn binary_search_left<T: PartialOrd>(arr: &[T], value: &T) -> usize {
let mut left = 0;
let mut right = arr.len();
while left < right {
let mid = left + (right - left) / 2;
if &arr[mid] < value {
left = mid + 1;
} else {
right = mid;
}
}
left
}
fn binary_search_right<T: PartialOrd>(arr: &[T], value: &T) -> usize {
let mut left = 0;
let mut right = arr.len();
while left < right {
let mid = left + (right - left) / 2;
if value < &arr[mid] {
right = mid;
} else {
left = mid + 1;
}
}
left
}