Skip to main content

bed_utils/extsort/
sort.rs

1use crate::extsort::chunk::{ExternalChunk, ExternalChunkError};
2use crate::extsort::merger::BinaryHeapMerger;
3
4use bincode_next::{Decode, Encode};
5use rayon::slice::ParallelSliceMut;
6use std::{
7    cmp::Ordering,
8    error::Error,
9    fmt::{self, Display},
10    io,
11    path::{Path, PathBuf},
12};
13use std::sync::{mpsc, atomic::{AtomicUsize, Ordering as AOrd}};
14
15/// Sorting error.
16#[derive(Debug)]
17pub enum SortError {
18    /// Temporary directory or file creation error.
19    TempDir(io::Error),
20    /// Workers thread pool initialization error.
21    ThreadPoolBuildError(rayon::ThreadPoolBuildError),
22    /// Common I/O error.
23    IO(io::Error),
24    /// Data serialization error.
25    SerializationError(bincode_next::error::EncodeError),
26    /// Data deserialization error.
27    DeserializationError(bincode_next::error::DecodeError),
28}
29
30impl Error for SortError {
31    fn source(&self) -> Option<&(dyn Error + 'static)> {
32        Some(match &self {
33            SortError::TempDir(err) => err,
34            SortError::ThreadPoolBuildError(err) => err,
35            SortError::IO(err) => err,
36            SortError::SerializationError(err) => err,
37            SortError::DeserializationError(err) => err,
38        })
39    }
40}
41
42impl Display for SortError {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match &self {
45            SortError::TempDir(err) => {
46                write!(f, "temporary directory or file not created: {}", err)
47            }
48            SortError::ThreadPoolBuildError(err) => {
49                write!(f, "thread pool initialization failed: {}", err)
50            }
51            SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
52            SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
53            SortError::DeserializationError(err) => {
54                write!(f, "data deserialization error: {}", err)
55            }
56        }
57    }
58}
59
60/// Exposes external sorting (i.e. on disk sorting) capability on arbitrarily
61/// sized iterator, even if the generated content of the iterator doesn't fit in
62/// memory.
63pub struct ExternalSorterBuilder {
64    chunk_size: usize,
65    tmp_dir: Option<PathBuf>,
66    num_threads: Option<usize>,
67    compression: u32,
68}
69
70impl ExternalSorterBuilder {
71    pub fn new() -> Self {
72        Self {
73            chunk_size: 50000000,
74            tmp_dir: None,
75            num_threads: None,
76            compression: 1,
77        }
78    }
79
80    /// Sets the maximum size of each segment in number of sorted items.
81    ///
82    /// This number of items needs to fit in memory. While sorting, a
83    /// in-memory buffer is used to collect the items to be sorted. Once
84    /// it reaches the maximum size, it is sorted and then written to disk.
85    ///
86    /// Using a higher segment size makes sorting faster by leveraging
87    /// faster in-memory operations.
88    pub fn with_chunk_size(mut self, size: usize) -> Self {
89        self.chunk_size = size;
90        self
91    }
92
93    /// Sets directory in which sorted segments will be written (if it doesn't
94    /// fit in memory).
95    pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
96        self.tmp_dir = Some(path.as_ref().to_path_buf());
97        self
98    }
99
100    /// Sets the compression level (1-16) to be used when writing sorted segments to
101    /// disk.
102    pub fn with_compression(mut self, level: u32) -> Self {
103        self.compression = level;
104        self
105    }
106
107    /// Uses Rayon to sort the in-memory buffer.
108    ///
109    /// This may not be needed if the buffer isn't big enough for parallelism to
110    /// be gainful over the overhead of multithreading.
111    pub fn num_threads(mut self, num_threads: usize) -> Self {
112        self.num_threads = Some(num_threads);
113        self
114    }
115
116    pub fn build(self) -> io::Result<ExternalSorter> {
117        Ok(ExternalSorter {
118            chunk_size: self.chunk_size,
119            compression: self.compression,
120            tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
121            thread_pool: _init_thread_pool(self.num_threads)?,
122        })
123    }
124}
125
126pub struct ExternalSorter {
127    chunk_size: usize,
128    compression: u32,
129    /// Sorting thread pool.
130    thread_pool: rayon::ThreadPool,
131    /// Directory to be used to store temporary data.
132    tmp_dir: tempfile::TempDir,
133}
134
135impl ExternalSorter {
136    /// Sorts data from the input.
137    /// Returns an iterator that can be used to get sorted data stream.
138    ///
139    /// # Arguments
140    /// * `input` - Input stream data to be fetched from
141    pub fn sort<I, T>(
142        &self,
143        input: I,
144    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
145    where
146        T: Encode + Decode<()> + Send + Ord,
147        I: IntoIterator<Item = T>,
148    {
149        self.sort_by(input, T::cmp)
150    }
151
152    /// Sorts a given iterator with a comparator function, returning a new iterator with items
153    pub fn sort_by<I, T, F>(
154        &self,
155        input: I,
156        cmp: F,
157    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
158    where
159        T: Encode + Decode<()> + Send,
160        I: IntoIterator<Item = T>,
161        F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
162    {
163        let mut chunk_buf = Vec::with_capacity(self.chunk_size);
164        let mut external_chunks = Vec::new();
165        let mut num_items = 0;
166
167        for item in input.into_iter() {
168            num_items += 1;
169            chunk_buf.push(item);
170            if chunk_buf.len() >= self.chunk_size {
171                external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
172                chunk_buf = Vec::with_capacity(self.chunk_size);
173            }
174        }
175
176        if chunk_buf.len() > 0 {
177            external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
178        }
179
180        return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
181    }
182
183    pub fn sort_async<I, T>(
184        &self,
185        input: I,
186    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
187    where
188        T: Encode + Decode<()> + Send + Ord + 'static,
189        I: IntoIterator<Item = T>,
190    {
191        self.sort_by_async(input, T::cmp)
192    }
193
194    /// Sorts a given iterator with a comparator function, returning a new iterator with items
195    pub fn sort_by_async<I, T, F>(
196        &self,
197        input: I,
198        cmp: F,
199    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
200    where
201        I: IntoIterator<Item = T>,
202        T: Encode + Decode<()> + Send + 'static,
203        F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
204    {
205        // We’ll get created chunks back through this channel.
206        let (tx, rx) = mpsc::channel::<Result<ExternalChunk<T>, SortError>>();
207
208        let num_items = AtomicUsize::new(0);
209        let tmp_dir_path: PathBuf = self.tmp_dir.path().to_path_buf();
210        let compression = self.compression;
211
212        // PRODUCER: runs on the caller’s thread -> iterator never crosses threads.
213        let mut buf: Vec<T> = Vec::with_capacity(self.chunk_size);
214
215        for item in input.into_iter() {
216            num_items.fetch_add(1, AOrd::Relaxed);
217            buf.push(item);
218            if buf.len() >= self.chunk_size {
219                let chunk = std::mem::take(&mut buf);
220                let txc = tx.clone();
221                let tmp = tmp_dir_path.clone();
222                let cmp_c = cmp;
223
224                // Spawn background job on *your* pool.
225                self.thread_pool.spawn(move || {
226                    let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
227                    let _ = txc.send(res);
228                });
229            }
230        }
231
232        if !buf.is_empty() {
233            let chunk = std::mem::take(&mut buf);
234            let txc = tx.clone();
235            let tmp = tmp_dir_path.clone();
236            let cmp_c = cmp;
237
238            self.thread_pool.spawn(move || {
239                let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
240                let _ = txc.send(res);
241            });
242        }
243
244        // Drop last sender so rx finishes once all tasks send their result.
245        drop(tx);
246
247        // CONSUMER: collect finished chunks (blocks only until workers complete).
248        let mut external_chunks = Vec::new();
249        for res in rx.iter() {
250            external_chunks.push(res?);
251        }
252
253        Ok(BinaryHeapMerger::new(
254            num_items.load(AOrd::Relaxed),
255            external_chunks,
256            cmp,
257        ))
258    }
259
260    fn create_chunk<T, F>(
261        &self,
262        mut buffer: Vec<T>,
263        compare: F,
264    ) -> Result<ExternalChunk<T>, SortError>
265    where
266        T: Encode + Send,
267        F: Fn(&T, &T) -> Ordering + Sync + Send,
268    {
269        self.thread_pool.install(|| {
270            buffer.par_sort_unstable_by(compare);
271        });
272
273        let tmp_file = tempfile::tempfile_in(&self.tmp_dir).unwrap();
274        let external_chunk =
275            ExternalChunk::new(tmp_file, buffer, self.compression).map_err(|err| match err {
276                ExternalChunkError::IO(err) => SortError::IO(err),
277                ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
278                ExternalChunkError::DecodeError(err) => SortError::DeserializationError(err),
279            })?;
280
281        return Ok(external_chunk);
282    }
283}
284
285/// Helper used by background tasks: sort a buffer and spill it to a temp file in `tmp_dir`.
286fn create_chunk_from_parts<T, F>(
287    mut buffer: Vec<T>,
288    compare: F,
289    tmp_dir: &std::path::Path,
290    compression: u32,
291) -> Result<ExternalChunk<T>, SortError>
292where
293    T: Encode + Send + 'static,
294    F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
295{
296    buffer.sort_unstable_by(compare);
297    let tmp_file = tempfile::tempfile_in(tmp_dir).map_err(SortError::IO)?;
298    ExternalChunk::new(tmp_file, buffer, compression).map_err(|err| match err {
299        ExternalChunkError::IO(e) => SortError::IO(e),
300        ExternalChunkError::EncodeError(e) => SortError::SerializationError(e),
301        ExternalChunkError::DecodeError(e) => SortError::DeserializationError(e),
302    })
303}
304
305fn _init_tmp_directory(tmp_path: Option<&Path>) -> io::Result<tempfile::TempDir> {
306    if let Some(tmp_path) = tmp_path {
307        tempfile::tempdir_in(tmp_path)
308    } else {
309        tempfile::tempdir()
310    }
311}
312
313fn _init_thread_pool(threads_number: Option<usize>) -> io::Result<rayon::ThreadPool> {
314    let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
315    if let Some(threads_number) = threads_number {
316        thread_pool_builder = thread_pool_builder.num_threads(threads_number);
317    }
318    thread_pool_builder
319        .build()
320        .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
321}
322
323#[cfg(test)]
324mod test {
325    use std::path::Path;
326
327    use rand::seq::SliceRandom;
328    use rstest::*;
329
330    use super::{ExternalSorter, ExternalSorterBuilder};
331
332    #[rstest]
333    #[case(false)]
334    #[case(true)]
335    fn test_external_sorter(#[case] reversed: bool) {
336        let input_sorted = 0..100;
337
338        let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
339        input.shuffle(&mut rand::thread_rng());
340
341        let sorter: ExternalSorter = ExternalSorterBuilder::new()
342            .num_threads(2)
343            .with_tmp_dir(Path::new("./"))
344            .build()
345            .unwrap();
346
347        let compare = if reversed {
348            |a: &i32, b: &i32| a.cmp(b).reverse()
349        } else {
350            |a: &i32, b: &i32| a.cmp(b)
351        };
352
353
354        let expected_result = if reversed {
355            Vec::from_iter(input_sorted.clone().rev())
356        } else {
357            Vec::from_iter(input_sorted.clone())
358        };
359
360        let result = sorter.sort_by(input.clone(), compare).unwrap();
361        assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
362
363        let result = sorter.sort_by_async(input, compare).unwrap();
364        assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
365    }
366}