use std::{
cmp::Ordering,
collections::{BinaryHeap, VecDeque},
fs::File,
io::{BufReader, Error, Seek, SeekFrom},
};
use crate::{ExternalSorterOptions, Sortable};
pub struct SortedIterator<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync + Clone,
{
_tempdir: Option<tempfile::TempDir>,
segments: Vec<Segment>,
mode: Mode<T, F>,
count: u64,
cmp: F,
}
enum Mode<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync + Clone,
{
Passthrough(VecDeque<T>),
Heap(BinaryHeap<HeapItem<T, F>>),
Peek(Vec<Option<T>>),
}
struct Segment {
reader: BufReader<File>,
heap_count: usize,
done: bool,
}
impl<T, F> SortedIterator<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync + Clone,
{
pub(crate) fn new(
tempdir: Option<tempfile::TempDir>,
pass_through_queue: Option<VecDeque<T>>,
mut segment_files: Vec<File>,
count: u64,
cmp: F,
options: ExternalSorterOptions,
) -> Result<SortedIterator<T, F>, Error> {
for segment_file in &mut segment_files {
segment_file.seek(SeekFrom::Start(0))?;
}
let mut segments: Vec<Segment> = segment_files
.into_iter()
.map(|file| Segment {
reader: BufReader::new(file),
heap_count: 0,
done: false,
})
.collect();
let mode = if let Some(queue) = pass_through_queue {
Mode::Passthrough(queue)
} else if segments.len() < options.heap_iter_segment_count {
let mut next_values = Vec::with_capacity(segments.len());
for segment in segments.iter_mut() {
next_values.push(Some(T::decode(&mut segment.reader)?));
}
Mode::Peek(next_values)
} else {
Mode::Heap(BinaryHeap::new())
};
Ok(SortedIterator {
_tempdir: tempdir,
segments,
mode,
count,
cmp,
})
}
pub fn sorted_count(&self) -> u64 {
self.count
}
pub fn disk_segment_count(&self) -> usize {
self.segments.len()
}
fn fill_heap(
heap: &mut BinaryHeap<HeapItem<T, F>>,
segments: &mut [Segment],
cmp: F,
) -> std::io::Result<()> {
for (segment_index, segment) in segments.iter_mut().enumerate() {
if segment.done {
continue;
}
if segment.heap_count == 0 {
for _i in 0..20 {
let value = match T::decode(&mut segment.reader) {
Ok(value) => value,
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
segment.done = true;
continue;
}
Err(err) => return Err(err),
};
segment.heap_count += 1;
heap.push(HeapItem {
segment_index,
value,
cmp: cmp.clone(),
});
}
}
}
Ok(())
}
}
impl<T, F> Iterator for SortedIterator<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync + Clone,
{
type Item = std::io::Result<T>;
fn next(&mut self) -> Option<Self::Item> {
match &mut self.mode {
Mode::Passthrough(queue) => queue.pop_front().map(Ok),
Mode::Heap(heap) => {
if heap.is_empty() {
if let Err(err) = Self::fill_heap(heap, &mut self.segments, self.cmp.clone()) {
return Some(Err(err));
}
}
if heap.is_empty() {
return None;
}
let item = heap.pop().unwrap();
let segment = &mut self.segments[item.segment_index];
segment.heap_count -= 1;
if segment.heap_count == 0 {
if let Err(err) = Self::fill_heap(heap, &mut self.segments, self.cmp.clone()) {
return Some(Err(err));
}
}
Some(Ok(item.value))
}
Mode::Peek(next_values) => {
let mut smallest_idx: Option<usize> = None;
{
let mut smallest: Option<&T> = None;
for (idx, next_value) in next_values.iter().enumerate() {
let Some(next_value) = next_value else {
continue;
};
if smallest.is_none()
|| (self.cmp)(next_value, smallest.unwrap()) == Ordering::Less
{
smallest = Some(next_value);
smallest_idx = Some(idx);
}
}
}
if let Some(idx) = smallest_idx {
let segment = &mut self.segments[idx];
let value = next_values[idx].take().unwrap();
match T::decode(&mut segment.reader) {
Ok(value) => {
next_values[idx] = Some(value);
}
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
next_values[idx] = None;
}
Err(err) => {
return Some(Err(err));
}
};
Some(Ok(value))
} else {
None
}
}
}
}
}
struct HeapItem<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync,
{
segment_index: usize,
value: T,
cmp: F,
}
impl<T, F> PartialOrd for HeapItem<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T, F> Ord for HeapItem<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync,
{
fn cmp(&self, other: &Self) -> Ordering {
(self.cmp)(&self.value, &other.value).reverse()
}
}
impl<T, F> PartialEq for HeapItem<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync,
{
fn eq(&self, other: &Self) -> bool {
(self.cmp)(&self.value, &other.value) == Ordering::Equal
}
}
impl<T, F> Eq for HeapItem<T, F>
where
T: Sortable,
F: Fn(&T, &T) -> Ordering + Send + Sync,
{
}