pub fn quicksort(data: &mut [f32]) -> Result<(), crate::traits::SimdError> {
quicksort_f32_simd(data);
Ok(())
}
pub fn quicksort_f32_simd(arr: &mut [f32]) {
if arr.len() <= 1 {
return;
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("avx2") && arr.len() >= 16 {
unsafe { quicksort_avx2(arr) };
return;
} else if crate::simd_feature_detected!("sse2") && arr.len() >= 8 {
unsafe { quicksort_sse2(arr) };
return;
}
}
quicksort_scalar(arr);
}
fn quicksort_scalar(arr: &mut [f32]) {
if arr.len() <= 1 {
return;
}
let pivot_index = partition_scalar(arr);
quicksort_scalar(&mut arr[0..pivot_index]);
quicksort_scalar(&mut arr[pivot_index + 1..]);
}
fn partition_scalar(arr: &mut [f32]) -> usize {
let len = arr.len();
let pivot = arr[len - 1];
let mut i = 0;
for j in 0..len - 1 {
if arr[j] <= pivot {
arr.swap(i, j);
i += 1;
}
}
arr.swap(i, len - 1);
i
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn quicksort_sse2(arr: &mut [f32]) {
if arr.len() <= 8 {
insertion_sort_simd_sse2(arr);
return;
}
let pivot_index = partition_sse2(arr);
quicksort_sse2(&mut arr[0..pivot_index]);
quicksort_sse2(&mut arr[pivot_index + 1..]);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn partition_sse2(arr: &mut [f32]) -> usize {
use core::arch::x86_64::*;
let len = arr.len();
let pivot = arr[len - 1];
let pivot_vec = _mm_set1_ps(pivot);
let mut left = 0;
let mut right = len - 1;
while left + 4 <= right {
let left_vec = _mm_loadu_ps(&arr[left]);
let cmp_mask = _mm_cmple_ps(left_vec, pivot_vec);
let mask = _mm_movemask_ps(cmp_mask);
for i in 0..4 {
if left < right {
if (mask & (1 << i)) != 0 {
left += 1;
} else {
while right > left && arr[right - 1] > pivot {
right -= 1;
}
if right > left {
arr.swap(left, right - 1);
right -= 1;
left += 1;
}
}
}
}
}
while left < right - 1 {
if arr[left] <= pivot {
left += 1;
} else {
right -= 1;
arr.swap(left, right);
}
}
arr.swap(left, len - 1);
left
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn quicksort_avx2(arr: &mut [f32]) {
if arr.len() <= 16 {
insertion_sort_simd_avx2(arr);
return;
}
let pivot_index = partition_avx2(arr);
quicksort_avx2(&mut arr[0..pivot_index]);
quicksort_avx2(&mut arr[pivot_index + 1..]);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn partition_avx2(arr: &mut [f32]) -> usize {
use core::arch::x86_64::*;
let len = arr.len();
let pivot = arr[len - 1];
let pivot_vec = _mm256_set1_ps(pivot);
let mut left = 0;
let mut right = len - 1;
while left + 8 <= right {
let left_vec = _mm256_loadu_ps(&arr[left]);
let cmp_mask = _mm256_cmp_ps(left_vec, pivot_vec, _CMP_LE_OQ);
let mask = _mm256_movemask_ps(cmp_mask);
for i in 0..8 {
if left < right {
if (mask & (1 << i)) != 0 {
left += 1;
} else {
while right > left && arr[right - 1] > pivot {
right -= 1;
}
if right > left {
arr.swap(left, right - 1);
right -= 1;
left += 1;
}
}
}
}
}
while left < right - 1 {
if arr[left] <= pivot {
left += 1;
} else {
right -= 1;
arr.swap(left, right);
}
}
arr.swap(left, len - 1);
left
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn insertion_sort_simd_sse2(arr: &mut [f32]) {
use core::arch::x86_64::*;
if arr.len() <= 1 {
return;
}
if arr.len() <= 4 {
for i in 1..arr.len() {
let key = arr[i];
let mut j = i;
while j > 0 && arr[j - 1] > key {
arr[j] = arr[j - 1];
j -= 1;
}
arr[j] = key;
}
return;
}
for i in 1..arr.len() {
let key = arr[i];
let mut j = i;
if j >= 4 {
let vec = _mm_loadu_ps(&arr[j - 4]);
let key_vec = _mm_set1_ps(key);
let cmp = _mm_cmpgt_ps(vec, key_vec);
let mask = _mm_movemask_ps(cmp);
if mask != 0 {
while j > 0 && arr[j - 1] > key {
arr[j] = arr[j - 1];
j -= 1;
}
}
}
while j > 0 && arr[j - 1] > key {
arr[j] = arr[j - 1];
j -= 1;
}
arr[j] = key;
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn insertion_sort_simd_avx2(arr: &mut [f32]) {
use core::arch::x86_64::*;
if arr.len() <= 1 {
return;
}
if arr.len() <= 8 {
for i in 1..arr.len() {
let key = arr[i];
let mut j = i;
while j > 0 && arr[j - 1] > key {
arr[j] = arr[j - 1];
j -= 1;
}
arr[j] = key;
}
return;
}
for i in 1..arr.len() {
let key = arr[i];
let mut j = i;
if j >= 8 {
let vec = _mm256_loadu_ps(&arr[j - 8]);
let key_vec = _mm256_set1_ps(key);
let cmp = _mm256_cmp_ps(vec, key_vec, _CMP_GT_OQ);
let mask = _mm256_movemask_ps(cmp);
if mask != 0 {
while j > 0 && arr[j - 1] > key {
arr[j] = arr[j - 1];
j -= 1;
}
}
}
while j > 0 && arr[j - 1] > key {
arr[j] = arr[j - 1];
j -= 1;
}
arr[j] = key;
}
}
pub fn bitonic_sort_f32_simd(arr: &mut [f32], ascending: bool) {
let len = arr.len();
assert!(
len.is_power_of_two(),
"Bitonic sort requires power-of-2 length"
);
if len <= 1 {
return;
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("avx2") && len >= 8 {
unsafe { bitonic_sort_avx2(arr, ascending) };
return;
} else if crate::simd_feature_detected!("sse2") && len >= 4 {
unsafe { bitonic_sort_sse2(arr, ascending) };
return;
}
}
bitonic_sort_scalar(arr, ascending);
}
fn bitonic_sort_scalar(arr: &mut [f32], ascending: bool) {
let len = arr.len();
if len <= 1 {
return;
}
if len == 2 {
if (arr[0] > arr[1]) == ascending {
arr.swap(0, 1);
}
return;
}
let mid = len / 2;
bitonic_sort_scalar(&mut arr[0..mid], true);
bitonic_sort_scalar(&mut arr[mid..], false);
bitonic_merge_scalar(arr, ascending);
}
fn bitonic_merge_scalar(arr: &mut [f32], ascending: bool) {
let len = arr.len();
if len <= 1 {
return;
}
let step = len / 2;
for i in 0..step {
if (arr[i] > arr[i + step]) == ascending {
arr.swap(i, i + step);
}
}
if step > 1 {
bitonic_merge_scalar(&mut arr[0..step], ascending);
bitonic_merge_scalar(&mut arr[step..], ascending);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn bitonic_sort_sse2(arr: &mut [f32], ascending: bool) {
let len = arr.len();
if len <= 4 {
bitonic_sort_4_sse2(arr, ascending);
return;
}
let mid = len / 2;
bitonic_sort_sse2(&mut arr[0..mid], true);
bitonic_sort_sse2(&mut arr[mid..], false);
bitonic_merge_sse2(arr, ascending);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn bitonic_sort_4_sse2(arr: &mut [f32], ascending: bool) {
use core::arch::x86_64::*;
if arr.len() != 4 {
bitonic_sort_scalar(arr, ascending);
return;
}
let temp = [arr[0], arr[1], arr[2], arr[3]];
let mut sorted = temp;
sorted.sort_by(|a, b| {
if ascending {
a.partial_cmp(b).expect("operation should succeed")
} else {
b.partial_cmp(a).expect("operation should succeed")
}
});
let vec = _mm_loadu_ps(sorted.as_ptr());
_mm_storeu_ps(arr.as_mut_ptr(), vec);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn bitonic_merge_sse2(arr: &mut [f32], ascending: bool) {
use core::arch::x86_64::*;
let len = arr.len();
if len <= 4 {
bitonic_merge_scalar(arr, ascending);
return;
}
let step = len / 2;
let mut i = 0;
while i + 4 <= step {
let vec1 = _mm_loadu_ps(&arr[i]);
let vec2 = _mm_loadu_ps(&arr[i + step]);
let cmp = if ascending {
_mm_cmpgt_ps(vec1, vec2)
} else {
_mm_cmplt_ps(vec1, vec2)
};
let mask = _mm_movemask_ps(cmp);
for j in 0..4 {
if (mask & (1 << j)) != 0 {
arr.swap(i + j, i + j + step);
}
}
i += 4;
}
while i < step {
if (arr[i] > arr[i + step]) == ascending {
arr.swap(i, i + step);
}
i += 1;
}
if step > 1 {
bitonic_merge_sse2(&mut arr[0..step], ascending);
bitonic_merge_sse2(&mut arr[step..], ascending);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn bitonic_sort_avx2(arr: &mut [f32], ascending: bool) {
let len = arr.len();
if len <= 8 {
bitonic_sort_8_avx2(arr, ascending);
return;
}
let mid = len / 2;
bitonic_sort_avx2(&mut arr[0..mid], true);
bitonic_sort_avx2(&mut arr[mid..], false);
bitonic_merge_avx2(arr, ascending);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn bitonic_sort_8_avx2(arr: &mut [f32], ascending: bool) {
use core::arch::x86_64::*;
if arr.len() != 8 {
bitonic_sort_scalar(arr, ascending);
return;
}
let temp = [
arr[0], arr[1], arr[2], arr[3], arr[4], arr[5], arr[6], arr[7],
];
let mut sorted = temp;
sorted.sort_by(|a, b| {
if ascending {
a.partial_cmp(b).expect("operation should succeed")
} else {
b.partial_cmp(a).expect("operation should succeed")
}
});
let vec = _mm256_loadu_ps(sorted.as_ptr());
_mm256_storeu_ps(arr.as_mut_ptr(), vec);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn bitonic_merge_avx2(arr: &mut [f32], ascending: bool) {
use core::arch::x86_64::*;
let len = arr.len();
if len <= 8 {
bitonic_merge_scalar(arr, ascending);
return;
}
let step = len / 2;
let mut i = 0;
while i + 8 <= step {
let vec1 = _mm256_loadu_ps(&arr[i]);
let vec2 = _mm256_loadu_ps(&arr[i + step]);
let cmp = if ascending {
_mm256_cmp_ps(vec1, vec2, _CMP_GT_OQ)
} else {
_mm256_cmp_ps(vec1, vec2, _CMP_LT_OQ)
};
let mask = _mm256_movemask_ps(cmp);
for j in 0..8 {
if (mask & (1 << j)) != 0 {
arr.swap(i + j, i + j + step);
}
}
i += 8;
}
while i < step {
if (arr[i] > arr[i + step]) == ascending {
arr.swap(i, i + step);
}
i += 1;
}
if step > 1 {
bitonic_merge_avx2(&mut arr[0..step], ascending);
bitonic_merge_avx2(&mut arr[step..], ascending);
}
}
pub fn median_f32_simd(arr: &mut [f32]) -> Option<f32> {
if arr.is_empty() {
return None;
}
let len = arr.len();
let mid = len / 2;
if len % 2 == 1 {
Some(quickselect_f32_simd(arr, mid))
} else {
let left_mid = quickselect_f32_simd(arr, mid - 1);
let right_mid = quickselect_f32_simd(arr, mid);
Some((left_mid + right_mid) / 2.0)
}
}
pub fn quickselect_f32_simd(arr: &mut [f32], k: usize) -> f32 {
assert!(k < arr.len(), "k must be less than array length");
let mut left = 0;
let mut right = arr.len() - 1;
loop {
if left == right {
return arr[left];
}
let pivot_index = partition_range(arr, left, right);
if k == pivot_index {
return arr[k];
} else if k < pivot_index {
right = pivot_index - 1;
} else {
left = pivot_index + 1;
}
}
}
fn partition_range(arr: &mut [f32], left: usize, right: usize) -> usize {
let pivot = arr[right];
let mut i = left;
for j in left..right {
if arr[j] <= pivot {
arr.swap(i, j);
i += 1;
}
}
arr.swap(i, right);
i
}
#[allow(non_snake_case)]
#[cfg(all(test, not(feature = "no-std")))]
mod tests {
use super::*;
use scirs2_core::random::prelude::*;
#[cfg(feature = "no-std")]
use alloc::{vec, vec::Vec};
fn is_sorted(arr: &[f32], ascending: bool) -> bool {
for i in 1..arr.len() {
if ascending && arr[i - 1] > arr[i] {
return false;
}
if !ascending && arr[i - 1] < arr[i] {
return false;
}
}
true
}
#[test]
fn test_quicksort_simd() {
let mut arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
quicksort_f32_simd(&mut arr);
assert!(is_sorted(&arr, true));
}
#[test]
fn test_quicksort_random() {
let mut rng = thread_rng();
let mut arr: Vec<f32> = (0..100).map(|_| rng.random_range(0.0..100.0)).collect();
quicksort_f32_simd(&mut arr);
assert!(is_sorted(&arr, true));
}
#[test]
fn test_bitonic_sort_small() {
let mut arr = vec![4.0, 2.0, 7.0, 1.0];
bitonic_sort_f32_simd(&mut arr, true);
assert!(is_sorted(&arr, true));
let mut arr = vec![4.0, 2.0, 7.0, 1.0];
bitonic_sort_f32_simd(&mut arr, false);
assert!(is_sorted(&arr, false));
}
#[test]
fn test_bitonic_sort_power_of_2() {
let mut arr = vec![8.0, 4.0, 2.0, 1.0, 3.0, 6.0, 5.0, 7.0];
bitonic_sort_f32_simd(&mut arr, true);
assert!(is_sorted(&arr, true));
}
#[test]
fn test_median_odd() {
let mut arr = vec![3.0, 1.0, 4.0, 1.0, 5.0];
let median = median_f32_simd(&mut arr);
assert_eq!(median, Some(3.0));
}
#[test]
fn test_median_even() {
let mut arr = vec![3.0, 1.0, 4.0, 2.0];
let median = median_f32_simd(&mut arr);
assert_eq!(median, Some(2.5)); }
#[test]
fn test_quickselect() {
let mut arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
let third_smallest = quickselect_f32_simd(&mut arr, 2);
let mut sorted = arr.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
assert_eq!(third_smallest, sorted[2]);
}
#[test]
fn test_empty_median() {
let mut arr: Vec<f32> = vec![];
let median = median_f32_simd(&mut arr);
assert_eq!(median, None);
}
#[test]
fn test_single_element() {
let mut arr = vec![42.0];
quicksort_f32_simd(&mut arr);
assert_eq!(arr, vec![42.0]);
let median = median_f32_simd(&mut arr);
assert_eq!(median, Some(42.0));
}
}