use std::alloc::{alloc, dealloc, Layout};
use std::borrow::Cow;
use std::convert::Infallible;
use std::fmt::Debug;
#[cfg(feature = "tempfile")]
use std::fs::File;
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
use std::mem::{align_of, size_of};
use std::num::NonZeroUsize;
use std::{cmp, io, ops, slice};
use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use crate::count_write::CountWrite;
const INITIAL_SORTER_VEC_SIZE: usize = 131_072; const DEFAULT_SORTER_MEMORY: usize = 1_073_741_824; const MIN_SORTER_MEMORY: usize = 10_485_760;
const DEFAULT_NB_CHUNKS: usize = 25;
const MIN_NB_CHUNKS: usize = 1;
use crate::{
CompressionType, Error, MergeFunction, Merger, MergerIter, Reader, ReaderCursor, Writer,
WriterBuilder,
};
#[derive(Debug, Clone, Copy)]
pub enum SortAlgorithm {
Stable,
Unstable,
}
#[derive(Debug, Clone, Copy)]
pub struct SorterBuilder<MF, CC> {
dump_threshold: usize,
allow_realloc: bool,
max_nb_chunks: usize,
chunk_compression_type: Option<CompressionType>,
chunk_compression_level: Option<u32>,
index_key_interval: Option<NonZeroUsize>,
block_size: Option<usize>,
index_levels: Option<u8>,
chunk_creator: CC,
sort_algorithm: SortAlgorithm,
sort_in_parallel: bool,
merge: MF,
}
impl<MF> SorterBuilder<MF, DefaultChunkCreator> {
pub fn new(merge: MF) -> Self {
SorterBuilder {
dump_threshold: DEFAULT_SORTER_MEMORY,
allow_realloc: true,
max_nb_chunks: DEFAULT_NB_CHUNKS,
chunk_compression_type: None,
chunk_compression_level: None,
index_key_interval: None,
block_size: None,
index_levels: None,
chunk_creator: DefaultChunkCreator::default(),
sort_algorithm: SortAlgorithm::Stable,
sort_in_parallel: false,
merge,
}
}
}
impl<MF, CC> SorterBuilder<MF, CC> {
pub fn dump_threshold(&mut self, memory: usize) -> &mut Self {
self.dump_threshold = cmp::max(memory, MIN_SORTER_MEMORY);
self
}
pub fn allow_realloc(&mut self, allow: bool) -> &mut Self {
self.allow_realloc = allow;
self
}
pub fn max_nb_chunks(&mut self, nb_chunks: usize) -> &mut Self {
self.max_nb_chunks = cmp::max(nb_chunks, MIN_NB_CHUNKS);
self
}
pub fn chunk_compression_type(&mut self, compression: CompressionType) -> &mut Self {
self.chunk_compression_type = Some(compression);
self
}
pub fn chunk_compression_level(&mut self, level: u32) -> &mut Self {
self.chunk_compression_level = Some(level);
self
}
pub fn index_key_interval(&mut self, interval: NonZeroUsize) -> &mut Self {
self.index_key_interval = Some(interval);
self
}
pub fn block_size(&mut self, size: usize) -> &mut Self {
self.block_size = Some(size);
self
}
pub fn index_levels(&mut self, levels: u8) -> &mut Self {
self.index_levels = Some(levels);
self
}
pub fn sort_algorithm(&mut self, algo: SortAlgorithm) -> &mut Self {
self.sort_algorithm = algo;
self
}
#[cfg(feature = "rayon")]
pub fn sort_in_parallel(&mut self, value: bool) -> &mut Self {
self.sort_in_parallel = value;
self
}
pub fn chunk_creator<CC2>(self, creation: CC2) -> SorterBuilder<MF, CC2> {
SorterBuilder {
dump_threshold: self.dump_threshold,
allow_realloc: self.allow_realloc,
max_nb_chunks: self.max_nb_chunks,
chunk_compression_type: self.chunk_compression_type,
chunk_compression_level: self.chunk_compression_level,
index_key_interval: self.index_key_interval,
block_size: self.block_size,
index_levels: self.index_levels,
chunk_creator: creation,
sort_algorithm: self.sort_algorithm,
sort_in_parallel: self.sort_in_parallel,
merge: self.merge,
}
}
}
impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
pub fn build(self) -> Sorter<MF, CC> {
let capacity =
if self.allow_realloc { INITIAL_SORTER_VEC_SIZE } else { self.dump_threshold };
Sorter {
chunks: Vec::new(),
entries: Entries::with_capacity(capacity),
chunks_total_size: 0,
allow_realloc: self.allow_realloc,
dump_threshold: self.dump_threshold,
max_nb_chunks: self.max_nb_chunks,
chunk_compression_type: self.chunk_compression_type,
chunk_compression_level: self.chunk_compression_level,
index_key_interval: self.index_key_interval,
block_size: self.block_size,
index_levels: self.index_levels,
chunk_creator: self.chunk_creator,
sort_algorithm: self.sort_algorithm,
sort_in_parallel: self.sort_in_parallel,
merge_function: self.merge,
}
}
}
struct Entries {
buffer: EntryBoundAlignedBuffer,
entries_len: usize,
bounds_count: usize,
}
impl Entries {
pub fn with_capacity(capacity: usize) -> Self {
Self { buffer: EntryBoundAlignedBuffer::new(capacity), entries_len: 0, bounds_count: 0 }
}
pub fn clear(&mut self) {
self.entries_len = 0;
self.bounds_count = 0;
}
pub fn insert(&mut self, key: &[u8], data: &[u8]) {
assert!(key.len() <= u32::MAX as usize);
assert!(data.len() <= u32::MAX as usize);
if self.fits(key, data) {
self.entries_len += key.len() + data.len();
let entries_start = self.buffer.len() - self.entries_len;
self.buffer[entries_start..][..key.len()].copy_from_slice(key);
self.buffer[entries_start + key.len()..][..data.len()].copy_from_slice(data);
let bound = EntryBound {
key_start: self.entries_len,
key_length: key.len() as u32,
data_length: data.len() as u32,
};
let bounds_end = (self.bounds_count + 1) * size_of::<EntryBound>();
let bounds = cast_slice_mut::<_, EntryBound>(&mut self.buffer[..bounds_end]);
bounds[self.bounds_count] = bound;
self.bounds_count += 1;
} else {
self.reallocate_buffer();
self.insert(key, data);
}
}
pub fn fits(&self, key: &[u8], data: &[u8]) -> bool {
let aligned_bounds_count = unsafe { self.buffer.align_to::<EntryBound>().1.len() };
let remaining_aligned_bounds = aligned_bounds_count - self.bounds_count;
self.remaining() >= Self::entry_size(key, data) && remaining_aligned_bounds >= 1
}
pub fn memory_usage(&self) -> usize {
self.buffer.len()
}
pub fn sort_by_key(&mut self, algorithm: SortAlgorithm) {
let bounds_end = self.bounds_count * size_of::<EntryBound>();
let (bounds, tail) = self.buffer.split_at_mut(bounds_end);
let bounds = cast_slice_mut::<_, EntryBound>(bounds);
let sort = match algorithm {
SortAlgorithm::Stable => <[EntryBound]>::sort_by_key,
SortAlgorithm::Unstable => <[EntryBound]>::sort_unstable_by_key,
};
sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
}
#[cfg(feature = "rayon")]
pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
use rayon::slice::ParallelSliceMut;
let bounds_end = self.bounds_count * size_of::<EntryBound>();
let (bounds, tail) = self.buffer.split_at_mut(bounds_end);
let bounds = cast_slice_mut::<_, EntryBound>(bounds);
let sort = match algorithm {
SortAlgorithm::Stable => <[EntryBound]>::par_sort_by_key,
SortAlgorithm::Unstable => <[EntryBound]>::par_sort_unstable_by_key,
};
sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
}
#[cfg(not(feature = "rayon"))]
pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
self.sort_by_key(algorithm);
}
pub fn iter(&self) -> impl Iterator<Item = (&[u8], &[u8])> + '_ {
let bounds_end = self.bounds_count * size_of::<EntryBound>();
let (bounds, tail) = self.buffer.split_at(bounds_end);
let bounds = cast_slice::<_, EntryBound>(bounds);
bounds.iter().map(move |b| {
let entries_start = tail.len() - b.key_start;
let key = &tail[entries_start..][..b.key_length as usize];
let data = &tail[entries_start + b.key_length as usize..][..b.data_length as usize];
(key, data)
})
}
pub fn estimated_entries_memory_usage(&self) -> usize {
self.memory_usage() - self.remaining()
}
fn remaining(&self) -> usize {
self.buffer.len() - self.entries_len - self.bounds_count * size_of::<EntryBound>()
}
fn entry_size(key: &[u8], data: &[u8]) -> usize {
size_of::<EntryBound>() + key.len() + data.len()
}
fn reallocate_buffer(&mut self) {
let bounds_end = self.bounds_count * size_of::<EntryBound>();
let bounds_bytes = &self.buffer[..bounds_end];
let entries_start = self.buffer.len() - self.entries_len;
let entries_bytes = &self.buffer[entries_start..];
let mut new_buffer = EntryBoundAlignedBuffer::new(self.buffer.len() * 2);
new_buffer[..bounds_end].copy_from_slice(bounds_bytes);
let new_entries_start = new_buffer.len() - self.entries_len;
new_buffer[new_entries_start..].copy_from_slice(entries_bytes);
self.buffer = new_buffer;
}
}
#[derive(Default, Copy, Clone, Pod, Zeroable)]
#[repr(C)]
struct EntryBound {
key_start: usize,
key_length: u32,
data_length: u32,
}
struct EntryBoundAlignedBuffer(&'static mut [u8]);
impl EntryBoundAlignedBuffer {
fn new(size: usize) -> EntryBoundAlignedBuffer {
let entry_bound_size = size_of::<EntryBound>();
let size = (size + entry_bound_size - 1) / entry_bound_size * entry_bound_size;
let layout = Layout::from_size_align(size, align_of::<EntryBound>()).unwrap();
let ptr = unsafe { alloc(layout) };
assert!(
!ptr.is_null(),
"the allocator is unable to allocate that much memory ({} bytes requested)",
size
);
let slice = unsafe { slice::from_raw_parts_mut(ptr, size) };
EntryBoundAlignedBuffer(slice)
}
}
impl ops::Deref for EntryBoundAlignedBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.0
}
}
impl ops::DerefMut for EntryBoundAlignedBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
}
}
impl Drop for EntryBoundAlignedBuffer {
fn drop(&mut self) {
let layout = Layout::from_size_align(self.0.len(), align_of::<EntryBound>()).unwrap();
unsafe { dealloc(self.0.as_mut_ptr(), layout) }
}
}
pub struct Sorter<MF, CC: ChunkCreator = DefaultChunkCreator> {
chunks: Vec<CC::Chunk>,
entries: Entries,
chunks_total_size: u64,
allow_realloc: bool,
dump_threshold: usize,
max_nb_chunks: usize,
chunk_compression_type: Option<CompressionType>,
chunk_compression_level: Option<u32>,
index_key_interval: Option<NonZeroUsize>,
block_size: Option<usize>,
index_levels: Option<u8>,
chunk_creator: CC,
sort_algorithm: SortAlgorithm,
sort_in_parallel: bool,
merge_function: MF,
}
impl<MF> Sorter<MF, DefaultChunkCreator> {
pub fn builder(merge: MF) -> SorterBuilder<MF, DefaultChunkCreator> {
SorterBuilder::new(merge)
}
pub fn new(merge: MF) -> Sorter<MF, DefaultChunkCreator> {
SorterBuilder::new(merge).build()
}
pub fn estimated_dumped_memory_usage(&self) -> u64 {
self.entries.estimated_entries_memory_usage() as u64 + self.chunks_total_size
}
}
impl<MF, CC> Sorter<MF, CC>
where
MF: MergeFunction,
CC: ChunkCreator,
{
pub fn insert<K, V>(&mut self, key: K, val: V) -> crate::Result<(), MF::Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let key = key.as_ref();
let val = val.as_ref();
#[allow(clippy::branches_sharing_code)]
if self.entries.fits(key, val) || (!self.threshold_exceeded() && self.allow_realloc) {
self.entries.insert(key, val);
} else {
self.chunks_total_size += self.write_chunk()?;
self.entries.insert(key, val);
if self.chunks.len() >= self.max_nb_chunks {
self.chunks_total_size = self.merge_chunks()?;
}
}
Ok(())
}
fn threshold_exceeded(&self) -> bool {
self.entries.memory_usage() >= self.dump_threshold
}
fn write_chunk(&mut self) -> crate::Result<u64, MF::Error> {
let count_write_chunk = self
.chunk_creator
.create()
.map_err(Into::into)
.map_err(Error::convert_merge_error)
.map(CountWrite::new)?;
let mut writer_builder = WriterBuilder::new();
if let Some(compression_type) = self.chunk_compression_type {
writer_builder.compression_type(compression_type);
}
if let Some(compression_level) = self.chunk_compression_level {
writer_builder.compression_level(compression_level);
}
if let Some(index_key_interval) = self.index_key_interval {
writer_builder.index_key_interval(index_key_interval);
}
if let Some(block_size) = self.block_size {
writer_builder.block_size(block_size);
}
if let Some(index_levels) = self.index_levels {
writer_builder.index_levels(index_levels);
}
let mut writer = writer_builder.build(count_write_chunk);
if self.sort_in_parallel {
self.entries.par_sort_by_key(self.sort_algorithm);
} else {
self.entries.sort_by_key(self.sort_algorithm);
}
let mut current = None;
for (key, value) in self.entries.iter() {
match current.as_mut() {
None => current = Some((key, vec![Cow::Borrowed(value)])),
Some((current_key, vals)) => {
if current_key != &key {
let merged_val =
self.merge_function.merge(current_key, vals).map_err(Error::Merge)?;
writer.insert(¤t_key, &merged_val)?;
vals.clear();
*current_key = key;
}
vals.push(Cow::Borrowed(value));
}
}
}
if let Some((key, vals)) = current.take() {
let merged_val = self.merge_function.merge(key, &vals).map_err(Error::Merge)?;
writer.insert(key, &merged_val)?;
}
let mut count_write_chunk = writer.into_inner()?;
count_write_chunk.flush()?;
let written_bytes = count_write_chunk.count();
let chunk = count_write_chunk.into_inner()?;
self.chunks.push(chunk);
self.entries.clear();
Ok(written_bytes)
}
fn merge_chunks(&mut self) -> crate::Result<u64, MF::Error> {
let count_write_chunk = self
.chunk_creator
.create()
.map_err(Into::into)
.map_err(Error::convert_merge_error)
.map(CountWrite::new)?;
let mut writer_builder = WriterBuilder::new();
if let Some(compression_type) = self.chunk_compression_type {
writer_builder.compression_type(compression_type);
}
if let Some(compression_level) = self.chunk_compression_level {
writer_builder.compression_level(compression_level);
}
if let Some(index_key_interval) = self.index_key_interval {
writer_builder.index_key_interval(index_key_interval);
}
if let Some(block_size) = self.block_size {
writer_builder.block_size(block_size);
}
if let Some(index_levels) = self.index_levels {
writer_builder.index_levels(index_levels);
}
let mut writer = writer_builder.build(count_write_chunk);
let sources: crate::Result<Vec<_>, MF::Error> = self
.chunks
.drain(..)
.map(|mut chunk| {
chunk.seek(SeekFrom::Start(0))?;
Reader::new(chunk).and_then(Reader::into_cursor).map_err(Error::convert_merge_error)
})
.collect();
let mut builder = Merger::builder(&self.merge_function);
builder.extend(sources?);
let merger = builder.build();
let mut iter = merger.into_stream_merger_iter().map_err(Error::convert_merge_error)?;
while let Some((key, val)) = iter.next()? {
writer.insert(key, val)?;
}
let mut count_write_chunk = writer.into_inner()?;
count_write_chunk.flush()?;
let written_bytes = count_write_chunk.count();
let chunk = count_write_chunk.into_inner()?;
self.chunks.push(chunk);
Ok(written_bytes)
}
pub fn write_into_stream_writer<W: io::Write>(
self,
writer: &mut Writer<W>,
) -> crate::Result<(), MF::Error> {
let mut iter = self.into_stream_merger_iter()?;
while let Some((key, val)) = iter.next()? {
writer.insert(key, val)?;
}
Ok(())
}
pub fn into_stream_merger_iter(self) -> crate::Result<MergerIter<CC::Chunk, MF>, MF::Error> {
let (sources, merge) = self.extract_reader_cursors_and_merger()?;
let mut builder = Merger::builder(merge);
builder.extend(sources);
builder.build().into_stream_merger_iter().map_err(Error::convert_merge_error)
}
pub fn into_reader_cursors(self) -> crate::Result<Vec<ReaderCursor<CC::Chunk>>, MF::Error> {
self.extract_reader_cursors_and_merger().map(|(readers, _)| readers)
}
#[allow(clippy::type_complexity)] fn extract_reader_cursors_and_merger(
mut self,
) -> crate::Result<(Vec<ReaderCursor<CC::Chunk>>, MF), MF::Error> {
self.chunks_total_size = self.write_chunk()?;
let Sorter { chunks, merge_function: merge, .. } = self;
let result: Result<Vec<_>, _> = chunks
.into_iter()
.map(|mut chunk| {
chunk.seek(SeekFrom::Start(0))?;
Reader::new(chunk).and_then(Reader::into_cursor).map_err(Error::convert_merge_error)
})
.collect();
result.map(|readers| (readers, merge))
}
}
impl<MF, CC: ChunkCreator> Debug for Sorter<MF, CC> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sorter")
.field("chunks_count", &self.chunks.len())
.field("remaining_entries", &self.entries.remaining())
.field("chunks_total_size", &self.chunks_total_size)
.field("allow_realloc", &self.allow_realloc)
.field("dump_threshold", &self.dump_threshold)
.field("max_nb_chunks", &self.max_nb_chunks)
.field("chunk_compression_type", &self.chunk_compression_type)
.field("chunk_compression_level", &self.chunk_compression_level)
.field("index_key_interval", &self.index_key_interval)
.field("block_size", &self.block_size)
.field("index_levels", &self.index_levels)
.field("chunk_creator", &"[chunck creator]")
.field("sort_algorithm", &self.sort_algorithm)
.field("sort_in_parallel", &self.sort_in_parallel)
.field("merge", &"[merge function]")
.finish()
}
}
pub trait ChunkCreator {
type Chunk: Write + Seek + Read;
type Error: Into<Error>;
fn create(&self) -> Result<Self::Chunk, Self::Error>;
}
#[cfg(feature = "tempfile")]
pub type DefaultChunkCreator = TempFileChunk;
#[cfg(not(feature = "tempfile"))]
pub type DefaultChunkCreator = CursorVec;
impl<C: Write + Seek + Read, E: Into<Error>> ChunkCreator for dyn Fn() -> Result<C, E> {
type Chunk = C;
type Error = E;
fn create(&self) -> Result<Self::Chunk, Self::Error> {
self()
}
}
#[cfg(feature = "tempfile")]
#[derive(Debug, Default, Copy, Clone)]
pub struct TempFileChunk;
#[cfg(feature = "tempfile")]
impl ChunkCreator for TempFileChunk {
type Chunk = File;
type Error = io::Error;
fn create(&self) -> Result<Self::Chunk, Self::Error> {
tempfile::tempfile()
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct CursorVec;
impl ChunkCreator for CursorVec {
type Chunk = Cursor<Vec<u8>>;
type Error = Infallible;
fn create(&self) -> Result<Self::Chunk, Self::Error> {
Ok(Cursor::new(Vec::new()))
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use std::io::Cursor;
use std::iter::repeat;
use super::*;
#[derive(Copy, Clone)]
struct ConcatMerger;
impl MergeFunction for ConcatMerger {
type Error = Infallible;
fn merge<'a>(
&self,
_key: &[u8],
values: &[Cow<'a, [u8]>],
) -> std::result::Result<Cow<'a, [u8]>, Self::Error> {
Ok(values.iter().flat_map(AsRef::as_ref).cloned().collect())
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn simple_cursorvec() {
let mut sorter = SorterBuilder::new(ConcatMerger)
.chunk_compression_type(CompressionType::Snappy)
.chunk_creator(CursorVec)
.build();
sorter.insert(b"hello", "kiki").unwrap();
sorter.insert(b"abstract", "lol").unwrap();
sorter.insert(b"allo", "lol").unwrap();
sorter.insert(b"abstract", "lol").unwrap();
let mut bytes = WriterBuilder::new().memory();
sorter.write_into_stream_writer(&mut bytes).unwrap();
let bytes = bytes.into_inner().unwrap();
let reader = Reader::new(Cursor::new(bytes.as_slice())).unwrap();
let mut cursor = reader.into_cursor().unwrap();
while let Some((key, val)) = cursor.move_on_next().unwrap() {
match key {
b"hello" => assert_eq!(val, b"kiki"),
b"abstract" => assert_eq!(val, b"lollol"),
b"allo" => assert_eq!(val, b"lol"),
bytes => panic!("{:?}", bytes),
}
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn hard_cursorvec() {
let mut sorter = SorterBuilder::new(ConcatMerger)
.dump_threshold(1024) .allow_realloc(false)
.chunk_compression_type(CompressionType::Snappy)
.chunk_creator(CursorVec)
.build();
for _ in 0..200 {
sorter.insert(b"hello", "kiki").unwrap();
}
let mut bytes = WriterBuilder::new().memory();
sorter.write_into_stream_writer(&mut bytes).unwrap();
let bytes = bytes.into_inner().unwrap();
let reader = Reader::new(Cursor::new(bytes.as_slice())).unwrap();
let mut cursor = reader.into_cursor().unwrap();
let (key, val) = cursor.move_on_next().unwrap().unwrap();
assert_eq!(key, b"hello");
assert!(val.iter().eq(repeat(b"kiki").take(200).flatten()));
assert!(cursor.move_on_next().unwrap().is_none());
}
#[test]
#[cfg_attr(miri, ignore)]
fn correct_key_ordering() {
use std::borrow::Cow;
use rand::prelude::{SeedableRng, SliceRandom};
use rand::rngs::StdRng;
struct ConcatBytesMerger;
impl MergeFunction for ConcatBytesMerger {
type Error = Infallible;
fn merge<'a>(
&self,
_key: &[u8],
values: &[Cow<'a, [u8]>],
) -> std::result::Result<Cow<'a, [u8]>, Self::Error> {
let mut output = Vec::new();
for value in values {
output.extend_from_slice(value);
}
Ok(Cow::from(output))
}
}
let mut sorter = SorterBuilder::new(ConcatBytesMerger).chunk_creator(CursorVec).build();
let mut rng = StdRng::seed_from_u64(42);
let possible_keys = ["first", "second", "third", "fourth", "fifth", "sixth"];
for n in 0..=255 {
let key = possible_keys.choose(&mut rng).unwrap();
sorter.insert(key, [n]).unwrap();
}
let mut iter = sorter.into_stream_merger_iter().unwrap();
while let Some((_key, value)) = iter.next().unwrap() {
assert!(value.windows(2).all(|w| w[0] <= w[1]), "{:?}", value);
}
}
#[test]
#[should_panic(
expected = "the allocator is unable to allocate that much memory (281474976710656 bytes requested)"
)]
#[cfg_attr(miri, ignore)]
fn too_big_allocation() {
EntryBoundAlignedBuffer::new(1 << 48);
}
}