#[cfg(feature = "parallel")]
use rayon;
use std::fmt::Debug;
use crate::error::{Result, SortError};
#[derive(Debug, Clone)]
pub struct MergeSortBuilder {
insertion_threshold: usize,
max_recursion_depth: usize,
parallel: bool,
parallel_threshold: usize,
}
impl Default for MergeSortBuilder {
fn default() -> Self {
Self {
insertion_threshold: 16,
max_recursion_depth: 48,
parallel: false,
parallel_threshold: 1024,
}
}
}
impl MergeSortBuilder {
const MAX_LENGTH: usize = 1 << 48;
pub fn new() -> Self {
Self::default()
}
pub fn insertion_threshold(mut self, threshold: usize) -> Self {
self.insertion_threshold = threshold;
self
}
pub fn max_recursion_depth(mut self, depth: usize) -> Self {
self.max_recursion_depth = depth;
self
}
pub fn parallel(mut self, enabled: bool) -> Self {
self.parallel = enabled;
self
}
pub fn parallel_threshold(mut self, threshold: usize) -> Self {
self.parallel_threshold = threshold;
self
}
pub fn sort<T>(&self, slice: &mut [T]) -> Result<()>
where
T: Ord + Clone + Send + Sync + 'static,
{
if slice.len() <= 1 {
return Ok(());
}
if slice.len() > Self::MAX_LENGTH {
return Err(SortError::input_too_large(slice.len(), Self::MAX_LENGTH));
}
let mut aux = vec![slice[0].clone(); slice.len()];
if self.parallel && slice.len() >= self.parallel_threshold {
self.sort_parallel(slice, &mut aux, 0)
} else {
self.sort_sequential(slice, &mut aux, 0)
}
}
fn sort_sequential<T>(&self, slice: &mut [T], aux: &mut Vec<T>, depth: usize) -> Result<()>
where
T: Ord + Clone + 'static,
{
if depth >= self.max_recursion_depth {
return Err(SortError::recursion_limit_exceeded(
depth,
self.max_recursion_depth,
));
}
if slice.len() <= self.insertion_threshold {
insertion_sort(slice);
return Ok(());
}
let mid = slice.len() / 2;
self.sort_sequential(&mut slice[..mid], aux, depth + 1)?;
self.sort_sequential(&mut slice[mid..], aux, depth + 1)?;
merge(slice, mid, aux);
Ok(())
}
#[cfg(feature = "parallel")]
fn sort_parallel<T>(&self, slice: &mut [T], aux: &mut [T], depth: usize) -> Result<()>
where
T: Ord + Clone + Send + Sync + 'static,
{
if depth >= self.max_recursion_depth {
return Err(SortError::recursion_limit_exceeded(
depth,
self.max_recursion_depth,
));
}
if slice.len() <= self.insertion_threshold {
insertion_sort(slice);
return Ok(());
}
let mid = slice.len() / 2;
let len = slice.len(); let (left, right) = slice.split_at_mut(mid);
let mut left_aux = aux[..mid].to_vec();
let mut right_aux = aux[mid..len].to_vec();
let (left_result, right_result) = rayon::join(
|| self.sort_sequential(left, &mut left_aux, depth + 1),
|| self.sort_sequential(right, &mut right_aux, depth + 1),
);
left_result?;
right_result?;
merge(slice, mid, aux);
Ok(())
}
#[cfg(not(feature = "parallel"))]
fn sort_parallel<T>(&self, slice: &mut [T], aux: &mut Vec<T>, depth: usize) -> Result<()>
where
T: Ord + Clone + Send + Sync + 'static,
{
self.sort_sequential(slice, aux, depth)
}
pub fn validate_array_size(&self, size: usize) -> Result<()> {
if size > Self::MAX_LENGTH {
Err(SortError::input_too_large(size, Self::MAX_LENGTH))
} else {
Ok(())
}
}
}
pub fn sort<T>(slice: &mut [T]) -> Result<()>
where
T: Ord + Clone + Send + Sync + 'static,
{
MergeSortBuilder::new().sort(slice)
}
fn insertion_sort<T: Ord>(slice: &mut [T]) {
for i in 1..slice.len() {
let mut j = i;
while j > 0 && slice[j - 1] > slice[j] {
slice.swap(j - 1, j);
j -= 1;
}
}
}
fn merge<T>(slice: &mut [T], mid: usize, aux: &mut [T])
where
T: Ord + Clone,
{
aux[..slice.len()].clone_from_slice(slice);
let (left, right) = aux[..slice.len()].split_at(mid);
let mut i = 0;
let mut j = 0;
let mut k = 0;
while i < left.len() && j < right.len() {
if left[i] <= right[j] {
slice[k] = left[i].clone();
i += 1;
} else {
slice[k] = right[j].clone();
j += 1;
}
k += 1;
}
if i < left.len() {
slice[k..].clone_from_slice(&left[i..]);
}
if j < right.len() {
slice[k..].clone_from_slice(&right[j..]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_slice() {
let mut arr: Vec<i32> = vec![];
sort(&mut arr).unwrap();
assert_eq!(arr, Vec::<i32>::new());
}
#[test]
fn test_single_element() {
let mut arr = vec![1];
sort(&mut arr).unwrap();
assert_eq!(arr, vec![1]);
}
#[test]
fn test_sorted_array() {
let mut arr = vec![1, 2, 3, 4, 5];
sort(&mut arr).unwrap();
assert_eq!(arr, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_reverse_sorted() {
let mut arr = vec![5, 4, 3, 2, 1];
sort(&mut arr).unwrap();
assert_eq!(arr, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_random_order() {
let mut arr = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5];
let mut expected = arr.clone();
expected.sort();
sort(&mut arr).unwrap();
assert_eq!(arr, expected);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_sorting() {
let size = 10_000; let mut arr: Vec<i32> = (0..size).rev().collect();
let mut expected = arr.clone();
expected.sort();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(4)
.build()
.unwrap();
pool.install(|| {
MergeSortBuilder::new()
.parallel(true)
.parallel_threshold(1000)
.sort(&mut arr)
.unwrap();
});
assert_eq!(arr, expected);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_threshold() {
let size = 10_000;
let arr: Vec<i32> = (0..size).rev().collect();
let mut arr1 = arr.clone();
MergeSortBuilder::new()
.parallel(true)
.parallel_threshold((size * 2) as usize)
.sort(&mut arr1)
.unwrap();
let mut arr2 = arr.clone();
MergeSortBuilder::new()
.parallel(true)
.parallel_threshold((size / 2) as usize)
.sort(&mut arr2)
.unwrap();
let mut expected = arr;
expected.sort();
assert_eq!(arr1, expected);
assert_eq!(arr2, expected);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_stability() {
#[derive(Debug, Clone, Eq, PartialEq)]
struct Item {
key: i32,
original_index: usize,
}
impl PartialOrd for Item {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.key.partial_cmp(&other.key)
}
}
impl Ord for Item {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.key.cmp(&other.key)
}
}
let size = 10_000;
let mut items: Vec<_> = (0..size)
.map(|i| Item {
key: i as i32 / 10, original_index: i,
})
.collect();
MergeSortBuilder::new()
.parallel(true)
.parallel_threshold(1000)
.sort(&mut items)
.unwrap();
for i in 1..items.len() {
if items[i - 1].key == items[i].key {
assert!(
items[i - 1].original_index < items[i].original_index,
"Stability violated at indices {} and {}",
i - 1,
i
);
}
}
}
#[test]
fn test_recursion_limit() {
let mut arr: Vec<i32> = (0..10_000).collect(); let result = MergeSortBuilder::new()
.max_recursion_depth(3)
.sort(&mut arr);
match result {
Err(SortError::RecursionLimitExceeded { depth, max_depth }) => {
assert_eq!(max_depth, 3);
assert!(depth >= max_depth);
}
_ => panic!("Expected RecursionLimitExceeded error"),
}
}
#[test]
fn test_input_too_large() {
let size = MergeSortBuilder::MAX_LENGTH + 1;
let result = MergeSortBuilder::new().validate_array_size(size);
match result {
Err(SortError::InputTooLarge { length, max_length }) => {
assert_eq!(length, size);
assert_eq!(max_length, MergeSortBuilder::MAX_LENGTH);
}
_ => panic!("Expected InputTooLarge error"),
}
}
}