use crate::{RadixResult, RadixSort, RadixSortable};
impl<T> RadixSort<T>
where
T: RadixSortable,
{
pub(crate) fn sort_floats(&self, slice: &mut [T]) -> RadixResult<()> {
if slice.len() <= 1 {
return Ok(());
}
let key_size = T::key_size();
let mut keys: Vec<T::RadixKey> = slice.iter().map(|v| T::to_radix_key(v.clone())).collect();
let mut key_buffer: Vec<T::RadixKey> = Vec::with_capacity(keys.len());
if let Some(&first_key) = keys.first() {
key_buffer.resize(keys.len(), first_key);
} else {
return Ok(());
}
for byte_index in 0..key_size {
self.counting_sort_float_bytes(&mut keys, &mut key_buffer, byte_index);
keys.swap_with_slice(&mut key_buffer);
}
for (i, key) in keys.iter().enumerate() {
slice[i] = T::from_radix_key(*key);
}
Ok(())
}
fn counting_sort_float_bytes(
&self,
input: &mut [T::RadixKey],
output: &mut [T::RadixKey],
byte_index: usize,
) {
const RADIX: usize = 256;
let mut count = [0usize; RADIX];
for &key in input.iter() {
let byte_value = T::extract_byte(key, byte_index) as usize;
count[byte_value] += 1;
}
let mut sum = 0;
for i in 0..RADIX {
sum += count[i];
count[i] = sum;
}
for i in (0..input.len()).rev() {
let byte_value = T::extract_byte(input[i], byte_index) as usize;
count[byte_value] -= 1;
output[count[byte_value]] = input[i];
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{RadixDataType, SortDirection};
#[test]
fn test_sort_f64_basic() {
let mut values = vec![3.14f64, -1.25, 0.5, -99.9, 42.0];
let sorter = RadixSort::<f64>::new(RadixDataType::Float, SortDirection::Ascending);
sorter.sort(&mut values).unwrap();
assert_eq!(values, vec![-99.9, -1.25, 0.5, 3.14, 42.0]);
}
#[test]
fn test_sort_f64_with_nan() {
let mut values = vec![3.14f64, f64::NAN, -1.25, 0.5];
let sorter = RadixSort::<f64>::new(RadixDataType::Float, SortDirection::Ascending);
sorter.sort(&mut values).unwrap();
assert!(values[3].is_nan());
assert_eq!(values[..3], [-1.25, 0.5, 3.14]);
}
#[test]
fn test_sort_f64_with_infinity() {
let mut values = vec![3.14f64, f64::INFINITY, -1.25, f64::NEG_INFINITY, 0.5];
let sorter = RadixSort::<f64>::new(RadixDataType::Float, SortDirection::Ascending);
sorter.sort(&mut values).unwrap();
assert_eq!(values[0], f64::NEG_INFINITY);
assert_eq!(values[4], f64::INFINITY);
}
#[test]
fn test_sort_f32_basic() {
let mut values = vec![3.14f32, -1.25, 0.5, -99.9];
let sorter = RadixSort::<f32>::new(RadixDataType::Float, SortDirection::Ascending);
sorter.sort(&mut values).unwrap();
assert_eq!(values, vec![-99.9, -1.25, 0.5, 3.14]);
}
#[test]
fn test_sort_f64_descending() {
let mut values = vec![3.14f64, -1.25, 0.5];
let sorter = RadixSort::<f64>::new(RadixDataType::Float, SortDirection::Descending);
sorter.sort(&mut values).unwrap();
assert_eq!(values, vec![3.14, 0.5, -1.25]);
}
#[test]
fn test_negative_zero() {
let mut values = vec![0.0f64, -0.0, 1.0, -1.0];
let sorter = RadixSort::<f64>::default();
sorter.sort(&mut values).unwrap();
assert!(values[0] == -1.0);
assert!(values[1] == 0.0 || values[1] == -0.0);
}
}