1use crate::extsort::chunk::{ExternalChunk, ExternalChunkError};
2use crate::extsort::merger::BinaryHeapMerger;
3
4use bitcode::{self, DecodeOwned, Encode};
5use rayon::slice::ParallelSliceMut;
6use std::{
7 cmp::Ordering,
8 error::Error,
9 fmt::{self, Display},
10 io,
11 path::{Path, PathBuf},
12};
13use std::sync::{mpsc, atomic::{AtomicUsize, Ordering as AOrd}};
14
15#[derive(Debug)]
17pub enum SortError {
18 TempDir(io::Error),
20 ThreadPoolBuildError(rayon::ThreadPoolBuildError),
22 IO(io::Error),
24 SerializationError(bitcode::Error),
26}
27
28impl Error for SortError {
29 fn source(&self) -> Option<&(dyn Error + 'static)> {
30 Some(match &self {
31 SortError::TempDir(err) => err,
32 SortError::ThreadPoolBuildError(err) => err,
33 SortError::IO(err) => err,
34 SortError::SerializationError(err) => err,
35 })
36 }
37}
38
39impl Display for SortError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match &self {
42 SortError::TempDir(err) => {
43 write!(f, "temporary directory or file not created: {}", err)
44 }
45 SortError::ThreadPoolBuildError(err) => {
46 write!(f, "thread pool initialization failed: {}", err)
47 }
48 SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
49 SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
50 }
51 }
52}
53
54pub struct ExternalSorterBuilder {
58 chunk_size: usize,
59 tmp_dir: Option<PathBuf>,
60 num_threads: Option<usize>,
61 compression: u32,
62}
63
64impl ExternalSorterBuilder {
65 pub fn new() -> Self {
66 Self {
67 chunk_size: 50000000,
68 tmp_dir: None,
69 num_threads: None,
70 compression: 1,
71 }
72 }
73
74 pub fn with_chunk_size(mut self, size: usize) -> Self {
83 self.chunk_size = size;
84 self
85 }
86
87 pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
90 self.tmp_dir = Some(path.as_ref().to_path_buf());
91 self
92 }
93
94 pub fn with_compression(mut self, level: u32) -> Self {
97 self.compression = level;
98 self
99 }
100
101 pub fn num_threads(mut self, num_threads: usize) -> Self {
106 self.num_threads = Some(num_threads);
107 self
108 }
109
110 pub fn build(self) -> io::Result<ExternalSorter> {
111 Ok(ExternalSorter {
112 chunk_size: self.chunk_size,
113 compression: self.compression,
114 tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
115 thread_pool: _init_thread_pool(self.num_threads)?,
116 })
117 }
118}
119
120pub struct ExternalSorter {
121 chunk_size: usize,
122 compression: u32,
123 thread_pool: rayon::ThreadPool,
125 tmp_dir: tempfile::TempDir,
127}
128
129impl ExternalSorter {
130 pub fn sort<I, T>(
136 &self,
137 input: I,
138 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
139 where
140 T: Encode + DecodeOwned + Send + Ord,
141 I: IntoIterator<Item = T>,
142 {
143 self.sort_by(input, T::cmp)
144 }
145
146 pub fn sort_by<I, T, F>(
148 &self,
149 input: I,
150 cmp: F,
151 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
152 where
153 T: Encode + DecodeOwned + Send,
154 I: IntoIterator<Item = T>,
155 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
156 {
157 let mut chunk_buf = Vec::with_capacity(self.chunk_size);
158 let mut external_chunks = Vec::new();
159 let mut num_items = 0;
160
161 for item in input.into_iter() {
162 num_items += 1;
163 chunk_buf.push(item);
164 if chunk_buf.len() >= self.chunk_size {
165 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
166 chunk_buf = Vec::with_capacity(self.chunk_size);
167 }
168 }
169
170 if chunk_buf.len() > 0 {
171 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
172 }
173
174 return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
175 }
176
177 pub fn sort_async<I, T>(
178 &self,
179 input: I,
180 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
181 where
182 T: Encode + DecodeOwned + Send + Ord + 'static,
183 I: IntoIterator<Item = T>,
184 {
185 self.sort_by_async(input, T::cmp)
186 }
187
188 pub fn sort_by_async<I, T, F>(
190 &self,
191 input: I,
192 cmp: F,
193 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
194 where
195 I: IntoIterator<Item = T>,
196 T: Encode + DecodeOwned + Send + 'static,
197 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
198 {
199 let (tx, rx) = mpsc::channel::<Result<ExternalChunk<T>, SortError>>();
201
202 let num_items = AtomicUsize::new(0);
203 let tmp_dir_path: PathBuf = self.tmp_dir.path().to_path_buf();
204 let compression = self.compression;
205
206 let mut buf: Vec<T> = Vec::with_capacity(self.chunk_size);
208
209 for item in input.into_iter() {
210 num_items.fetch_add(1, AOrd::Relaxed);
211 buf.push(item);
212 if buf.len() >= self.chunk_size {
213 let chunk = std::mem::take(&mut buf);
214 let txc = tx.clone();
215 let tmp = tmp_dir_path.clone();
216 let cmp_c = cmp;
217
218 self.thread_pool.spawn(move || {
220 let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
221 let _ = txc.send(res);
222 });
223 }
224 }
225
226 if !buf.is_empty() {
227 let chunk = std::mem::take(&mut buf);
228 let txc = tx.clone();
229 let tmp = tmp_dir_path.clone();
230 let cmp_c = cmp;
231
232 self.thread_pool.spawn(move || {
233 let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
234 let _ = txc.send(res);
235 });
236 }
237
238 drop(tx);
240
241 let mut external_chunks = Vec::new();
243 for res in rx.iter() {
244 external_chunks.push(res?);
245 }
246
247 Ok(BinaryHeapMerger::new(
248 num_items.load(AOrd::Relaxed),
249 external_chunks,
250 cmp,
251 ))
252 }
253
254 fn create_chunk<T, F>(
255 &self,
256 mut buffer: Vec<T>,
257 compare: F,
258 ) -> Result<ExternalChunk<T>, SortError>
259 where
260 T: Encode + Send,
261 F: Fn(&T, &T) -> Ordering + Sync + Send,
262 {
263 self.thread_pool.install(|| {
264 buffer.par_sort_unstable_by(compare);
265 });
266
267 let tmp_file = tempfile::tempfile_in(&self.tmp_dir).unwrap();
268 let external_chunk =
269 ExternalChunk::new(tmp_file, buffer, self.compression).map_err(|err| match err {
270 ExternalChunkError::IO(err) => SortError::IO(err),
271 ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
272 })?;
273
274 return Ok(external_chunk);
275 }
276}
277
278fn create_chunk_from_parts<T, F>(
280 mut buffer: Vec<T>,
281 compare: F,
282 tmp_dir: &std::path::Path,
283 compression: u32,
284) -> Result<ExternalChunk<T>, SortError>
285where
286 T: Encode + Send + 'static,
287 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
288{
289 buffer.sort_unstable_by(compare);
290 let tmp_file = tempfile::tempfile_in(tmp_dir).map_err(SortError::IO)?;
291 ExternalChunk::new(tmp_file, buffer, compression).map_err(|err| match err {
292 ExternalChunkError::IO(e) => SortError::IO(e),
293 ExternalChunkError::EncodeError(e) => SortError::SerializationError(e),
294 })
295}
296
297fn _init_tmp_directory(tmp_path: Option<&Path>) -> io::Result<tempfile::TempDir> {
298 if let Some(tmp_path) = tmp_path {
299 tempfile::tempdir_in(tmp_path)
300 } else {
301 tempfile::tempdir()
302 }
303}
304
305fn _init_thread_pool(threads_number: Option<usize>) -> io::Result<rayon::ThreadPool> {
306 let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
307 if let Some(threads_number) = threads_number {
308 thread_pool_builder = thread_pool_builder.num_threads(threads_number);
309 }
310 thread_pool_builder
311 .build()
312 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
313}
314
315#[cfg(test)]
316mod test {
317 use std::path::Path;
318
319 use rand::seq::SliceRandom;
320 use rstest::*;
321
322 use super::{ExternalSorter, ExternalSorterBuilder};
323
324 #[rstest]
325 #[case(false)]
326 #[case(true)]
327 fn test_external_sorter(#[case] reversed: bool) {
328 let input_sorted = 0..100;
329
330 let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
331 input.shuffle(&mut rand::thread_rng());
332
333 let sorter: ExternalSorter = ExternalSorterBuilder::new()
334 .num_threads(2)
335 .with_tmp_dir(Path::new("./"))
336 .build()
337 .unwrap();
338
339 let compare = if reversed {
340 |a: &i32, b: &i32| a.cmp(b).reverse()
341 } else {
342 |a: &i32, b: &i32| a.cmp(b)
343 };
344
345
346 let expected_result = if reversed {
347 Vec::from_iter(input_sorted.clone().rev())
348 } else {
349 Vec::from_iter(input_sorted.clone())
350 };
351
352 let result = sorter.sort_by(input.clone(), compare).unwrap();
353 assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
354
355 let result = sorter.sort_by_async(input, compare).unwrap();
356 assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
357 }
358}