use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{ArrayBase, Data, DataMut, Ix1};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::simd_ops::{AutoOptimizer, SimdUnifiedOps};
#[allow(dead_code)]
pub fn quickselect_simd<F>(arr: &mut [F], k: usize) -> F
where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
{
if arr.len() == 1 {
return arr[0];
}
let mut left = 0;
let mut right = arr.len() - 1;
let optimizer = AutoOptimizer::new();
while left < right {
let pivot_idx = partition_simd(arr, left, right, &optimizer);
if k == pivot_idx {
return arr[k];
} else if k < pivot_idx {
right = pivot_idx - 1;
} else {
left = pivot_idx + 1;
}
}
arr[k]
}
#[allow(dead_code)]
fn partition_simd<F>(arr: &mut [F], left: usize, right: usize, optimizer: &AutoOptimizer) -> usize
where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
{
let mid = left + (right - left) / 2;
let pivot = median_of_three(arr[left], arr[mid], arr[right]);
let mut i = left;
let mut j = right;
let use_simd = optimizer.should_use_simd(right - left + 1);
loop {
if use_simd && j - i > 8 {
while i < j {
let chunksize = (j - i).min(8);
let mut found = false;
for offset in 0..chunksize {
if arr[i + offset] >= pivot {
i += offset;
found = true;
break;
}
}
if !found {
i += chunksize;
} else {
break;
}
}
while i < j {
let chunksize = (j - i).min(8);
let mut found = false;
for offset in 0..chunksize {
if arr[j - offset] <= pivot {
j -= offset;
found = true;
break;
}
}
if !found {
j -= chunksize;
} else {
break;
}
}
} else {
while i < j && arr[i] < pivot {
i += 1;
}
while i < j && arr[j] > pivot {
j -= 1;
}
}
if i >= j {
break;
}
arr.swap(i, j);
i += 1;
j -= 1;
}
i
}
#[allow(dead_code)]
fn median_of_three<F: Float>(a: F, b: F, c: F) -> F {
if a <= b {
if b <= c {
b
} else if a <= c {
c
} else {
a
}
} else if a <= c {
a
} else if b <= c {
c
} else {
b
}
}
#[allow(dead_code)]
pub fn quantile_simd<F, D>(x: &mut ArrayBase<D, Ix1>, q: F, method: &str) -> StatsResult<F>
where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
D: DataMut<Elem = F>,
{
let n = x.len();
if n == 0 {
return Err(StatsError::invalid_argument(
"Cannot compute quantile of empty array",
));
}
if q < F::zero() || q > F::one() {
return Err(StatsError::invalid_argument(
"Quantile must be between 0 and 1",
));
}
if n == 1 {
return Ok(x[0]);
}
if q == F::zero() {
return Ok(*x
.iter()
.min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
.expect("Operation failed"));
}
if q == F::one() {
return Ok(*x
.iter()
.max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
.expect("Operation failed"));
}
let data = x.as_slice_mut().expect("Operation failed");
let pos = q * F::from(n - 1).expect("Failed to convert to float");
let lower_idx = pos.floor().to_usize().expect("Operation failed");
let upper_idx = pos.ceil().to_usize().expect("Operation failed");
let fraction = pos - pos.floor();
if lower_idx == upper_idx {
Ok(quickselect_simd(data, lower_idx))
} else {
let lower_val = quickselect_simd(data, lower_idx);
let upper_val = quickselect_simd(data, upper_idx);
match method {
"linear" => Ok(lower_val + fraction * (upper_val - lower_val)),
"lower" => Ok(lower_val),
"higher" => Ok(upper_val),
"midpoint" => Ok((lower_val + upper_val)
/ F::from(2.0).expect("Failed to convert constant to float")),
"nearest" => {
if fraction < F::from(0.5).expect("Failed to convert constant to float") {
Ok(lower_val)
} else {
Ok(upper_val)
}
}
_ => Err(StatsError::invalid_argument(format!(
"Unknown interpolation method: {}",
method
))),
}
}
}
#[allow(dead_code)]
pub fn quantiles_simd<F, D1, D2>(
x: &mut ArrayBase<D1, Ix1>,
quantiles: &ArrayBase<D2, Ix1>,
method: &str,
) -> StatsResult<scirs2_core::ndarray::Array1<F>>
where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
D1: DataMut<Elem = F>,
D2: Data<Elem = F>,
{
let n = x.len();
if n == 0 {
return Err(StatsError::invalid_argument(
"Cannot compute quantiles of empty array",
));
}
for &q in quantiles.iter() {
if q < F::zero() || q > F::one() {
return Err(StatsError::invalid_argument(
"All quantiles must be between 0 and 1",
));
}
}
let mut results = scirs2_core::ndarray::Array1::zeros(quantiles.len());
if quantiles.len() > 1 {
let data = x.as_slice_mut().expect("Operation failed");
simd_sort(data);
for (i, &q) in quantiles.iter().enumerate() {
results[i] = compute_quantile_from_sorted(data, q, method)?;
}
} else {
results[0] = quantile_simd(x, quantiles[0], method)?;
}
Ok(results)
}
pub(crate) fn simd_sort<F>(data: &mut [F])
where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
{
let n = data.len();
let optimizer = AutoOptimizer::new();
if n <= 1 {
return;
}
if n <= 32 {
insertion_sort(data);
return;
}
let max_depth = (n.ilog2() * 2) as usize;
introsort_simd(data, 0, n - 1, max_depth, &optimizer);
}
#[allow(dead_code)]
fn insertion_sort<F: Float>(data: &mut [F]) {
for i in 1..data.len() {
let key = data[i];
let mut j = i;
while j > 0 && data[j - 1] > key {
data[j] = data[j - 1];
j -= 1;
}
data[j] = key;
}
}
#[allow(dead_code)]
fn introsort_simd<F>(
data: &mut [F],
left: usize,
right: usize,
depth_limit: usize,
optimizer: &AutoOptimizer,
) where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
{
if right <= left {
return;
}
let size = right - left + 1;
if size <= 16 {
insertion_sort(&mut data[left..=right]);
return;
}
if depth_limit == 0 {
heapsort(&mut data[left..=right]);
return;
}
let pivot_idx = partition_simd(data, left, right, optimizer);
if pivot_idx > left {
introsort_simd(data, left, pivot_idx - 1, depth_limit - 1, optimizer);
}
if pivot_idx < right {
introsort_simd(data, pivot_idx + 1, right, depth_limit - 1, optimizer);
}
}
#[allow(dead_code)]
fn heapsort<F: Float>(data: &mut [F]) {
let n = data.len();
for i in (0..n / 2).rev() {
heapify(data, n, i);
}
for i in (1..n).rev() {
data.swap(0, i);
heapify(data, i, 0);
}
}
#[allow(dead_code)]
fn heapify<F: Float>(data: &mut [F], n: usize, i: usize) {
let mut largest = i;
let left = 2 * i + 1;
let right = 2 * i + 2;
if left < n && data[left] > data[largest] {
largest = left;
}
if right < n && data[right] > data[largest] {
largest = right;
}
if largest != i {
data.swap(i, largest);
heapify(data, n, largest);
}
}
#[allow(dead_code)]
fn compute_quantile_from_sorted<F>(sorteddata: &[F], q: F, method: &str) -> StatsResult<F>
where
F: Float + NumCast + std::fmt::Display,
{
let n = sorteddata.len();
if q == F::zero() {
return Ok(sorteddata[0]);
}
if q == F::one() {
return Ok(sorteddata[n - 1]);
}
let pos = q * F::from(n - 1).expect("Failed to convert to float");
let lower_idx = pos.floor().to_usize().expect("Operation failed");
let upper_idx = pos.ceil().to_usize().expect("Operation failed");
let fraction = pos - pos.floor();
if lower_idx == upper_idx {
Ok(sorteddata[lower_idx])
} else {
let lower_val = sorteddata[lower_idx];
let upper_val = sorteddata[upper_idx];
match method {
"linear" => Ok(lower_val + fraction * (upper_val - lower_val)),
"lower" => Ok(lower_val),
"higher" => Ok(upper_val),
"midpoint" => Ok((lower_val + upper_val)
/ F::from(2.0).expect("Failed to convert constant to float")),
"nearest" => {
if fraction < F::from(0.5).expect("Failed to convert constant to float") {
Ok(lower_val)
} else {
Ok(upper_val)
}
}
_ => Err(StatsError::invalid_argument(format!(
"Unknown interpolation method: {}",
method
))),
}
}
}
#[allow(dead_code)]
pub fn median_simd<F, D>(x: &mut ArrayBase<D, Ix1>) -> StatsResult<F>
where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
D: DataMut<Elem = F>,
{
quantile_simd(
x,
F::from(0.5).expect("Failed to convert constant to float"),
"linear",
)
}
#[allow(dead_code)]
pub fn percentile_simd<F, D>(x: &mut ArrayBase<D, Ix1>, p: F, method: &str) -> StatsResult<F>
where
F: Float + NumCast + SimdUnifiedOps + std::fmt::Display,
D: DataMut<Elem = F>,
{
if p < F::zero() || p > F::from(100.0).expect("Failed to convert constant to float") {
return Err(StatsError::invalid_argument(
"Percentile must be between 0 and 100",
));
}
quantile_simd(
x,
p / F::from(100.0).expect("Failed to convert constant to float"),
method,
)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_quickselect_simd() {
let mut data = vec![5.0, 3.0, 7.0, 1.0, 9.0, 2.0, 8.0, 4.0, 6.0];
let result = quickselect_simd(&mut data, 4); assert_relative_eq!(result, 5.0, epsilon = 1e-10);
}
#[test]
fn test_quantile_simd() {
let mut data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let median = quantile_simd(&mut data.view_mut(), 0.5, "linear").expect("Operation failed");
assert_relative_eq!(median, 5.0, epsilon = 1e-10);
let q1 = quantile_simd(&mut data.view_mut(), 0.25, "linear").expect("Operation failed");
assert_relative_eq!(q1, 3.0, epsilon = 1e-10);
let q3 = quantile_simd(&mut data.view_mut(), 0.75, "linear").expect("Operation failed");
assert_relative_eq!(q3, 7.0, epsilon = 1e-10);
}
#[test]
fn test_quantiles_simd() {
let mut data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let quantiles = array![0.1, 0.25, 0.5, 0.75, 0.9];
let results = quantiles_simd(&mut data.view_mut(), &quantiles.view(), "linear")
.expect("Operation failed");
assert_relative_eq!(results[0], 1.9, epsilon = 1e-10); assert_relative_eq!(results[1], 3.25, epsilon = 1e-10); assert_relative_eq!(results[2], 5.5, epsilon = 1e-10); assert_relative_eq!(results[3], 7.75, epsilon = 1e-10); assert_relative_eq!(results[4], 9.1, epsilon = 1e-10); }
#[test]
fn test_simd_sort() {
let mut data = vec![9.0, 3.0, 7.0, 1.0, 5.0, 8.0, 2.0, 6.0, 4.0];
simd_sort(&mut data);
let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
for (a, b) in data.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
}