bed_utils/extsort/
sort.rs

1use crate::extsort::merger::BinaryHeapMerger;
2use crate::extsort::chunk::{ExternalChunk, ExternalChunkError};
3
4use rayon::prelude::*;
5use std::{fmt::{self, Display}, io};
6use rayon;
7use std::error::Error;
8use bincode::{self, Decode, Encode};
9use std::{
10    cmp::Ordering,
11    path::{Path, PathBuf},
12};
13
14/// Sorting error.
15#[derive(Debug)]
16pub enum SortError {
17    /// Temporary directory or file creation error.
18    TempDir(io::Error),
19    /// Workers thread pool initialization error.
20    ThreadPoolBuildError(rayon::ThreadPoolBuildError),
21    /// Common I/O error.
22    IO(io::Error),
23    /// Data serialization error.
24    SerializationError(bincode::error::EncodeError),
25    /// Data deserialization error.
26    DeserializationError(bincode::error::DecodeError),
27}
28
29impl Error for SortError
30{
31    fn source(&self) -> Option<&(dyn Error + 'static)> {
32        Some(match &self {
33            SortError::TempDir(err) => err,
34            SortError::ThreadPoolBuildError(err) => err,
35            SortError::IO(err) => err,
36            SortError::SerializationError(err) => err,
37            SortError::DeserializationError(err) => err,
38        })
39    }
40}
41
42impl Display for SortError {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match &self {
45            SortError::TempDir(err) => write!(f, "temporary directory or file not created: {}", err),
46            SortError::ThreadPoolBuildError(err) => write!(f, "thread pool initialization failed: {}", err),
47            SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
48            SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
49            SortError::DeserializationError(err) => write!(f, "data deserialization error: {}", err),
50        }
51    }
52}
53
54/// Exposes external sorting (i.e. on disk sorting) capability on arbitrarily
55/// sized iterator, even if the generated content of the iterator doesn't fit in
56/// memory.
57pub struct ExternalSorterBuilder {
58    chunk_size: usize,
59    tmp_dir: Option<PathBuf>,
60    num_threads: Option<usize>,
61    compression: Option<u32>,
62}
63
64impl ExternalSorterBuilder {
65    pub fn new() -> Self {
66        Self {
67            chunk_size: 50000000,
68            tmp_dir: None,
69            num_threads: None,
70            compression: None,
71        }
72    }
73
74    /// Sets the maximum size of each segment in number of sorted items.
75    ///
76    /// This number of items needs to fit in memory. While sorting, a
77    /// in-memory buffer is used to collect the items to be sorted. Once
78    /// it reaches the maximum size, it is sorted and then written to disk.
79    ///
80    /// Using a higher segment size makes sorting faster by leveraging
81    /// faster in-memory operations.
82    pub fn with_chunk_size(mut self, size: usize) -> Self {
83        self.chunk_size = size;
84        self
85    }
86
87    /// Sets directory in which sorted segments will be written (if it doesn't
88    /// fit in memory).
89    pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
90        self.tmp_dir = Some(path.as_ref().to_path_buf());
91        self
92    }
93
94    /// Sets the compression level (1-16) to be used when writing sorted segments to
95    /// disk.
96    pub fn with_compression(mut self, level: u32) -> Self {
97        self.compression = Some(level);
98        self
99    }
100
101    /// Uses Rayon to sort the in-memory buffer.
102    ///
103    /// This may not be needed if the buffer isn't big enough for parallelism to
104    /// be gainful over the overhead of multithreading.
105    pub fn num_threads(mut self, num_threads: usize) -> Self {
106        self.num_threads = Some(num_threads);
107        self
108    }
109
110    pub fn build(self) -> io::Result<ExternalSorter> {
111        Ok(ExternalSorter {
112            chunk_size: self.chunk_size,
113            compression: self.compression,
114            tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
115            thread_pool: _init_thread_pool(self.num_threads)?,
116        })
117    }
118}
119
120pub struct ExternalSorter {
121    chunk_size: usize,
122    compression: Option<u32>,
123    /// Sorting thread pool.
124    thread_pool: rayon::ThreadPool,
125    /// Directory to be used to store temporary data.
126    tmp_dir: tempfile::TempDir,
127}
128
129impl ExternalSorter {
130    /// Sorts data from the input.
131    /// Returns an iterator that can be used to get sorted data stream.
132    ///
133    /// # Arguments
134    /// * `input` - Input stream data to be fetched from
135    pub fn sort<I, T>(&self, input: I) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
136    where
137        T: Encode + Decode<()> + Send + Ord,
138        I: IntoIterator<Item = T>,
139    {
140        self.sort_by(input, T::cmp)
141    }
142
143    /// Sorts a given iterator with a comparator function, returning a new iterator with items
144    pub fn sort_by<I, T, F>(&self, input: I, cmp: F) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
145    where
146        T: Encode + Decode<()> + Send,
147        I: IntoIterator<Item = T>,
148        F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
149    {
150        let mut chunk_buf = Vec::with_capacity(self.chunk_size);
151        let mut external_chunks = Vec::new();
152        let mut num_items = 0;
153
154        for item in input.into_iter() {
155            num_items += 1;
156            chunk_buf.push(item);
157            if chunk_buf.len() >= self.chunk_size {
158                external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
159                chunk_buf = Vec::with_capacity(self.chunk_size);
160            }
161        }
162
163        if chunk_buf.len() > 0 {
164            external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
165        }
166
167        return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
168    }
169
170    fn create_chunk<T, F>(&self, mut buffer: Vec<T>, compare: F) -> Result<ExternalChunk<T>, SortError>
171    where
172        T: Encode + Send,
173        F: Fn(&T, &T) -> Ordering + Sync + Send,
174    {
175        self.thread_pool.install(|| {
176            buffer.par_sort_unstable_by(compare);
177        });
178
179        let external_chunk =
180            ExternalChunk::new(&self.tmp_dir, buffer, self.compression).map_err(|err| match err {
181                ExternalChunkError::IO(err) => SortError::IO(err),
182                ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
183                ExternalChunkError::DecodeError(err) => SortError::DeserializationError(err),
184            })?;
185
186        return Ok(external_chunk);
187    }
188}
189
190fn _init_tmp_directory(
191    tmp_path: Option<&Path>,
192) -> io::Result<tempfile::TempDir> {
193    if let Some(tmp_path) = tmp_path {
194        tempfile::tempdir_in(tmp_path)
195    } else {
196        tempfile::tempdir()
197    }
198}
199
200fn _init_thread_pool(
201    threads_number: Option<usize>,
202) -> io::Result<rayon::ThreadPool> {
203    let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
204    if let Some(threads_number) = threads_number {
205        thread_pool_builder = thread_pool_builder.num_threads(threads_number);
206    }
207    thread_pool_builder.build().map_err(|x| io::Error::new(io::ErrorKind::Other, x))
208}
209
210#[cfg(test)]
211mod test {
212    use std::path::Path;
213
214    use rand::seq::SliceRandom;
215    use rstest::*;
216
217    use super::{ExternalSorter, ExternalSorterBuilder};
218
219    #[rstest]
220    #[case(false)]
221    #[case(true)]
222    fn test_external_sorter(#[case] reversed: bool) {
223        let input_sorted = 0..100;
224
225        let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
226        input.shuffle(&mut rand::thread_rng());
227
228        let sorter: ExternalSorter = ExternalSorterBuilder::new()
229            .num_threads(2)
230            .with_tmp_dir(Path::new("./"))
231            .build()
232            .unwrap();
233
234        let compare = if reversed {
235            |a: &i32, b: &i32| a.cmp(b).reverse()
236        } else {
237            |a: &i32, b: &i32| a.cmp(b)
238        };
239
240        let result = sorter.sort_by(input, compare).unwrap();
241
242        let actual_result: Result<Vec<i32>, _> = result.collect();
243        let actual_result = actual_result.unwrap();
244        let expected_result = if reversed {
245            Vec::from_iter(input_sorted.clone().rev())
246        } else {
247            Vec::from_iter(input_sorted.clone())
248        };
249
250        assert_eq!(actual_result, expected_result)
251    }
252}