grenad/
sorter.rs

1use std::alloc::{alloc, dealloc, Layout};
2use std::borrow::Cow;
3use std::convert::Infallible;
4use std::fmt::Debug;
5#[cfg(feature = "tempfile")]
6use std::fs::File;
7use std::io::{Cursor, Read, Seek, SeekFrom, Write};
8use std::mem::{align_of, size_of};
9use std::num::NonZeroUsize;
10use std::{cmp, io, ops, slice};
11
12use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
13
14use crate::count_write::CountWrite;
15
16const INITIAL_SORTER_VEC_SIZE: usize = 131_072; // 128KB
17const DEFAULT_SORTER_MEMORY: usize = 1_073_741_824; // 1GB
18const MIN_SORTER_MEMORY: usize = 10_485_760; // 10MB
19
20const DEFAULT_NB_CHUNKS: usize = 25;
21const MIN_NB_CHUNKS: usize = 1;
22
23use crate::{
24    CompressionType, Error, MergeFunction, Merger, MergerIter, Reader, ReaderCursor, Writer,
25    WriterBuilder,
26};
27
28/// The kind of sort algorithm used by the sorter to sort its internal vector.
29#[derive(Debug, Clone, Copy)]
30pub enum SortAlgorithm {
31    /// The stable sort algorithm maintains the relative order of values with equal keys,
32    /// but it is slower than the unstable algorithm.
33    Stable,
34    /// The unstable sort algorithm is faster than the unstable algorithm, but it
35    /// does not keep the relative order of values with equal keys.
36    Unstable,
37}
38
39/// A struct that is used to configure a [`Sorter`] to better fit your needs.
40#[derive(Debug, Clone, Copy)]
41pub struct SorterBuilder<MF, CC> {
42    dump_threshold: usize,
43    allow_realloc: bool,
44    max_nb_chunks: usize,
45    chunk_compression_type: Option<CompressionType>,
46    chunk_compression_level: Option<u32>,
47    index_key_interval: Option<NonZeroUsize>,
48    block_size: Option<usize>,
49    index_levels: Option<u8>,
50    chunk_creator: CC,
51    sort_algorithm: SortAlgorithm,
52    sort_in_parallel: bool,
53    merge: MF,
54}
55
56impl<MF> SorterBuilder<MF, DefaultChunkCreator> {
57    /// Creates a [`SorterBuilder`] from a merge function, it can be
58    /// used to configure your [`Sorter`] to better fit your needs.
59    pub fn new(merge: MF) -> Self {
60        SorterBuilder {
61            dump_threshold: DEFAULT_SORTER_MEMORY,
62            allow_realloc: true,
63            max_nb_chunks: DEFAULT_NB_CHUNKS,
64            chunk_compression_type: None,
65            chunk_compression_level: None,
66            index_key_interval: None,
67            block_size: None,
68            index_levels: None,
69            chunk_creator: DefaultChunkCreator::default(),
70            sort_algorithm: SortAlgorithm::Stable,
71            sort_in_parallel: false,
72            merge,
73        }
74    }
75}
76
77impl<MF, CC> SorterBuilder<MF, CC> {
78    /// The amount of memory to reach that will trigger a memory dump from in memory to disk.
79    pub fn dump_threshold(&mut self, memory: usize) -> &mut Self {
80        self.dump_threshold = cmp::max(memory, MIN_SORTER_MEMORY);
81        self
82    }
83
84    /// Whether the sorter is allowed or not to reallocate the internal vector.
85    ///
86    /// Note that reallocating involve a more important memory usage and disallowing
87    /// it will make the sorter to **always** consume the dump threshold memory.
88    pub fn allow_realloc(&mut self, allow: bool) -> &mut Self {
89        self.allow_realloc = allow;
90        self
91    }
92
93    /// The maximum number of chunks on disk, if this number of chunks is reached
94    /// they will be merged into a single chunk. Merging can reduce the disk usage.
95    pub fn max_nb_chunks(&mut self, nb_chunks: usize) -> &mut Self {
96        self.max_nb_chunks = cmp::max(nb_chunks, MIN_NB_CHUNKS);
97        self
98    }
99
100    /// Defines the compression type the built [`Sorter`] will use when buffering.
101    pub fn chunk_compression_type(&mut self, compression: CompressionType) -> &mut Self {
102        self.chunk_compression_type = Some(compression);
103        self
104    }
105
106    /// Defines the compression level that the defined compression type will use.
107    pub fn chunk_compression_level(&mut self, level: u32) -> &mut Self {
108        self.chunk_compression_level = Some(level);
109        self
110    }
111
112    /// The interval at which we store the index of a key in the
113    /// index footer, used to seek into a block.
114    pub fn index_key_interval(&mut self, interval: NonZeroUsize) -> &mut Self {
115        self.index_key_interval = Some(interval);
116        self
117    }
118
119    /// Defines the size of the blocks that the writer will write.
120    ///
121    /// The bigger the blocks are the better they are compressed
122    /// but the more time it takes to compress and decompress them.
123    pub fn block_size(&mut self, size: usize) -> &mut Self {
124        self.block_size = Some(size);
125        self
126    }
127
128    /// The number of levels/indirection we will use to write the index footer.
129    ///
130    /// An indirection of 1 or 2 is sufficient to reduce the impact of
131    /// decompressing/reading a big index footer.
132    ///
133    /// The default is 0 which means that the index footer values directly
134    /// specifies the block where the requested entry can be found. The disavantage of this
135    /// is that the index block can be quite big and take time to be decompressed and read.
136    pub fn index_levels(&mut self, levels: u8) -> &mut Self {
137        self.index_levels = Some(levels);
138        self
139    }
140
141    /// The algorithm used to sort the internal vector.
142    ///
143    /// The default is [`SortAlgorithm::Stable`](crate::SortAlgorithm::Stable).
144    pub fn sort_algorithm(&mut self, algo: SortAlgorithm) -> &mut Self {
145        self.sort_algorithm = algo;
146        self
147    }
148
149    /// Whether we use [rayon to sort](https://docs.rs/rayon/latest/rayon/slice/trait.ParallelSliceMut.html#method.par_sort_by_key) the entries.
150    ///
151    /// By default we do not sort in parallel, the value is `false`.
152    #[cfg(feature = "rayon")]
153    pub fn sort_in_parallel(&mut self, value: bool) -> &mut Self {
154        self.sort_in_parallel = value;
155        self
156    }
157
158    /// The [`ChunkCreator`] struct used to generate the chunks used
159    /// by the [`Sorter`] to bufferize when required.
160    pub fn chunk_creator<CC2>(self, creation: CC2) -> SorterBuilder<MF, CC2> {
161        SorterBuilder {
162            dump_threshold: self.dump_threshold,
163            allow_realloc: self.allow_realloc,
164            max_nb_chunks: self.max_nb_chunks,
165            chunk_compression_type: self.chunk_compression_type,
166            chunk_compression_level: self.chunk_compression_level,
167            index_key_interval: self.index_key_interval,
168            block_size: self.block_size,
169            index_levels: self.index_levels,
170            chunk_creator: creation,
171            sort_algorithm: self.sort_algorithm,
172            sort_in_parallel: self.sort_in_parallel,
173            merge: self.merge,
174        }
175    }
176}
177
178impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
179    /// Creates the [`Sorter`] configured by this builder.
180    pub fn build(self) -> Sorter<MF, CC> {
181        let capacity =
182            if self.allow_realloc { INITIAL_SORTER_VEC_SIZE } else { self.dump_threshold };
183
184        Sorter {
185            chunks: Vec::new(),
186            entries: Entries::with_capacity(capacity),
187            chunks_total_size: 0,
188            allow_realloc: self.allow_realloc,
189            dump_threshold: self.dump_threshold,
190            max_nb_chunks: self.max_nb_chunks,
191            chunk_compression_type: self.chunk_compression_type,
192            chunk_compression_level: self.chunk_compression_level,
193            index_key_interval: self.index_key_interval,
194            block_size: self.block_size,
195            index_levels: self.index_levels,
196            chunk_creator: self.chunk_creator,
197            sort_algorithm: self.sort_algorithm,
198            sort_in_parallel: self.sort_in_parallel,
199            merge_function: self.merge,
200        }
201    }
202}
203
204/// Stores entries memory efficiently in a buffer.
205struct Entries {
206    /// The internal buffer that contains the bounds of the buffer
207    /// on the front and the key and data bytes on the back of it.
208    ///
209    /// [----bounds---->--remaining--<--key+data--]
210    ///
211    buffer: EntryBoundAlignedBuffer,
212
213    /// The amount of bytes stored in the buffer.
214    entries_len: usize,
215
216    /// The number of bounds stored in the buffer.
217    bounds_count: usize,
218}
219
220impl Entries {
221    /// Creates a buffer which will consumes this amount of memory,
222    /// rounded up to the size of one `EntryBound` more.
223    ///
224    /// It will use this amount of memory until it needs to reallocate
225    /// where it will create a new buffer of twice the size of the current one
226    /// copies the entries inside and replace the current one by the new one.
227    ///
228    /// If you want to be sure about the amount of memory used you can use
229    /// the `fits` method.
230    pub fn with_capacity(capacity: usize) -> Self {
231        Self { buffer: EntryBoundAlignedBuffer::new(capacity), entries_len: 0, bounds_count: 0 }
232    }
233
234    /// Clear the entries.
235    pub fn clear(&mut self) {
236        self.entries_len = 0;
237        self.bounds_count = 0;
238    }
239
240    /// Inserts a new entry into the buffer, if there is not
241    /// enough space for it to be stored, we double the buffer size.
242    pub fn insert(&mut self, key: &[u8], data: &[u8]) {
243        assert!(key.len() <= u32::MAX as usize);
244        assert!(data.len() <= u32::MAX as usize);
245
246        if self.fits(key, data) {
247            // We store the key and data bytes one after the other at the back of the buffer.
248            self.entries_len += key.len() + data.len();
249            let entries_start = self.buffer.len() - self.entries_len;
250            self.buffer[entries_start..][..key.len()].copy_from_slice(key);
251            self.buffer[entries_start + key.len()..][..data.len()].copy_from_slice(data);
252
253            let bound = EntryBound {
254                key_start: self.entries_len,
255                key_length: key.len() as u32,
256                data_length: data.len() as u32,
257            };
258
259            // We store the bounds at the front of the buffer and grow from the end to the start
260            // of it. We interpret the front of the buffer as a slice of EntryBounds + 1 entry
261            // that is not assigned and replace it with the new one we want to insert.
262            let bounds_end = (self.bounds_count + 1) * size_of::<EntryBound>();
263            let bounds = cast_slice_mut::<_, EntryBound>(&mut self.buffer[..bounds_end]);
264            bounds[self.bounds_count] = bound;
265            self.bounds_count += 1;
266        } else {
267            self.reallocate_buffer();
268            self.insert(key, data);
269        }
270    }
271
272    /// Returns `true` if inserting this entry will not trigger a reallocation.
273    pub fn fits(&self, key: &[u8], data: &[u8]) -> bool {
274        // The number of memory aligned EntryBounds that we can store.
275        let aligned_bounds_count = unsafe { self.buffer.align_to::<EntryBound>().1.len() };
276        let remaining_aligned_bounds = aligned_bounds_count - self.bounds_count;
277
278        self.remaining() >= Self::entry_size(key, data) && remaining_aligned_bounds >= 1
279    }
280
281    /// Simply returns the size of the internal buffer.
282    pub fn memory_usage(&self) -> usize {
283        self.buffer.len()
284    }
285
286    /// Sorts the entry bounds by the entries keys, after a sort
287    /// the `iter` method will yield the entries sorted.
288    pub fn sort_by_key(&mut self, algorithm: SortAlgorithm) {
289        let bounds_end = self.bounds_count * size_of::<EntryBound>();
290        let (bounds, tail) = self.buffer.split_at_mut(bounds_end);
291        let bounds = cast_slice_mut::<_, EntryBound>(bounds);
292        let sort = match algorithm {
293            SortAlgorithm::Stable => <[EntryBound]>::sort_by_key,
294            SortAlgorithm::Unstable => <[EntryBound]>::sort_unstable_by_key,
295        };
296        sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
297    }
298
299    /// Sorts in **parallel** the entry bounds by the entries keys,
300    /// after a sort the `iter` method will yield the entries sorted.
301    #[cfg(feature = "rayon")]
302    pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
303        use rayon::slice::ParallelSliceMut;
304
305        let bounds_end = self.bounds_count * size_of::<EntryBound>();
306        let (bounds, tail) = self.buffer.split_at_mut(bounds_end);
307        let bounds = cast_slice_mut::<_, EntryBound>(bounds);
308        let sort = match algorithm {
309            SortAlgorithm::Stable => <[EntryBound]>::par_sort_by_key,
310            SortAlgorithm::Unstable => <[EntryBound]>::par_sort_unstable_by_key,
311        };
312        sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
313    }
314
315    #[cfg(not(feature = "rayon"))]
316    pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
317        self.sort_by_key(algorithm);
318    }
319
320    /// Returns an iterator over the keys and datas.
321    pub fn iter(&self) -> impl Iterator<Item = (&[u8], &[u8])> + '_ {
322        let bounds_end = self.bounds_count * size_of::<EntryBound>();
323        let (bounds, tail) = self.buffer.split_at(bounds_end);
324        let bounds = cast_slice::<_, EntryBound>(bounds);
325
326        bounds.iter().map(move |b| {
327            let entries_start = tail.len() - b.key_start;
328            let key = &tail[entries_start..][..b.key_length as usize];
329            let data = &tail[entries_start + b.key_length as usize..][..b.data_length as usize];
330            (key, data)
331        })
332    }
333
334    /// Returns the approximative memory usage of the rough entries.
335    ///
336    /// This is a very bad estimate in the sense that it does not calculate the amount of
337    /// duplicate entries and the fact that entries can be compressed once dumped to disk.
338    /// This estimate will always be greater than the actual end space usage on disk.
339    pub fn estimated_entries_memory_usage(&self) -> usize {
340        self.memory_usage() - self.remaining()
341    }
342
343    /// The remaining amount of bytes before we need to reallocate a new buffer.
344    fn remaining(&self) -> usize {
345        self.buffer.len() - self.entries_len - self.bounds_count * size_of::<EntryBound>()
346    }
347
348    /// The size that this entry will need to be stored in the buffer.
349    fn entry_size(key: &[u8], data: &[u8]) -> usize {
350        size_of::<EntryBound>() + key.len() + data.len()
351    }
352
353    /// Doubles the size of the internal buffer, copies the entries and bounds into the new buffer.
354    fn reallocate_buffer(&mut self) {
355        let bounds_end = self.bounds_count * size_of::<EntryBound>();
356        let bounds_bytes = &self.buffer[..bounds_end];
357
358        let entries_start = self.buffer.len() - self.entries_len;
359        let entries_bytes = &self.buffer[entries_start..];
360
361        let mut new_buffer = EntryBoundAlignedBuffer::new(self.buffer.len() * 2);
362        new_buffer[..bounds_end].copy_from_slice(bounds_bytes);
363        let new_entries_start = new_buffer.len() - self.entries_len;
364        new_buffer[new_entries_start..].copy_from_slice(entries_bytes);
365
366        self.buffer = new_buffer;
367    }
368}
369
370#[derive(Default, Copy, Clone, Pod, Zeroable)]
371#[repr(C)]
372struct EntryBound {
373    key_start: usize,
374    key_length: u32,
375    data_length: u32,
376}
377
378/// Represents an `EntryBound` aligned buffer.
379struct EntryBoundAlignedBuffer(&'static mut [u8]);
380
381impl EntryBoundAlignedBuffer {
382    /// Allocates a new buffer of the given size, it is correctly aligned to store `EntryBound`s.
383    fn new(size: usize) -> EntryBoundAlignedBuffer {
384        let entry_bound_size = size_of::<EntryBound>();
385        let size = (size + entry_bound_size - 1) / entry_bound_size * entry_bound_size;
386        let layout = Layout::from_size_align(size, align_of::<EntryBound>()).unwrap();
387        let ptr = unsafe { alloc(layout) };
388        assert!(
389            !ptr.is_null(),
390            "the allocator is unable to allocate that much memory ({} bytes requested)",
391            size
392        );
393        let slice = unsafe { slice::from_raw_parts_mut(ptr, size) };
394        EntryBoundAlignedBuffer(slice)
395    }
396}
397
398impl ops::Deref for EntryBoundAlignedBuffer {
399    type Target = [u8];
400
401    fn deref(&self) -> &Self::Target {
402        self.0
403    }
404}
405
406impl ops::DerefMut for EntryBoundAlignedBuffer {
407    fn deref_mut(&mut self) -> &mut Self::Target {
408        self.0
409    }
410}
411
412impl Drop for EntryBoundAlignedBuffer {
413    fn drop(&mut self) {
414        let layout = Layout::from_size_align(self.0.len(), align_of::<EntryBound>()).unwrap();
415        unsafe { dealloc(self.0.as_mut_ptr(), layout) }
416    }
417}
418
419/// A struct you can use to automatically sort and merge duplicate entries.
420///
421/// You can insert key-value pairs in arbitrary order, it will use the
422/// [`ChunkCreator`] and you the generated chunks to buffer when the `dump_threashold`
423/// setting is reached.
424pub struct Sorter<MF, CC: ChunkCreator = DefaultChunkCreator> {
425    chunks: Vec<CC::Chunk>,
426    entries: Entries,
427    chunks_total_size: u64,
428    allow_realloc: bool,
429    dump_threshold: usize,
430    max_nb_chunks: usize,
431    chunk_compression_type: Option<CompressionType>,
432    chunk_compression_level: Option<u32>,
433    index_key_interval: Option<NonZeroUsize>,
434    block_size: Option<usize>,
435    index_levels: Option<u8>,
436    chunk_creator: CC,
437    sort_algorithm: SortAlgorithm,
438    sort_in_parallel: bool,
439    merge_function: MF,
440}
441
442impl<MF> Sorter<MF, DefaultChunkCreator> {
443    /// Creates a [`SorterBuilder`] from a merge function, it can be
444    /// used to configure your [`Sorter`] to better fit your needs.
445    pub fn builder(merge: MF) -> SorterBuilder<MF, DefaultChunkCreator> {
446        SorterBuilder::new(merge)
447    }
448
449    /// Creates a [`Sorter`] from a merge function, with the default parameters.
450    pub fn new(merge: MF) -> Sorter<MF, DefaultChunkCreator> {
451        SorterBuilder::new(merge).build()
452    }
453
454    /// A rough estimate of how much memory usage it will take on the disk once dumped to disk.
455    ///
456    /// This is a very bad estimate in the sense that it does not calculate the amount of
457    /// duplicate entries that are in the dumped chunks and neither the fact that the in-memory
458    /// buffer will likely be compressed once written to disk. This estimate will always
459    /// be greater than the actual end space usage on disk.
460    pub fn estimated_dumped_memory_usage(&self) -> u64 {
461        self.entries.estimated_entries_memory_usage() as u64 + self.chunks_total_size
462    }
463}
464
465impl<MF, CC> Sorter<MF, CC>
466where
467    MF: MergeFunction,
468    CC: ChunkCreator,
469{
470    /// Insert an entry into the [`Sorter`] making sure that conflicts
471    /// are resolved by the provided merge function.
472    pub fn insert<K, V>(&mut self, key: K, val: V) -> crate::Result<(), MF::Error>
473    where
474        K: AsRef<[u8]>,
475        V: AsRef<[u8]>,
476    {
477        let key = key.as_ref();
478        let val = val.as_ref();
479
480        #[allow(clippy::branches_sharing_code)]
481        if self.entries.fits(key, val) || (!self.threshold_exceeded() && self.allow_realloc) {
482            self.entries.insert(key, val);
483        } else {
484            self.chunks_total_size += self.write_chunk()?;
485            self.entries.insert(key, val);
486            if self.chunks.len() >= self.max_nb_chunks {
487                self.chunks_total_size = self.merge_chunks()?;
488            }
489        }
490
491        Ok(())
492    }
493
494    fn threshold_exceeded(&self) -> bool {
495        self.entries.memory_usage() >= self.dump_threshold
496    }
497
498    /// Returns the exact amount of bytes written to disk, the value can be trusted,
499    /// this is not an estimate.
500    ///
501    /// Writes the in-memory entries to disk, using the specify settings
502    /// to compress the block and entries. It clears the in-memory entries.
503    fn write_chunk(&mut self) -> crate::Result<u64, MF::Error> {
504        let count_write_chunk = self
505            .chunk_creator
506            .create()
507            .map_err(Into::into)
508            .map_err(Error::convert_merge_error)
509            .map(CountWrite::new)?;
510
511        let mut writer_builder = WriterBuilder::new();
512        if let Some(compression_type) = self.chunk_compression_type {
513            writer_builder.compression_type(compression_type);
514        }
515        if let Some(compression_level) = self.chunk_compression_level {
516            writer_builder.compression_level(compression_level);
517        }
518        if let Some(index_key_interval) = self.index_key_interval {
519            writer_builder.index_key_interval(index_key_interval);
520        }
521        if let Some(block_size) = self.block_size {
522            writer_builder.block_size(block_size);
523        }
524        if let Some(index_levels) = self.index_levels {
525            writer_builder.index_levels(index_levels);
526        }
527        let mut writer = writer_builder.build(count_write_chunk);
528
529        if self.sort_in_parallel {
530            self.entries.par_sort_by_key(self.sort_algorithm);
531        } else {
532            self.entries.sort_by_key(self.sort_algorithm);
533        }
534
535        let mut current = None;
536        for (key, value) in self.entries.iter() {
537            match current.as_mut() {
538                None => current = Some((key, vec![Cow::Borrowed(value)])),
539                Some((current_key, vals)) => {
540                    if current_key != &key {
541                        let merged_val =
542                            self.merge_function.merge(current_key, vals).map_err(Error::Merge)?;
543                        writer.insert(&current_key, &merged_val)?;
544                        vals.clear();
545                        *current_key = key;
546                    }
547                    vals.push(Cow::Borrowed(value));
548                }
549            }
550        }
551
552        if let Some((key, vals)) = current.take() {
553            let merged_val = self.merge_function.merge(key, &vals).map_err(Error::Merge)?;
554            writer.insert(key, &merged_val)?;
555        }
556
557        // We retrieve the wrapped CountWrite and extract
558        // the amount of bytes effectively written.
559        let mut count_write_chunk = writer.into_inner()?;
560        count_write_chunk.flush()?;
561        let written_bytes = count_write_chunk.count();
562        let chunk = count_write_chunk.into_inner()?;
563
564        self.chunks.push(chunk);
565        self.entries.clear();
566
567        Ok(written_bytes)
568    }
569
570    /// Returns the exact amount of bytes written to disk, the value can be trusted,
571    /// this is not an estimate.
572    ///
573    /// Merges all of the chunks into a final chunk that replaces them.
574    /// It uses the user provided merge function to resolve merge conflicts.
575    fn merge_chunks(&mut self) -> crate::Result<u64, MF::Error> {
576        let count_write_chunk = self
577            .chunk_creator
578            .create()
579            .map_err(Into::into)
580            .map_err(Error::convert_merge_error)
581            .map(CountWrite::new)?;
582
583        let mut writer_builder = WriterBuilder::new();
584        if let Some(compression_type) = self.chunk_compression_type {
585            writer_builder.compression_type(compression_type);
586        }
587        if let Some(compression_level) = self.chunk_compression_level {
588            writer_builder.compression_level(compression_level);
589        }
590        if let Some(index_key_interval) = self.index_key_interval {
591            writer_builder.index_key_interval(index_key_interval);
592        }
593        if let Some(block_size) = self.block_size {
594            writer_builder.block_size(block_size);
595        }
596        if let Some(index_levels) = self.index_levels {
597            writer_builder.index_levels(index_levels);
598        }
599        let mut writer = writer_builder.build(count_write_chunk);
600
601        let sources: crate::Result<Vec<_>, MF::Error> = self
602            .chunks
603            .drain(..)
604            .map(|mut chunk| {
605                chunk.seek(SeekFrom::Start(0))?;
606                Reader::new(chunk).and_then(Reader::into_cursor).map_err(Error::convert_merge_error)
607            })
608            .collect();
609
610        // Create a merger to merge all those chunks.
611        let mut builder = Merger::builder(&self.merge_function);
612        builder.extend(sources?);
613        let merger = builder.build();
614
615        let mut iter = merger.into_stream_merger_iter().map_err(Error::convert_merge_error)?;
616        while let Some((key, val)) = iter.next()? {
617            writer.insert(key, val)?;
618        }
619
620        let mut count_write_chunk = writer.into_inner()?;
621        count_write_chunk.flush()?;
622        let written_bytes = count_write_chunk.count();
623        let chunk = count_write_chunk.into_inner()?;
624
625        self.chunks.push(chunk);
626
627        Ok(written_bytes)
628    }
629
630    /// Consumes this [`Sorter`] and streams the entries to the [`Writer`] given in parameter.
631    pub fn write_into_stream_writer<W: io::Write>(
632        self,
633        writer: &mut Writer<W>,
634    ) -> crate::Result<(), MF::Error> {
635        let mut iter = self.into_stream_merger_iter()?;
636        while let Some((key, val)) = iter.next()? {
637            writer.insert(key, val)?;
638        }
639        Ok(())
640    }
641
642    /// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order.
643    pub fn into_stream_merger_iter(self) -> crate::Result<MergerIter<CC::Chunk, MF>, MF::Error> {
644        let (sources, merge) = self.extract_reader_cursors_and_merger()?;
645        let mut builder = Merger::builder(merge);
646        builder.extend(sources);
647        builder.build().into_stream_merger_iter().map_err(Error::convert_merge_error)
648    }
649
650    /// Consumes this [`Sorter`] and outputs the list of reader cursors.
651    pub fn into_reader_cursors(self) -> crate::Result<Vec<ReaderCursor<CC::Chunk>>, MF::Error> {
652        self.extract_reader_cursors_and_merger().map(|(readers, _)| readers)
653    }
654
655    /// A helper function to extract the readers and the merge function.
656    #[allow(clippy::type_complexity)] // Return type is not THAT complex
657    fn extract_reader_cursors_and_merger(
658        mut self,
659    ) -> crate::Result<(Vec<ReaderCursor<CC::Chunk>>, MF), MF::Error> {
660        // Flush the pending unordered entries.
661        self.chunks_total_size = self.write_chunk()?;
662
663        let Sorter { chunks, merge_function: merge, .. } = self;
664        let result: Result<Vec<_>, _> = chunks
665            .into_iter()
666            .map(|mut chunk| {
667                chunk.seek(SeekFrom::Start(0))?;
668                Reader::new(chunk).and_then(Reader::into_cursor).map_err(Error::convert_merge_error)
669            })
670            .collect();
671
672        result.map(|readers| (readers, merge))
673    }
674}
675
676impl<MF, CC: ChunkCreator> Debug for Sorter<MF, CC> {
677    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
678        f.debug_struct("Sorter")
679            .field("chunks_count", &self.chunks.len())
680            .field("remaining_entries", &self.entries.remaining())
681            .field("chunks_total_size", &self.chunks_total_size)
682            .field("allow_realloc", &self.allow_realloc)
683            .field("dump_threshold", &self.dump_threshold)
684            .field("max_nb_chunks", &self.max_nb_chunks)
685            .field("chunk_compression_type", &self.chunk_compression_type)
686            .field("chunk_compression_level", &self.chunk_compression_level)
687            .field("index_key_interval", &self.index_key_interval)
688            .field("block_size", &self.block_size)
689            .field("index_levels", &self.index_levels)
690            .field("chunk_creator", &"[chunck creator]")
691            .field("sort_algorithm", &self.sort_algorithm)
692            .field("sort_in_parallel", &self.sort_in_parallel)
693            .field("merge", &"[merge function]")
694            .finish()
695    }
696}
697
698/// A trait that represent a `ChunkCreator`.
699pub trait ChunkCreator {
700    /// The generated chunk by this `ChunkCreator`.
701    type Chunk: Write + Seek + Read;
702    /// The error that can be thrown by this `ChunkCreator`.
703    type Error: Into<Error>;
704
705    /// The method called by the sorter that returns the created chunk.
706    fn create(&self) -> Result<Self::Chunk, Self::Error>;
707}
708
709/// The default chunk creator.
710#[cfg(feature = "tempfile")]
711pub type DefaultChunkCreator = TempFileChunk;
712
713/// The default chunk creator.
714#[cfg(not(feature = "tempfile"))]
715pub type DefaultChunkCreator = CursorVec;
716
717impl<C: Write + Seek + Read, E: Into<Error>> ChunkCreator for dyn Fn() -> Result<C, E> {
718    type Chunk = C;
719    type Error = E;
720
721    fn create(&self) -> Result<Self::Chunk, Self::Error> {
722        self()
723    }
724}
725
726/// A [`ChunkCreator`] that generates temporary [`File`]s for chunks.
727#[cfg(feature = "tempfile")]
728#[derive(Debug, Default, Copy, Clone)]
729pub struct TempFileChunk;
730
731#[cfg(feature = "tempfile")]
732impl ChunkCreator for TempFileChunk {
733    type Chunk = File;
734    type Error = io::Error;
735
736    fn create(&self) -> Result<Self::Chunk, Self::Error> {
737        tempfile::tempfile()
738    }
739}
740
741/// A [`ChunkCreator`] that generates [`Vec`] of bytes wrapped by a [`Cursor`] for chunks.
742#[derive(Debug, Default, Copy, Clone)]
743pub struct CursorVec;
744
745impl ChunkCreator for CursorVec {
746    type Chunk = Cursor<Vec<u8>>;
747    type Error = Infallible;
748
749    fn create(&self) -> Result<Self::Chunk, Self::Error> {
750        Ok(Cursor::new(Vec::new()))
751    }
752}
753
754#[cfg(test)]
755mod tests {
756    use std::convert::Infallible;
757    use std::io::Cursor;
758    use std::iter::repeat;
759
760    use super::*;
761
762    #[derive(Copy, Clone)]
763    struct ConcatMerger;
764
765    impl MergeFunction for ConcatMerger {
766        type Error = Infallible;
767
768        fn merge<'a>(
769            &self,
770            _key: &[u8],
771            values: &[Cow<'a, [u8]>],
772        ) -> std::result::Result<Cow<'a, [u8]>, Self::Error> {
773            Ok(values.iter().flat_map(AsRef::as_ref).cloned().collect())
774        }
775    }
776
777    #[test]
778    #[cfg_attr(miri, ignore)]
779    fn simple_cursorvec() {
780        let mut sorter = SorterBuilder::new(ConcatMerger)
781            .chunk_compression_type(CompressionType::Snappy)
782            .chunk_creator(CursorVec)
783            .build();
784
785        sorter.insert(b"hello", "kiki").unwrap();
786        sorter.insert(b"abstract", "lol").unwrap();
787        sorter.insert(b"allo", "lol").unwrap();
788        sorter.insert(b"abstract", "lol").unwrap();
789
790        let mut bytes = WriterBuilder::new().memory();
791        sorter.write_into_stream_writer(&mut bytes).unwrap();
792        let bytes = bytes.into_inner().unwrap();
793
794        let reader = Reader::new(Cursor::new(bytes.as_slice())).unwrap();
795        let mut cursor = reader.into_cursor().unwrap();
796        while let Some((key, val)) = cursor.move_on_next().unwrap() {
797            match key {
798                b"hello" => assert_eq!(val, b"kiki"),
799                b"abstract" => assert_eq!(val, b"lollol"),
800                b"allo" => assert_eq!(val, b"lol"),
801                bytes => panic!("{:?}", bytes),
802            }
803        }
804    }
805
806    #[test]
807    #[cfg_attr(miri, ignore)]
808    fn hard_cursorvec() {
809        let mut sorter = SorterBuilder::new(ConcatMerger)
810            .dump_threshold(1024) // 1KiB
811            .allow_realloc(false)
812            .chunk_compression_type(CompressionType::Snappy)
813            .chunk_creator(CursorVec)
814            .build();
815
816        // make sure that we reach the threshold we store the keys,
817        // values and EntryBound inline in the buffer so we are likely
818        // to reach it by inserting 200x 5+4 bytes long entries.
819        for _ in 0..200 {
820            sorter.insert(b"hello", "kiki").unwrap();
821        }
822
823        let mut bytes = WriterBuilder::new().memory();
824        sorter.write_into_stream_writer(&mut bytes).unwrap();
825        let bytes = bytes.into_inner().unwrap();
826
827        let reader = Reader::new(Cursor::new(bytes.as_slice())).unwrap();
828        let mut cursor = reader.into_cursor().unwrap();
829        let (key, val) = cursor.move_on_next().unwrap().unwrap();
830        assert_eq!(key, b"hello");
831        assert!(val.iter().eq(repeat(b"kiki").take(200).flatten()));
832        assert!(cursor.move_on_next().unwrap().is_none());
833    }
834
835    #[test]
836    #[cfg_attr(miri, ignore)]
837    fn correct_key_ordering() {
838        use std::borrow::Cow;
839
840        use rand::prelude::{SeedableRng, SliceRandom};
841        use rand::rngs::StdRng;
842
843        /// This merge function concat bytes in the order they are received.
844        struct ConcatBytesMerger;
845
846        impl MergeFunction for ConcatBytesMerger {
847            type Error = Infallible;
848
849            fn merge<'a>(
850                &self,
851                _key: &[u8],
852                values: &[Cow<'a, [u8]>],
853            ) -> std::result::Result<Cow<'a, [u8]>, Self::Error> {
854                let mut output = Vec::new();
855                for value in values {
856                    output.extend_from_slice(value);
857                }
858                Ok(Cow::from(output))
859            }
860        }
861
862        // We create a sorter that will sum our u32s when necessary.
863        let mut sorter = SorterBuilder::new(ConcatBytesMerger).chunk_creator(CursorVec).build();
864
865        // We insert all the possible values of an u8 in ascending order
866        // but we split them along different keys.
867        let mut rng = StdRng::seed_from_u64(42);
868        let possible_keys = ["first", "second", "third", "fourth", "fifth", "sixth"];
869        for n in 0..=255 {
870            let key = possible_keys.choose(&mut rng).unwrap();
871            sorter.insert(key, [n]).unwrap();
872        }
873
874        // We can iterate over the entries in key-order.
875        let mut iter = sorter.into_stream_merger_iter().unwrap();
876        while let Some((_key, value)) = iter.next().unwrap() {
877            assert!(value.windows(2).all(|w| w[0] <= w[1]), "{:?}", value);
878        }
879    }
880
881    #[test]
882    #[should_panic(
883        expected = "the allocator is unable to allocate that much memory (281474976710656 bytes requested)"
884    )]
885    #[cfg_attr(miri, ignore)]
886    fn too_big_allocation() {
887        EntryBoundAlignedBuffer::new(1 << 48);
888    }
889}