use crate::algorithms::{Algorithm, AlgorithmStats};
use crate::error::Result;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct MultiWayMergeConfig {
pub use_parallel: bool,
pub buffer_size: usize,
pub max_merge_ways: usize,
pub use_tournament_tree: bool,
}
impl Default for MultiWayMergeConfig {
fn default() -> Self {
Self {
use_parallel: true,
buffer_size: 64 * 1024, max_merge_ways: 1024,
use_tournament_tree: false,
}
}
}
pub trait MergeSource<T> {
fn next(&mut self) -> Option<T>;
fn peek(&self) -> Option<&T>;
fn is_empty(&self) -> bool;
fn remaining_hint(&self) -> Option<usize> {
None
}
}
pub struct SliceSource<'a, T> {
data: &'a [T],
index: usize,
}
impl<'a, T> SliceSource<'a, T> {
pub fn new(data: &'a [T]) -> Self {
Self { data, index: 0 }
}
}
impl<'a, T> MergeSource<T> for SliceSource<'a, T>
where
T: Clone,
{
fn next(&mut self) -> Option<T> {
if self.index < self.data.len() {
let item = self.data[self.index].clone();
self.index += 1;
Some(item)
} else {
None
}
}
fn peek(&self) -> Option<&T> {
self.data.get(self.index)
}
fn is_empty(&self) -> bool {
self.index >= self.data.len()
}
fn remaining_hint(&self) -> Option<usize> {
Some(self.data.len() - self.index)
}
}
pub struct VectorSource<T> {
data: Vec<T>,
index: usize,
}
impl<T> VectorSource<T> {
pub fn new(data: Vec<T>) -> Self {
Self { data, index: 0 }
}
pub fn remaining(&self) -> &[T] {
&self.data[self.index..]
}
}
impl<T> MergeSource<T> for VectorSource<T>
where
T: Clone,
{
fn next(&mut self) -> Option<T> {
if self.index < self.data.len() {
let item = self.data[self.index].clone();
self.index += 1;
Some(item)
} else {
None
}
}
fn peek(&self) -> Option<&T> {
self.data.get(self.index)
}
fn is_empty(&self) -> bool {
self.index >= self.data.len()
}
fn remaining_hint(&self) -> Option<usize> {
Some(self.data.len() - self.index)
}
}
#[derive(Debug)]
struct HeapEntry<T> {
item: T,
source_id: usize,
}
impl<T> PartialEq for HeapEntry<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.item.eq(&other.item)
}
}
impl<T> Eq for HeapEntry<T> where T: Eq {}
impl<T> PartialOrd for HeapEntry<T>
where
T: PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.item.partial_cmp(&self.item)
}
}
impl<T> Ord for HeapEntry<T>
where
T: Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
other.item.cmp(&self.item)
}
}
pub struct MultiWayMerge {
config: MultiWayMergeConfig,
stats: AlgorithmStats,
}
impl MultiWayMerge {
pub fn new() -> Self {
Self::with_config(MultiWayMergeConfig::default())
}
pub fn with_config(config: MultiWayMergeConfig) -> Self {
Self {
config,
stats: AlgorithmStats {
items_processed: 0,
processing_time_us: 0,
memory_used: 0,
used_parallel: false,
used_simd: false,
},
}
}
pub fn merge<T, S>(&mut self, sources: Vec<S>) -> Result<Vec<T>>
where
T: Ord + Clone,
S: MergeSource<T>,
{
let start_time = Instant::now();
if sources.is_empty() {
return Ok(Vec::new());
}
if sources.len() == 1 {
let mut source = sources.into_iter().next().expect("non-empty sources");
let mut result = Vec::new();
while let Some(item) = source.next() {
result.push(item);
}
return Ok(result);
}
let result = if sources.len() > self.config.max_merge_ways {
self.merge_hierarchical(sources)?
} else if self.config.use_tournament_tree && sources.len() > 8 {
self.merge_tournament(sources)?
} else {
self.merge_heap(sources)?
};
let elapsed = start_time.elapsed();
self.stats = AlgorithmStats {
items_processed: result.len(),
processing_time_us: elapsed.as_micros() as u64,
memory_used: result.len() * std::mem::size_of::<T>(),
used_parallel: false, used_simd: false,
};
Ok(result)
}
fn merge_heap<T, S>(&self, mut sources: Vec<S>) -> Result<Vec<T>>
where
T: Ord + Clone,
S: MergeSource<T>,
{
let mut heap = BinaryHeap::new();
let mut result = Vec::new();
for (id, source) in sources.iter_mut().enumerate() {
if let Some(item) = source.next() {
heap.push(HeapEntry {
item,
source_id: id,
});
}
}
while let Some(entry) = heap.pop() {
result.push(entry.item);
if let Some(next_item) = sources[entry.source_id].next() {
heap.push(HeapEntry {
item: next_item,
source_id: entry.source_id,
});
}
}
Ok(result)
}
fn merge_tournament<T, S>(&self, mut sources: Vec<S>) -> Result<Vec<T>>
where
T: Ord + Clone,
S: MergeSource<T>,
{
let mut active_sources: Vec<usize> = (0..sources.len()).collect();
let mut result = Vec::new();
while !active_sources.is_empty() {
let mut min_source = 0;
let mut min_item: Option<&T> = None;
for &source_id in &active_sources {
if let Some(item) = sources[source_id].peek() {
if min_item.is_none() || item < min_item.as_ref().expect("min_item set by prior iteration") {
min_item = Some(item);
min_source = source_id;
}
}
}
if min_item.is_some() {
if let Some(item) = sources[min_source].next() {
result.push(item);
}
if sources[min_source].is_empty() {
active_sources.retain(|&id| id != min_source);
}
} else {
break;
}
}
Ok(result)
}
fn merge_hierarchical<T, S>(&self, sources: Vec<S>) -> Result<Vec<T>>
where
T: Ord + Clone,
S: MergeSource<T>,
{
self.merge_heap(sources)
}
pub fn stats(&self) -> &AlgorithmStats {
&self.stats
}
}
impl Default for MultiWayMerge {
fn default() -> Self {
Self::new()
}
}
impl Algorithm for MultiWayMerge {
type Config = MultiWayMergeConfig;
type Input = Vec<Vec<i32>>; type Output = Vec<i32>;
fn execute(&self, config: &Self::Config, input: Self::Input) -> Result<Self::Output> {
let mut merger = Self::with_config(config.clone());
let sources: Vec<VectorSource<i32>> = input.into_iter().map(VectorSource::new).collect();
merger.merge(sources)
}
fn stats(&self) -> AlgorithmStats {
self.stats.clone()
}
fn estimate_memory(&self, input_size: usize) -> usize {
let heap_size = self.config.max_merge_ways * std::mem::size_of::<HeapEntry<i32>>();
let output_size = input_size * std::mem::size_of::<i32>();
heap_size + output_size
}
fn supports_parallel(&self) -> bool {
true }
}
pub struct MergeOperations;
impl MergeOperations {
pub fn merge_two<T>(left: Vec<T>, right: Vec<T>) -> Vec<T>
where
T: Ord,
{
let mut result = Vec::with_capacity(left.len() + right.len());
let mut left_iter = left.into_iter();
let mut right_iter = right.into_iter();
let mut left_current = left_iter.next();
let mut right_current = right_iter.next();
loop {
match (&left_current, &right_current) {
(Some(l), Some(r)) => {
if l <= r {
result.push(left_current.take().expect("left_current set by comparison"));
left_current = left_iter.next();
} else {
result.push(right_current.take().expect("right_current set by comparison"));
right_current = right_iter.next();
}
}
(Some(_), None) => {
result.push(left_current.take().expect("left_current has remaining value"));
result.extend(left_iter);
break;
}
(None, Some(_)) => {
result.push(right_current.take().expect("right_current has remaining value"));
result.extend(right_iter);
break;
}
(None, None) => break,
}
}
result
}
pub fn merge_in_place<T>(data: &mut [T], mid: usize)
where
T: Ord + Clone,
{
if mid == 0 || mid >= data.len() {
return;
}
let mut temp = Vec::with_capacity(data.len());
let (left, right) = data.split_at(mid);
let mut left_iter = left.iter();
let mut right_iter = right.iter();
let mut left_current = left_iter.next();
let mut right_current = right_iter.next();
loop {
match (left_current, right_current) {
(Some(l), Some(r)) => {
if l <= r {
temp.push(l.clone());
left_current = left_iter.next();
} else {
temp.push(r.clone());
right_current = right_iter.next();
}
}
(Some(l), None) => {
temp.push(l.clone());
temp.extend(left_iter.cloned());
break;
}
(None, Some(r)) => {
temp.push(r.clone());
temp.extend(right_iter.cloned());
break;
}
(None, None) => break,
}
}
for (i, item) in temp.into_iter().enumerate() {
data[i] = item;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_source() {
let mut source = VectorSource::new(vec![1, 2, 3, 4, 5]);
assert_eq!(source.peek(), Some(&1));
assert_eq!(source.next(), Some(1));
assert_eq!(source.peek(), Some(&2));
assert!(!source.is_empty());
assert_eq!(source.remaining_hint(), Some(4));
source.next();
source.next();
source.next();
source.next();
assert!(source.is_empty());
assert_eq!(source.next(), None);
}
#[test]
fn test_multiway_merge_empty() {
let mut merger = MultiWayMerge::new();
let sources: Vec<VectorSource<i32>> = vec![];
let result = merger.merge(sources).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_multiway_merge_single_source() {
let mut merger = MultiWayMerge::new();
let sources = vec![VectorSource::new(vec![1, 3, 5, 7, 9])];
let result = merger.merge(sources).unwrap();
assert_eq!(result, vec![1, 3, 5, 7, 9]);
}
#[test]
fn test_multiway_merge_two_sources() {
let mut merger = MultiWayMerge::new();
let sources = vec![
VectorSource::new(vec![1, 3, 5, 7, 9]),
VectorSource::new(vec![2, 4, 6, 8, 10]),
];
let result = merger.merge(sources).unwrap();
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}
#[test]
fn test_multiway_merge_multiple_sources() {
let mut merger = MultiWayMerge::new();
let sources = vec![
VectorSource::new(vec![1, 4, 7]),
VectorSource::new(vec![2, 5, 8]),
VectorSource::new(vec![3, 6, 9]),
];
let result = merger.merge(sources).unwrap();
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_multiway_merge_uneven_sources() {
let mut merger = MultiWayMerge::new();
let sources = vec![
VectorSource::new(vec![1, 2, 3, 4, 5]),
VectorSource::new(vec![6]),
VectorSource::new(vec![]),
VectorSource::new(vec![7, 8]),
];
let result = merger.merge(sources).unwrap();
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8]);
}
#[test]
fn test_multiway_merge_tournament() {
let config = MultiWayMergeConfig {
use_tournament_tree: true,
..Default::default()
};
let mut merger = MultiWayMerge::with_config(config);
let sources = vec![
VectorSource::new(vec![1, 5, 9]),
VectorSource::new(vec![2, 6, 10]),
VectorSource::new(vec![3, 7, 11]),
VectorSource::new(vec![4, 8, 12]),
];
let result = merger.merge(sources).unwrap();
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
}
#[test]
fn test_merge_two_vectors() {
let left = vec![1, 3, 5, 7];
let right = vec![2, 4, 6, 8];
let result = MergeOperations::merge_two(left, right);
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8]);
}
#[test]
fn test_merge_two_uneven() {
let left = vec![1, 3, 5];
let right = vec![2, 4, 6, 7, 8, 9];
let result = MergeOperations::merge_two(left, right);
assert_eq!(result, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_merge_in_place() {
let mut data = vec![1, 3, 5, 7, 2, 4, 6, 8];
MergeOperations::merge_in_place(&mut data, 4);
assert_eq!(data, vec![1, 2, 3, 4, 5, 6, 7, 8]);
}
#[test]
fn test_slice_source_basic() {
let data = vec![1, 3, 5, 7, 9];
let mut src = SliceSource::new(&data);
assert_eq!(src.peek(), Some(&1));
assert_eq!(src.remaining_hint(), Some(5));
assert_eq!(src.next(), Some(1));
assert_eq!(src.next(), Some(3));
assert_eq!(src.remaining_hint(), Some(3));
assert!(!src.is_empty());
assert_eq!(src.next(), Some(5));
assert_eq!(src.next(), Some(7));
assert_eq!(src.next(), Some(9));
assert!(src.is_empty());
assert_eq!(src.next(), None);
}
#[test]
fn test_multiway_merge_with_slice_sources() {
let a = vec![1, 4, 7, 10];
let b = vec![2, 5, 8, 11];
let c = vec![3, 6, 9, 12];
let mut merger = MultiWayMerge::new();
let sources = vec![
SliceSource::new(&a),
SliceSource::new(&b),
SliceSource::new(&c),
];
let result = merger.merge(sources).unwrap();
assert_eq!(result, (1..=12).collect::<Vec<i32>>());
}
#[test]
fn test_slice_source_empty() {
let data: Vec<u32> = vec![];
let mut src = SliceSource::new(&data);
assert!(src.is_empty());
assert_eq!(src.next(), None);
assert_eq!(src.peek(), None);
}
#[test]
fn test_heap_entry_ordering() {
let mut heap = BinaryHeap::new();
heap.push(HeapEntry {
item: 5,
source_id: 0,
});
heap.push(HeapEntry {
item: 2,
source_id: 1,
});
heap.push(HeapEntry {
item: 8,
source_id: 2,
});
heap.push(HeapEntry {
item: 1,
source_id: 3,
});
assert_eq!(heap.pop().unwrap().item, 1);
assert_eq!(heap.pop().unwrap().item, 2);
assert_eq!(heap.pop().unwrap().item, 5);
assert_eq!(heap.pop().unwrap().item, 8);
}
#[test]
fn test_algorithm_trait() {
let merger = MultiWayMerge::new();
assert!(merger.supports_parallel());
let memory_estimate = merger.estimate_memory(1000);
assert!(memory_estimate > 0);
let input = vec![vec![1, 3, 5], vec![2, 4, 6], vec![7, 8, 9]];
let config = MultiWayMergeConfig::default();
let result = merger.execute(&config, input);
assert!(result.is_ok());
let merged = result.unwrap();
assert_eq!(merged, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_multiway_merge_performance() {
let mut merger = MultiWayMerge::new();
let sources: Vec<VectorSource<i32>> = (0..10)
.map(|i| {
let data: Vec<i32> = (i * 10..(i + 1) * 10).collect();
VectorSource::new(data)
})
.collect();
let result = merger.merge(sources).unwrap();
let expected: Vec<i32> = (0..100).collect();
assert_eq!(result, expected);
let stats = merger.stats();
assert_eq!(stats.items_processed, 100);
assert!(stats.processing_time_us > 0);
}
}