Skip to main content

bed_utils/extsort/
sort.rs

1use crate::extsort::chunk::{ExternalChunk, ExternalChunkError};
2use crate::extsort::merger::BinaryHeapMerger;
3
4use bitcode::{self, DecodeOwned, Encode};
5use rayon::slice::ParallelSliceMut;
6use std::{
7    cmp::Ordering,
8    error::Error,
9    fmt::{self, Display},
10    io,
11    path::{Path, PathBuf},
12};
13use std::sync::{mpsc, atomic::{AtomicUsize, Ordering as AOrd}};
14
15/// Sorting error.
16#[derive(Debug)]
17pub enum SortError {
18    /// Temporary directory or file creation error.
19    TempDir(io::Error),
20    /// Workers thread pool initialization error.
21    ThreadPoolBuildError(rayon::ThreadPoolBuildError),
22    /// Common I/O error.
23    IO(io::Error),
24    /// Data serialization error.
25    SerializationError(bitcode::Error),
26}
27
28impl Error for SortError {
29    fn source(&self) -> Option<&(dyn Error + 'static)> {
30        Some(match &self {
31            SortError::TempDir(err) => err,
32            SortError::ThreadPoolBuildError(err) => err,
33            SortError::IO(err) => err,
34            SortError::SerializationError(err) => err,
35        })
36    }
37}
38
39impl Display for SortError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match &self {
42            SortError::TempDir(err) => {
43                write!(f, "temporary directory or file not created: {}", err)
44            }
45            SortError::ThreadPoolBuildError(err) => {
46                write!(f, "thread pool initialization failed: {}", err)
47            }
48            SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
49            SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
50        }
51    }
52}
53
54/// Exposes external sorting (i.e. on disk sorting) capability on arbitrarily
55/// sized iterator, even if the generated content of the iterator doesn't fit in
56/// memory.
57pub struct ExternalSorterBuilder {
58    chunk_size: usize,
59    tmp_dir: Option<PathBuf>,
60    num_threads: Option<usize>,
61    compression: u32,
62}
63
64impl ExternalSorterBuilder {
65    pub fn new() -> Self {
66        Self {
67            chunk_size: 50000000,
68            tmp_dir: None,
69            num_threads: None,
70            compression: 1,
71        }
72    }
73
74    /// Sets the maximum size of each segment in number of sorted items.
75    ///
76    /// This number of items needs to fit in memory. While sorting, a
77    /// in-memory buffer is used to collect the items to be sorted. Once
78    /// it reaches the maximum size, it is sorted and then written to disk.
79    ///
80    /// Using a higher segment size makes sorting faster by leveraging
81    /// faster in-memory operations.
82    pub fn with_chunk_size(mut self, size: usize) -> Self {
83        self.chunk_size = size;
84        self
85    }
86
87    /// Sets directory in which sorted segments will be written (if it doesn't
88    /// fit in memory).
89    pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
90        self.tmp_dir = Some(path.as_ref().to_path_buf());
91        self
92    }
93
94    /// Sets the compression level (1-16) to be used when writing sorted segments to
95    /// disk.
96    pub fn with_compression(mut self, level: u32) -> Self {
97        self.compression = level;
98        self
99    }
100
101    /// Uses Rayon to sort the in-memory buffer.
102    ///
103    /// This may not be needed if the buffer isn't big enough for parallelism to
104    /// be gainful over the overhead of multithreading.
105    pub fn num_threads(mut self, num_threads: usize) -> Self {
106        self.num_threads = Some(num_threads);
107        self
108    }
109
110    pub fn build(self) -> io::Result<ExternalSorter> {
111        Ok(ExternalSorter {
112            chunk_size: self.chunk_size,
113            compression: self.compression,
114            tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
115            thread_pool: _init_thread_pool(self.num_threads)?,
116        })
117    }
118}
119
120pub struct ExternalSorter {
121    chunk_size: usize,
122    compression: u32,
123    /// Sorting thread pool.
124    thread_pool: rayon::ThreadPool,
125    /// Directory to be used to store temporary data.
126    tmp_dir: tempfile::TempDir,
127}
128
129impl ExternalSorter {
130    /// Sorts data from the input.
131    /// Returns an iterator that can be used to get sorted data stream.
132    ///
133    /// # Arguments
134    /// * `input` - Input stream data to be fetched from
135    pub fn sort<I, T>(
136        &self,
137        input: I,
138    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
139    where
140        T: Encode + DecodeOwned + Send + Ord,
141        I: IntoIterator<Item = T>,
142    {
143        self.sort_by(input, T::cmp)
144    }
145
146    /// Sorts a given iterator with a comparator function, returning a new iterator with items
147    pub fn sort_by<I, T, F>(
148        &self,
149        input: I,
150        cmp: F,
151    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
152    where
153        T: Encode + DecodeOwned + Send,
154        I: IntoIterator<Item = T>,
155        F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
156    {
157        let mut chunk_buf = Vec::with_capacity(self.chunk_size);
158        let mut external_chunks = Vec::new();
159        let mut num_items = 0;
160
161        for item in input.into_iter() {
162            num_items += 1;
163            chunk_buf.push(item);
164            if chunk_buf.len() >= self.chunk_size {
165                external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
166                chunk_buf = Vec::with_capacity(self.chunk_size);
167            }
168        }
169
170        if chunk_buf.len() > 0 {
171            external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
172        }
173
174        return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
175    }
176
177    pub fn sort_async<I, T>(
178        &self,
179        input: I,
180    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
181    where
182        T: Encode + DecodeOwned + Send + Ord + 'static,
183        I: IntoIterator<Item = T>,
184    {
185        self.sort_by_async(input, T::cmp)
186    }
187
188    /// Sorts a given iterator with a comparator function, returning a new iterator with items
189    pub fn sort_by_async<I, T, F>(
190        &self,
191        input: I,
192        cmp: F,
193    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
194    where
195        I: IntoIterator<Item = T>,
196        T: Encode + DecodeOwned + Send + 'static,
197        F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
198    {
199        // We’ll get created chunks back through this channel.
200        let (tx, rx) = mpsc::channel::<Result<ExternalChunk<T>, SortError>>();
201
202        let num_items = AtomicUsize::new(0);
203        let tmp_dir_path: PathBuf = self.tmp_dir.path().to_path_buf();
204        let compression = self.compression;
205
206        // PRODUCER: runs on the caller’s thread -> iterator never crosses threads.
207        let mut buf: Vec<T> = Vec::with_capacity(self.chunk_size);
208
209        for item in input.into_iter() {
210            num_items.fetch_add(1, AOrd::Relaxed);
211            buf.push(item);
212            if buf.len() >= self.chunk_size {
213                let chunk = std::mem::take(&mut buf);
214                let txc = tx.clone();
215                let tmp = tmp_dir_path.clone();
216                let cmp_c = cmp;
217
218                // Spawn background job on *your* pool.
219                self.thread_pool.spawn(move || {
220                    let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
221                    let _ = txc.send(res);
222                });
223            }
224        }
225
226        if !buf.is_empty() {
227            let chunk = std::mem::take(&mut buf);
228            let txc = tx.clone();
229            let tmp = tmp_dir_path.clone();
230            let cmp_c = cmp;
231
232            self.thread_pool.spawn(move || {
233                let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
234                let _ = txc.send(res);
235            });
236        }
237
238        // Drop last sender so rx finishes once all tasks send their result.
239        drop(tx);
240
241        // CONSUMER: collect finished chunks (blocks only until workers complete).
242        let mut external_chunks = Vec::new();
243        for res in rx.iter() {
244            external_chunks.push(res?);
245        }
246
247        Ok(BinaryHeapMerger::new(
248            num_items.load(AOrd::Relaxed),
249            external_chunks,
250            cmp,
251        ))
252    }
253
254    fn create_chunk<T, F>(
255        &self,
256        mut buffer: Vec<T>,
257        compare: F,
258    ) -> Result<ExternalChunk<T>, SortError>
259    where
260        T: Encode + Send,
261        F: Fn(&T, &T) -> Ordering + Sync + Send,
262    {
263        self.thread_pool.install(|| {
264            buffer.par_sort_unstable_by(compare);
265        });
266
267        let tmp_file = tempfile::tempfile_in(&self.tmp_dir).unwrap();
268        let external_chunk =
269            ExternalChunk::new(tmp_file, buffer, self.compression).map_err(|err| match err {
270                ExternalChunkError::IO(err) => SortError::IO(err),
271                ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
272            })?;
273
274        return Ok(external_chunk);
275    }
276}
277
278/// Helper used by background tasks: sort a buffer and spill it to a temp file in `tmp_dir`.
279fn create_chunk_from_parts<T, F>(
280    mut buffer: Vec<T>,
281    compare: F,
282    tmp_dir: &std::path::Path,
283    compression: u32,
284) -> Result<ExternalChunk<T>, SortError>
285where
286    T: Encode + Send + 'static,
287    F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
288{
289    buffer.sort_unstable_by(compare);
290    let tmp_file = tempfile::tempfile_in(tmp_dir).map_err(SortError::IO)?;
291    ExternalChunk::new(tmp_file, buffer, compression).map_err(|err| match err {
292        ExternalChunkError::IO(e) => SortError::IO(e),
293        ExternalChunkError::EncodeError(e) => SortError::SerializationError(e),
294    })
295}
296
297fn _init_tmp_directory(tmp_path: Option<&Path>) -> io::Result<tempfile::TempDir> {
298    if let Some(tmp_path) = tmp_path {
299        tempfile::tempdir_in(tmp_path)
300    } else {
301        tempfile::tempdir()
302    }
303}
304
305fn _init_thread_pool(threads_number: Option<usize>) -> io::Result<rayon::ThreadPool> {
306    let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
307    if let Some(threads_number) = threads_number {
308        thread_pool_builder = thread_pool_builder.num_threads(threads_number);
309    }
310    thread_pool_builder
311        .build()
312        .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
313}
314
315#[cfg(test)]
316mod test {
317    use std::path::Path;
318
319    use rand::seq::SliceRandom;
320    use rstest::*;
321
322    use super::{ExternalSorter, ExternalSorterBuilder};
323
324    #[rstest]
325    #[case(false)]
326    #[case(true)]
327    fn test_external_sorter(#[case] reversed: bool) {
328        let input_sorted = 0..100;
329
330        let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
331        input.shuffle(&mut rand::thread_rng());
332
333        let sorter: ExternalSorter = ExternalSorterBuilder::new()
334            .num_threads(2)
335            .with_tmp_dir(Path::new("./"))
336            .build()
337            .unwrap();
338
339        let compare = if reversed {
340            |a: &i32, b: &i32| a.cmp(b).reverse()
341        } else {
342            |a: &i32, b: &i32| a.cmp(b)
343        };
344
345
346        let expected_result = if reversed {
347            Vec::from_iter(input_sorted.clone().rev())
348        } else {
349            Vec::from_iter(input_sorted.clone())
350        };
351
352        let result = sorter.sort_by(input.clone(), compare).unwrap();
353        assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
354
355        let result = sorter.sort_by_async(input, compare).unwrap();
356        assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
357    }
358}