use get_size2::GetSize;
use crate::{
SortedItemsError,
chunk::{ChunkSorter, Sequential},
codec::{Codec, CodecWriter, DeriveKey, KeyedCodec, KeyedCodecWriter},
compare::{Compare, Natural},
dedup::{Dedup, Identity},
key::{KeyCompare, SortKey},
merge::{KeyedRunMerge, MergeConfig, MergeError, RunMerge, RunMerger, RunWriter, SortedRun},
};
enum FlushStrategy<T> {
Bytes {
budget: usize,
item_size: Box<dyn Fn(&T) -> usize + Send + Sync>,
},
Items { max_items: usize },
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct SorterConfig {
pub merge: MergeConfig,
}
pub struct NeedsSortKey;
pub struct HasSortKey<SK>(SK);
pub struct NeedsCodec;
pub struct HasCodec<Cod>(Cod);
pub struct HasKeyedCodec<Cod>(Cod);
pub struct NeedsFlushStrategy;
pub struct HasFlushStrategy<T>(FlushStrategy<T>);
pub struct Basic;
pub struct Keyed;
pub struct Builder<SK, Cod, Flush, Cmp = Natural, D = Identity, CS = Sequential> {
sort_key: SK,
codec: Cod,
flush: Flush,
compare: Cmp,
dedup: D,
chunk_sort: CS,
config: SorterConfig,
}
impl Builder<NeedsSortKey, NeedsCodec, NeedsFlushStrategy> {
#[must_use]
pub fn new() -> Self {
Builder {
sort_key: NeedsSortKey,
codec: NeedsCodec,
flush: NeedsFlushStrategy,
compare: Natural,
dedup: Identity,
chunk_sort: Sequential,
config: SorterConfig::default(),
}
}
}
impl Default for Builder<NeedsSortKey, NeedsCodec, NeedsFlushStrategy> {
fn default() -> Self {
Self::new()
}
}
impl<SK, Cod, Flush, Cmp, D, CS> Builder<SK, Cod, Flush, Cmp, D, CS> {
#[must_use]
pub fn key<SK2>(self, sort_key: SK2) -> Builder<HasSortKey<SK2>, Cod, Flush, Cmp, D, CS> {
Builder {
sort_key: HasSortKey(sort_key),
codec: self.codec,
flush: self.flush,
compare: self.compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn codec<Cod2>(self, codec: Cod2) -> Builder<SK, HasCodec<Cod2>, Flush, Cmp, D, CS> {
Builder {
sort_key: self.sort_key,
codec: HasCodec(codec),
flush: self.flush,
compare: self.compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn keyed_codec<Cod2>(
self,
codec: Cod2,
) -> Builder<SK, HasKeyedCodec<Cod2>, Flush, Cmp, D, CS> {
Builder {
sort_key: self.sort_key,
codec: HasKeyedCodec(codec),
flush: self.flush,
compare: self.compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn compare<Cmp2>(self, compare: Cmp2) -> Builder<SK, Cod, Flush, Cmp2, D, CS> {
Builder {
sort_key: self.sort_key,
codec: self.codec,
flush: self.flush,
compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn dedup<D2>(self, dedup: D2) -> Builder<SK, Cod, Flush, Cmp, D2, CS> {
Builder {
sort_key: self.sort_key,
codec: self.codec,
flush: self.flush,
compare: self.compare,
dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn chunk_sort<CS2>(self, chunk_sort: CS2) -> Builder<SK, Cod, Flush, Cmp, D, CS2> {
Builder {
sort_key: self.sort_key,
codec: self.codec,
flush: self.flush,
compare: self.compare,
dedup: self.dedup,
chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn merge_config(mut self, merge: MergeConfig) -> Self {
self.config.merge = merge;
self
}
}
impl<SK, Cod, Cmp, D, CS> Builder<SK, Cod, NeedsFlushStrategy, Cmp, D, CS> {
#[must_use]
pub fn memory_budget<T: 'static>(
self,
budget: usize,
item_size: impl Fn(&T) -> usize + Send + Sync + 'static,
) -> Builder<SK, Cod, HasFlushStrategy<T>, Cmp, D, CS> {
Builder {
sort_key: self.sort_key,
codec: self.codec,
flush: HasFlushStrategy(FlushStrategy::Bytes {
budget,
item_size: Box::new(item_size),
}),
compare: self.compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn measured_budget<T: GetSize + 'static>(
self,
budget: usize,
) -> Builder<SK, Cod, HasFlushStrategy<T>, Cmp, D, CS> {
Builder {
sort_key: self.sort_key,
codec: self.codec,
flush: HasFlushStrategy(FlushStrategy::Bytes {
budget,
item_size: Box::new(GetSize::get_size),
}),
compare: self.compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn fixed_size_budget<T: 'static>(
self,
budget: usize,
) -> Builder<SK, Cod, HasFlushStrategy<T>, Cmp, D, CS> {
Builder {
sort_key: self.sort_key,
codec: self.codec,
flush: HasFlushStrategy(FlushStrategy::Bytes {
budget,
item_size: Box::new(|_| std::mem::size_of::<T>()),
}),
compare: self.compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
#[must_use]
pub fn max_buffer_items<T>(
self,
max_items: usize,
) -> Builder<SK, Cod, HasFlushStrategy<T>, Cmp, D, CS> {
Builder {
sort_key: self.sort_key,
codec: self.codec,
flush: HasFlushStrategy(FlushStrategy::Items { max_items }),
compare: self.compare,
dedup: self.dedup,
chunk_sort: self.chunk_sort,
config: self.config,
}
}
}
impl<T, SK, Cod, Cmp, D, CS> Builder<HasSortKey<SK>, HasCodec<Cod>, HasFlushStrategy<T>, Cmp, D, CS>
where
SK: SortKey<T> + Copy,
Cod: Codec<Item = T> + Copy,
Cmp: for<'a> Compare<SK::Key<'a>> + Copy,
CS: ChunkSorter<T>,
{
#[must_use]
pub fn build(self) -> Sorter<T, SK, Cod, Cmp, D, CS, Basic> {
Sorter {
sort_key: self.sort_key.0,
codec: self.codec.0,
compare: self.compare,
dedup: Some(self.dedup),
chunk_sort: self.chunk_sort,
flush: self.flush.0,
buffer: Vec::new(),
buffer_bytes: 0,
spilled_runs: Vec::new(),
config: self.config,
_marker: std::marker::PhantomData,
}
}
}
impl<T, SK, Cod, Cmp, D, CS>
Builder<HasSortKey<SK>, HasKeyedCodec<Cod>, HasFlushStrategy<T>, Cmp, D, CS>
where
SK: SortKey<T> + Copy,
Cod: KeyedCodec<Item = T> + Copy,
Cmp: for<'a> Compare<SK::Key<'a>> + Compare<Cod::Key> + Copy,
CS: ChunkSorter<T>,
{
#[must_use]
pub fn build(self) -> Sorter<T, SK, Cod, Cmp, D, CS, Keyed> {
Sorter {
sort_key: self.sort_key.0,
codec: self.codec.0,
compare: self.compare,
dedup: Some(self.dedup),
chunk_sort: self.chunk_sort,
flush: self.flush.0,
buffer: Vec::new(),
buffer_bytes: 0,
spilled_runs: Vec::new(),
config: self.config,
_marker: std::marker::PhantomData,
}
}
}
pub struct Sorter<T, SK, Cod, Cmp, D, CS, M = Basic> {
sort_key: SK,
codec: Cod,
compare: Cmp,
dedup: Option<D>,
chunk_sort: CS,
flush: FlushStrategy<T>,
buffer: Vec<T>,
buffer_bytes: usize,
spilled_runs: Vec<SortedRun>,
config: SorterConfig,
_marker: std::marker::PhantomData<M>,
}
pub struct Sorted<I> {
source: I,
}
pub struct SortedItems<S> {
source: S,
}
#[doc(hidden)]
pub mod sealed {
pub trait Sealed {}
}
#[doc(hidden)]
pub trait VisitSortedItems: sealed::Sealed {
type Item<'a>
where
Self: 'a;
type Error;
fn visit_items<F, FE>(self, f: F) -> Result<(), SortedItemsError<Self::Error, FE>>
where
F: for<'a> FnMut(Self::Item<'a>) -> Result<(), FE>;
}
impl<I> Sorted<I> {
fn new(source: I) -> Self {
Self { source }
}
pub fn items(self) -> SortedItems<I> {
SortedItems {
source: self.source,
}
}
}
impl<S> SortedItems<S>
where
S: VisitSortedItems,
{
pub fn try_for_each<F, FE>(self, f: F) -> Result<(), SortedItemsError<S::Error, FE>>
where
F: for<'a> FnMut(S::Item<'a>) -> Result<(), FE>,
{
self.source.visit_items(f)
}
}
impl<I> Iterator for Sorted<I>
where
I: Iterator,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.source.next()
}
}
impl<I> sealed::Sealed for Sorted<I> where I: VisitSortedItems + 'static {}
impl<I> VisitSortedItems for Sorted<I>
where
I: VisitSortedItems + 'static,
{
type Item<'a>
= I::Item<'a>
where
Self: 'a;
type Error = I::Error;
fn visit_items<F, FE>(self, f: F) -> Result<(), SortedItemsError<Self::Error, FE>>
where
F: for<'a> FnMut(Self::Item<'a>) -> Result<(), FE>,
{
self.source.visit_items(f)
}
}
impl<T, SK, Cod, Cmp, D, CS> Sorter<T, SK, Cod, Cmp, D, CS, Basic>
where
T: 'static,
SK: SortKey<T> + Copy + Send + Sync + 'static,
Cod: Codec<Item = T> + Copy + 'static,
for<'a> Cod::Writer<&'a mut std::fs::File>: CodecWriter<T, Error = Cod::Error>,
Cmp: for<'a> Compare<SK::Key<'a>> + Copy + Send + Sync + 'static,
CS: ChunkSorter<T>,
{
pub fn push(&mut self, item: T) -> Result<(), MergeError<Cod::Error>> {
match &self.flush {
FlushStrategy::Bytes { budget, item_size } => {
self.buffer_bytes += item_size(&item);
self.buffer.push(item);
if self.buffer_bytes >= *budget {
self.flush_basic()?;
}
}
FlushStrategy::Items { max_items } => {
self.buffer.push(item);
if self.buffer.len() >= *max_items {
self.flush_basic()?;
}
}
}
Ok(())
}
fn flush_basic(&mut self) -> Result<(), MergeError<Cod::Error>> {
let item_cmp = KeyCompare::new(self.sort_key, self.compare);
self.chunk_sort.sort(&mut self.buffer, move |a: &T, b: &T| {
Compare::compare(&item_cmp, a, b)
});
let run_writer = RunWriter::new_unkeyed(self.codec, self.config.merge.clone());
let run = run_writer.write_sorted(&self.buffer)?;
self.spilled_runs.push(run);
self.buffer.clear();
self.buffer_bytes = 0;
Ok(())
}
#[allow(clippy::type_complexity)]
pub fn finish(
mut self,
) -> Result<
Sorted<<D as Dedup<RunMerge<T, Cod, KeyCompare<SK, Cmp>>>>::Deduped>,
MergeError<Cod::Error>,
>
where
D: Dedup<RunMerge<T, Cod, KeyCompare<SK, Cmp>>>,
{
if !self.buffer.is_empty() {
self.flush_basic()?;
}
let item_cmp = KeyCompare::new(self.sort_key, self.compare);
let run_merger = RunMerger::new(self.codec, item_cmp, self.config.merge.clone());
let merged = run_merger.merge(std::mem::take(&mut self.spilled_runs))?;
let dedup = self
.dedup
.take()
.expect("dedup is always Some until finish() consumes it");
Ok(Sorted::new(dedup.dedup(merged)))
}
}
impl<T, SK, Cod, Cmp, D, CS> Sorter<T, SK, Cod, Cmp, D, CS, Keyed>
where
T: 'static,
SK: SortKey<T> + Copy + Send + Sync + 'static,
Cod: KeyedCodec<Item = T> + DeriveKey<T> + Copy + 'static,
for<'a> Cod::KeyedWriter<&'a mut std::fs::File>:
KeyedCodecWriter<T, Cod::Key, Error = Cod::Error>,
Cmp: for<'a> Compare<SK::Key<'a>> + Compare<Cod::Key> + Copy + Send + Sync + 'static,
CS: ChunkSorter<T>,
{
pub fn push(&mut self, item: T) -> Result<(), MergeError<Cod::Error>> {
match &self.flush {
FlushStrategy::Bytes { budget, item_size } => {
self.buffer_bytes += item_size(&item);
self.buffer.push(item);
if self.buffer_bytes >= *budget {
self.flush_keyed()?;
}
}
FlushStrategy::Items { max_items } => {
self.buffer.push(item);
if self.buffer.len() >= *max_items {
self.flush_keyed()?;
}
}
}
Ok(())
}
fn flush_keyed(&mut self) -> Result<(), MergeError<Cod::Error>> {
let item_cmp = KeyCompare::new(self.sort_key, self.compare);
self.chunk_sort.sort(&mut self.buffer, move |a: &T, b: &T| {
Compare::compare(&item_cmp, a, b)
});
let run_writer = RunWriter::new_keyed(self.codec, self.config.merge.clone());
let run = run_writer.write_sorted(&self.buffer)?;
self.spilled_runs.push(run);
self.buffer.clear();
self.buffer_bytes = 0;
Ok(())
}
#[allow(clippy::type_complexity)]
pub fn finish(
mut self,
) -> Result<
Sorted<<D as Dedup<KeyedRunMerge<T, Cod, Cmp, KeyCompare<SK, Cmp>>>>::Deduped>,
MergeError<Cod::Error>,
>
where
D: Dedup<KeyedRunMerge<T, Cod, Cmp, KeyCompare<SK, Cmp>>>,
{
if !self.buffer.is_empty() {
self.flush_keyed()?;
}
let item_cmp = KeyCompare::new(self.sort_key, self.compare);
let run_merger = RunMerger::new(self.codec, item_cmp, self.config.merge.clone());
let merged =
run_merger.merge_keyed(std::mem::take(&mut self.spilled_runs), self.compare)?;
let dedup = self
.dedup
.take()
.expect("dedup is always Some until finish() consumes it");
Ok(Sorted::new(dedup.dedup(merged)))
}
}
#[cfg(test)]
mod tests {
use std::io::{BufWriter, Read, Write};
use super::*;
use crate::{
codec::{CodecCursor, CodecWriter},
compare::Reverse,
dedup::AdjacentDedup,
key::Owned,
};
#[derive(Clone, Copy)]
struct U64Codec;
struct U64Writer<W: Write> {
inner: BufWriter<W>,
}
impl<W: Write> CodecWriter<u64> for U64Writer<W> {
type Error = std::io::Error;
fn write(&mut self, item: &u64) -> Result<(), Self::Error> {
self.inner.write_all(&item.to_le_bytes())
}
fn finish(mut self) -> Result<(), Self::Error> {
self.inner.flush()
}
}
struct U64Reader<R: Read> {
inner: R,
current: Option<u64>,
}
impl<R: Read> CodecCursor<u64> for U64Reader<R> {
type Error = std::io::Error;
type Current<'a>
= u64
where
Self: 'a;
fn advance(&mut self) -> Result<bool, Self::Error> {
let mut buf = [0u8; 8];
match self.inner.read(&mut buf[..1]) {
Ok(0) => {
self.current = None;
Ok(false)
}
Ok(_) => {
self.inner.read_exact(&mut buf[1..])?;
self.current = Some(u64::from_le_bytes(buf));
Ok(true)
}
Err(e) => Err(e),
}
}
fn current(&mut self) -> Result<u64, Self::Error> {
self.current
.ok_or_else(|| std::io::Error::other("current called before advance"))
}
fn with_current<'a, F>(
&'a mut self,
f: impl FnOnce(Self::Current<'a>) -> F,
) -> Result<F, Self::Error> {
self.current().map(f)
}
}
impl Codec for U64Codec {
type Item = u64;
type Error = std::io::Error;
type Writer<W: Write> = U64Writer<W>;
type Cursor<R: Read> = U64Reader<R>;
fn writer<W: Write>(&self, dest: W) -> U64Writer<W> {
U64Writer {
inner: BufWriter::new(dest),
}
}
fn cursor<R: Read>(&self, source: R) -> U64Reader<R> {
U64Reader {
inner: source,
current: None,
}
}
}
type U64Sorter = Sorter<u64, Owned<fn(&u64) -> u64>, U64Codec, Natural, Identity, Sequential>;
fn u64_sorter(max_items: usize) -> U64Sorter {
Builder::new()
.key(Owned((|v: &u64| *v) as fn(&u64) -> u64))
.codec(U64Codec)
.max_buffer_items::<u64>(max_items)
.build()
}
struct TestVisitable {
items: Vec<u64>,
}
impl sealed::Sealed for TestVisitable {}
impl VisitSortedItems for TestVisitable {
type Item<'a>
= u64
where
Self: 'a;
type Error = std::convert::Infallible;
fn visit_items<F, FE>(self, mut f: F) -> Result<(), SortedItemsError<Self::Error, FE>>
where
F: for<'a> FnMut(u64) -> Result<(), FE>,
{
for item in self.items {
f(item).map_err(SortedItemsError::Sink)?;
}
Ok(())
}
}
#[test]
fn sort_single_item() {
let mut sorter = u64_sorter(100);
sorter.push(42).expect("push");
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert_eq!(results, vec![42]);
}
#[test]
fn sorted_items_visits_items_in_order() {
let mut visited = Vec::new();
Sorted::new(TestVisitable {
items: vec![1, 2, 3],
})
.items()
.try_for_each(|item| {
visited.push(item);
Ok::<(), std::convert::Infallible>(())
})
.expect("visit items");
assert_eq!(visited, vec![1, 2, 3]);
}
#[test]
fn sorted_items_stops_on_visitor_error() {
#[derive(Debug, PartialEq, Eq)]
struct SinkError;
impl std::fmt::Display for SinkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("sink error")
}
}
impl std::error::Error for SinkError {}
let mut visited = Vec::new();
let err = Sorted::new(TestVisitable {
items: vec![1, 2, 3],
})
.items()
.try_for_each(|item| {
visited.push(item);
if item == 2 { Err(SinkError) } else { Ok(()) }
})
.expect_err("sink error should stop traversal");
assert!(matches!(err, SortedItemsError::Sink(SinkError)));
assert_eq!(visited, vec![1, 2]);
}
#[test]
fn sort_already_sorted() {
let mut sorter = u64_sorter(100);
for v in [1, 2, 3, 4, 5] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert_eq!(results, vec![1, 2, 3, 4, 5]);
}
#[test]
fn sort_reverse_input() {
let mut sorter = u64_sorter(100);
for v in [5, 4, 3, 2, 1] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert_eq!(results, vec![1, 2, 3, 4, 5]);
}
#[test]
fn sort_with_spilling() {
let mut sorter = u64_sorter(3);
for v in [9, 7, 5, 3, 1, 2, 4, 6, 8, 10] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}
#[test]
fn sort_empty_input() {
let sorter = u64_sorter(100);
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert!(results.is_empty());
}
#[test]
fn sort_preserves_duplicates() {
let mut sorter = u64_sorter(3);
for v in [3, 1, 2, 1, 3, 2] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert_eq!(results, vec![1, 1, 2, 2, 3, 3]);
}
#[test]
fn sort_with_reverse_comparator() {
let mut sorter = Builder::new()
.key(Owned((|v: &u64| *v) as fn(&u64) -> u64))
.codec(U64Codec)
.compare(Reverse(Natural))
.max_buffer_items::<u64>(3)
.build();
for v in [1, 5, 3, 2, 4] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert_eq!(results, vec![5, 4, 3, 2, 1]);
}
#[test]
fn sort_with_dedup() {
let mut sorter = Builder::new()
.key(Owned((|v: &u64| *v) as fn(&u64) -> u64))
.codec(U64Codec)
.dedup(AdjacentDedup::new(|a: &u64, b: &u64| a == b))
.max_buffer_items::<u64>(3)
.build();
for v in [3, 1, 2, 1, 3, 2] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r: Result<u64, _>| r.expect("read"))
.collect();
assert_eq!(results, vec![1, 2, 3]);
}
#[test]
fn sort_with_byte_budget() {
let mut sorter = Builder::new()
.key(Owned((|v: &u64| *v) as fn(&u64) -> u64))
.codec(U64Codec)
.fixed_size_budget::<u64>(24)
.build();
for v in [9, 7, 5, 3, 1, 2, 4, 6, 8, 10] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r: Result<u64, _>| r.expect("read"))
.collect();
assert_eq!(results, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}
#[test]
fn sort_all_in_memory_no_spill() {
let mut sorter = u64_sorter(1000);
for v in [5, 3, 1, 4, 2] {
sorter.push(v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
assert_eq!(results, vec![1, 2, 3, 4, 5]);
}
mod proptests {
use proptest::prelude::*;
use super::*;
proptest! {
#[test]
fn output_is_always_sorted(
data in proptest::collection::vec(0u64..10_000, 0..500),
max_items in 3usize..50,
) {
let mut sorter = u64_sorter(max_items);
for v in &data {
sorter.push(*v).expect("push");
}
let results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
prop_assert!(
results.windows(2).all(|w| w[0] <= w[1]),
"output must be sorted"
);
}
#[test]
fn output_preserves_all_items(
data in proptest::collection::vec(0u64..1_000, 0..200),
max_items in 3usize..50,
) {
let mut sorter = u64_sorter(max_items);
for v in &data {
sorter.push(*v).expect("push");
}
let mut results: Vec<u64> = sorter
.finish()
.expect("finish")
.map(|r| r.expect("read"))
.collect();
let mut expected = data;
expected.sort_unstable();
results.sort_unstable();
prop_assert_eq!(results, expected);
}
}
}
}