Skip to main content

bed_utils/extsort/
sort.rs

1use crate::extsort::merger::BinaryHeapMerger;
2use crate::extsort::{
3    chunk::{ExternalChunk, ExternalChunkError},
4    DiskDeserializer, DiskSerializer,
5};
6
7use rayon::slice::ParallelSliceMut;
8use rkyv::{Archive, Deserialize, Serialize};
9use std::sync::{
10    atomic::{AtomicUsize, Ordering as AOrd},
11    mpsc,
12};
13use std::{
14    cmp::Ordering,
15    error::Error,
16    fmt::{self, Display},
17    io,
18    path::{Path, PathBuf},
19};
20
21/// Errors returned by external sorting operations.
22#[derive(Debug)]
23pub enum SortError {
24    /// Temporary directory or file creation error.
25    TempDir(io::Error),
26    /// Workers thread pool initialization error.
27    ThreadPoolBuildError(rayon::ThreadPoolBuildError),
28    /// Common I/O error.
29    IO(io::Error),
30    /// Data serialization error.
31    SerializationError(rkyv::rancor::Error),
32    /// Data deserialization error.
33    DeserializationError(rkyv::rancor::Error),
34}
35
36impl Error for SortError {
37    fn source(&self) -> Option<&(dyn Error + 'static)> {
38        Some(match &self {
39            SortError::TempDir(err) => err,
40            SortError::ThreadPoolBuildError(err) => err,
41            SortError::IO(err) => err,
42            SortError::SerializationError(err) => err,
43            SortError::DeserializationError(err) => err,
44        })
45    }
46}
47
48impl Display for SortError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match &self {
51            SortError::TempDir(err) => {
52                write!(f, "temporary directory or file not created: {}", err)
53            }
54            SortError::ThreadPoolBuildError(err) => {
55                write!(f, "thread pool initialization failed: {}", err)
56            }
57            SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
58            SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
59            SortError::DeserializationError(err) => {
60                write!(f, "data deserialization error: {}", err)
61            }
62        }
63    }
64}
65
66/// Exposes external sorting (i.e. on disk sorting) capability on arbitrarily
67/// sized iterator, even if the generated content of the iterator doesn't fit in
68/// memory.
69pub struct ExternalSorterBuilder {
70    chunk_size: usize,
71    tmp_dir: Option<PathBuf>,
72    num_threads: Option<usize>,
73    compression: u32,
74}
75
76impl ExternalSorterBuilder {
77    pub fn new() -> Self {
78        Self {
79            chunk_size: 50000000,
80            tmp_dir: None,
81            num_threads: None,
82            compression: 1,
83        }
84    }
85
86    /// Sets the maximum size of each segment in number of sorted items.
87    ///
88    /// This number of items needs to fit in memory. While sorting, a
89    /// in-memory buffer is used to collect the items to be sorted. Once
90    /// it reaches the maximum size, it is sorted and then written to disk.
91    ///
92    /// Using a higher segment size makes sorting faster by leveraging
93    /// faster in-memory operations.
94    pub fn with_chunk_size(mut self, size: usize) -> Self {
95        self.chunk_size = size;
96        self
97    }
98
99    /// Sets directory in which sorted segments will be written (if it doesn't
100    /// fit in memory).
101    pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
102        self.tmp_dir = Some(path.as_ref().to_path_buf());
103        self
104    }
105
106    /// Sets the compression level (1-16) to be used when writing sorted segments to
107    /// disk.
108    pub fn with_compression(mut self, level: u32) -> Self {
109        self.compression = level;
110        self
111    }
112
113    /// Uses Rayon to sort the in-memory buffer.
114    ///
115    /// This may not be needed if the buffer isn't big enough for parallelism to
116    /// be gainful over the overhead of multithreading.
117    pub fn num_threads(mut self, num_threads: usize) -> Self {
118        self.num_threads = Some(num_threads);
119        self
120    }
121
122    /// Builds an [`ExternalSorter`].
123    ///
124    /// Returns an I/O error if the temporary directory or rayon thread pool
125    /// cannot be initialized.
126    pub fn build(self) -> io::Result<ExternalSorter> {
127        Ok(ExternalSorter {
128            chunk_size: self.chunk_size,
129            compression: self.compression,
130            tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
131            thread_pool: _init_thread_pool(self.num_threads)?,
132        })
133    }
134}
135
136pub struct ExternalSorter {
137    chunk_size: usize,
138    compression: u32,
139    /// Sorting thread pool.
140    thread_pool: rayon::ThreadPool,
141    /// Directory to be used to store temporary data.
142    tmp_dir: tempfile::TempDir,
143}
144
145impl ExternalSorter {
146    /// Sorts items using [`Ord`] and returns a sorted iterator.
147    ///
148    /// The input is processed in bounded in-memory chunks and spilled to
149    /// temporary files as needed, so it can handle iterators larger than RAM.
150    ///
151    /// # Errors
152    /// Returns [`SortError`] if creating or reading temporary chunks fails.
153    pub fn sort<I, T>(
154        &self,
155        input: I,
156    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
157    where
158        T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + Ord,
159        T::Archived: Deserialize<T, DiskDeserializer>,
160        I: IntoIterator<Item = T>,
161    {
162        self.sort_by(input, T::cmp)
163    }
164
165    /// Sorts items with a custom comparator and returns a sorted iterator.
166    ///
167    /// This is the synchronous variant: each full chunk is sorted and written
168    /// before reading more input.
169    ///
170    /// # Errors
171    /// Returns [`SortError`] if temporary chunk creation, serialization, or
172    /// deserialization fails.
173    pub fn sort_by<I, T, F>(
174        &self,
175        input: I,
176        cmp: F,
177    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
178    where
179        T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send,
180        T::Archived: Deserialize<T, DiskDeserializer>,
181        I: IntoIterator<Item = T>,
182        F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
183    {
184        let mut chunk_buf = Vec::with_capacity(self.chunk_size);
185        let mut external_chunks = Vec::new();
186        let mut num_items = 0;
187
188        for item in input.into_iter() {
189            num_items += 1;
190            chunk_buf.push(item);
191            if chunk_buf.len() >= self.chunk_size {
192                external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
193                chunk_buf = Vec::with_capacity(self.chunk_size);
194            }
195        }
196
197        if chunk_buf.len() > 0 {
198            external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
199        }
200
201        return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
202    }
203
204    /// Asynchronously sorts items using [`Ord`] and returns a sorted iterator.
205    ///
206    /// This delegates to [`ExternalSorter::sort_by_async`] with `T::cmp`.
207    ///
208    /// # Errors
209    /// Returns [`SortError`] if temporary chunk creation, serialization, or
210    /// deserialization fails.
211    pub fn sort_async<I, T>(
212        &self,
213        input: I,
214    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
215    where
216        T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + Ord + 'static,
217        T::Archived: Deserialize<T, DiskDeserializer>,
218        I: IntoIterator<Item = T>,
219    {
220        self.sort_by_async(input, T::cmp)
221    }
222
223    /// Asynchronously sorts items with a custom comparator and returns a sorted iterator.
224    ///
225    /// Chunk sorting/spilling jobs are scheduled onto the configured rayon
226    /// thread pool while input continues to be consumed on the caller thread.
227    ///
228    /// # Errors
229    /// Returns [`SortError`] if any background job fails while creating,
230    /// serializing, or reading temporary chunks.
231    pub fn sort_by_async<I, T, F>(
232        &self,
233        input: I,
234        cmp: F,
235    ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
236    where
237        I: IntoIterator<Item = T>,
238        T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + 'static,
239        T::Archived: Deserialize<T, DiskDeserializer>,
240        F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
241    {
242        // We’ll get created chunks back through this channel.
243        let (tx, rx) = mpsc::channel::<Result<ExternalChunk<T>, SortError>>();
244
245        let num_items = AtomicUsize::new(0);
246        let tmp_dir_path: PathBuf = self.tmp_dir.path().to_path_buf();
247        let compression = self.compression;
248
249        // PRODUCER: runs on the caller’s thread -> iterator never crosses threads.
250        let mut buf: Vec<T> = Vec::with_capacity(self.chunk_size);
251
252        for item in input.into_iter() {
253            num_items.fetch_add(1, AOrd::Relaxed);
254            buf.push(item);
255            if buf.len() >= self.chunk_size {
256                let chunk = std::mem::take(&mut buf);
257                let txc = tx.clone();
258                let tmp = tmp_dir_path.clone();
259                let cmp_c = cmp;
260
261                // Spawn background job on *your* pool.
262                self.thread_pool.spawn(move || {
263                    let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
264                    let _ = txc.send(res);
265                });
266            }
267        }
268
269        if !buf.is_empty() {
270            let chunk = std::mem::take(&mut buf);
271            let txc = tx.clone();
272            let tmp = tmp_dir_path.clone();
273            let cmp_c = cmp;
274
275            self.thread_pool.spawn(move || {
276                let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
277                let _ = txc.send(res);
278            });
279        }
280
281        // Drop last sender so rx finishes once all tasks send their result.
282        drop(tx);
283
284        // CONSUMER: collect finished chunks (blocks only until workers complete).
285        let mut external_chunks = Vec::new();
286        for res in rx.iter() {
287            external_chunks.push(res?);
288        }
289
290        Ok(BinaryHeapMerger::new(
291            num_items.load(AOrd::Relaxed),
292            external_chunks,
293            cmp,
294        ))
295    }
296
297    fn create_chunk<T, F>(
298        &self,
299        mut buffer: Vec<T>,
300        compare: F,
301    ) -> Result<ExternalChunk<T>, SortError>
302    where
303        T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send,
304        T::Archived: Deserialize<T, DiskDeserializer>,
305        F: Fn(&T, &T) -> Ordering + Sync + Send,
306    {
307        self.thread_pool.install(|| {
308            buffer.par_sort_unstable_by(compare);
309        });
310
311        let tmp_file = tempfile::tempfile_in(&self.tmp_dir).unwrap();
312        let external_chunk =
313            ExternalChunk::new(tmp_file, buffer, self.compression).map_err(|err| match err {
314                ExternalChunkError::IO(err) => SortError::IO(err),
315                ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
316                ExternalChunkError::DecodeError(err) => SortError::DeserializationError(err),
317            })?;
318
319        return Ok(external_chunk);
320    }
321}
322
323/// Helper used by background tasks: sort a buffer and spill it to a temp file in `tmp_dir`.
324fn create_chunk_from_parts<T, F>(
325    mut buffer: Vec<T>,
326    compare: F,
327    tmp_dir: &std::path::Path,
328    compression: u32,
329) -> Result<ExternalChunk<T>, SortError>
330where
331    T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + 'static,
332    T::Archived: Deserialize<T, DiskDeserializer>,
333    F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
334{
335    buffer.sort_unstable_by(compare);
336    let tmp_file = tempfile::tempfile_in(tmp_dir).map_err(SortError::IO)?;
337    ExternalChunk::new(tmp_file, buffer, compression).map_err(|err| match err {
338        ExternalChunkError::IO(e) => SortError::IO(e),
339        ExternalChunkError::EncodeError(e) => SortError::SerializationError(e),
340        ExternalChunkError::DecodeError(e) => SortError::DeserializationError(e),
341    })
342}
343
344fn _init_tmp_directory(tmp_path: Option<&Path>) -> io::Result<tempfile::TempDir> {
345    if let Some(tmp_path) = tmp_path {
346        tempfile::tempdir_in(tmp_path)
347    } else {
348        tempfile::tempdir()
349    }
350}
351
352fn _init_thread_pool(threads_number: Option<usize>) -> io::Result<rayon::ThreadPool> {
353    let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
354    if let Some(threads_number) = threads_number {
355        thread_pool_builder = thread_pool_builder.num_threads(threads_number);
356    }
357    thread_pool_builder
358        .build()
359        .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
360}
361
362#[cfg(test)]
363mod test {
364    use std::path::Path;
365
366    use rand::seq::SliceRandom;
367    use rstest::*;
368
369    use super::{ExternalSorter, ExternalSorterBuilder};
370
371    #[rstest]
372    #[case(false)]
373    #[case(true)]
374    fn test_external_sorter(#[case] reversed: bool) {
375        let input_sorted = 0..100;
376
377        let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
378        input.shuffle(&mut rand::thread_rng());
379
380        let sorter: ExternalSorter = ExternalSorterBuilder::new()
381            .num_threads(2)
382            .with_tmp_dir(Path::new("./"))
383            .build()
384            .unwrap();
385
386        let compare = if reversed {
387            |a: &i32, b: &i32| a.cmp(b).reverse()
388        } else {
389            |a: &i32, b: &i32| a.cmp(b)
390        };
391
392        let expected_result = if reversed {
393            Vec::from_iter(input_sorted.clone().rev())
394        } else {
395            Vec::from_iter(input_sorted.clone())
396        };
397
398        let result = sorter.sort_by(input.clone(), compare).unwrap();
399        assert_eq!(
400            result.collect::<Result<Vec<_>, _>>().unwrap(),
401            expected_result
402        );
403
404        let result = sorter.sort_by_async(input, compare).unwrap();
405        assert_eq!(
406            result.collect::<Result<Vec<_>, _>>().unwrap(),
407            expected_result
408        );
409    }
410}