1use crate::extsort::chunk::{ExternalChunk, ExternalChunkError};
2use crate::extsort::merger::BinaryHeapMerger;
3
4use bincode_next::{Decode, Encode};
5use rayon::slice::ParallelSliceMut;
6use std::{
7 cmp::Ordering,
8 error::Error,
9 fmt::{self, Display},
10 io,
11 path::{Path, PathBuf},
12};
13use std::sync::{mpsc, atomic::{AtomicUsize, Ordering as AOrd}};
14
15#[derive(Debug)]
17pub enum SortError {
18 TempDir(io::Error),
20 ThreadPoolBuildError(rayon::ThreadPoolBuildError),
22 IO(io::Error),
24 SerializationError(bincode_next::error::EncodeError),
26 DeserializationError(bincode_next::error::DecodeError),
28}
29
30impl Error for SortError {
31 fn source(&self) -> Option<&(dyn Error + 'static)> {
32 Some(match &self {
33 SortError::TempDir(err) => err,
34 SortError::ThreadPoolBuildError(err) => err,
35 SortError::IO(err) => err,
36 SortError::SerializationError(err) => err,
37 SortError::DeserializationError(err) => err,
38 })
39 }
40}
41
42impl Display for SortError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match &self {
45 SortError::TempDir(err) => {
46 write!(f, "temporary directory or file not created: {}", err)
47 }
48 SortError::ThreadPoolBuildError(err) => {
49 write!(f, "thread pool initialization failed: {}", err)
50 }
51 SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
52 SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
53 SortError::DeserializationError(err) => {
54 write!(f, "data deserialization error: {}", err)
55 }
56 }
57 }
58}
59
60pub struct ExternalSorterBuilder {
64 chunk_size: usize,
65 tmp_dir: Option<PathBuf>,
66 num_threads: Option<usize>,
67 compression: u32,
68}
69
70impl ExternalSorterBuilder {
71 pub fn new() -> Self {
72 Self {
73 chunk_size: 50000000,
74 tmp_dir: None,
75 num_threads: None,
76 compression: 1,
77 }
78 }
79
80 pub fn with_chunk_size(mut self, size: usize) -> Self {
89 self.chunk_size = size;
90 self
91 }
92
93 pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
96 self.tmp_dir = Some(path.as_ref().to_path_buf());
97 self
98 }
99
100 pub fn with_compression(mut self, level: u32) -> Self {
103 self.compression = level;
104 self
105 }
106
107 pub fn num_threads(mut self, num_threads: usize) -> Self {
112 self.num_threads = Some(num_threads);
113 self
114 }
115
116 pub fn build(self) -> io::Result<ExternalSorter> {
117 Ok(ExternalSorter {
118 chunk_size: self.chunk_size,
119 compression: self.compression,
120 tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
121 thread_pool: _init_thread_pool(self.num_threads)?,
122 })
123 }
124}
125
126pub struct ExternalSorter {
127 chunk_size: usize,
128 compression: u32,
129 thread_pool: rayon::ThreadPool,
131 tmp_dir: tempfile::TempDir,
133}
134
135impl ExternalSorter {
136 pub fn sort<I, T>(
142 &self,
143 input: I,
144 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
145 where
146 T: Encode + Decode<()> + Send + Ord,
147 I: IntoIterator<Item = T>,
148 {
149 self.sort_by(input, T::cmp)
150 }
151
152 pub fn sort_by<I, T, F>(
154 &self,
155 input: I,
156 cmp: F,
157 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
158 where
159 T: Encode + Decode<()> + Send,
160 I: IntoIterator<Item = T>,
161 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
162 {
163 let mut chunk_buf = Vec::with_capacity(self.chunk_size);
164 let mut external_chunks = Vec::new();
165 let mut num_items = 0;
166
167 for item in input.into_iter() {
168 num_items += 1;
169 chunk_buf.push(item);
170 if chunk_buf.len() >= self.chunk_size {
171 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
172 chunk_buf = Vec::with_capacity(self.chunk_size);
173 }
174 }
175
176 if chunk_buf.len() > 0 {
177 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
178 }
179
180 return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
181 }
182
183 pub fn sort_async<I, T>(
184 &self,
185 input: I,
186 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
187 where
188 T: Encode + Decode<()> + Send + Ord + 'static,
189 I: IntoIterator<Item = T>,
190 {
191 self.sort_by_async(input, T::cmp)
192 }
193
194 pub fn sort_by_async<I, T, F>(
196 &self,
197 input: I,
198 cmp: F,
199 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
200 where
201 I: IntoIterator<Item = T>,
202 T: Encode + Decode<()> + Send + 'static,
203 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
204 {
205 let (tx, rx) = mpsc::channel::<Result<ExternalChunk<T>, SortError>>();
207
208 let num_items = AtomicUsize::new(0);
209 let tmp_dir_path: PathBuf = self.tmp_dir.path().to_path_buf();
210 let compression = self.compression;
211
212 let mut buf: Vec<T> = Vec::with_capacity(self.chunk_size);
214
215 for item in input.into_iter() {
216 num_items.fetch_add(1, AOrd::Relaxed);
217 buf.push(item);
218 if buf.len() >= self.chunk_size {
219 let chunk = std::mem::take(&mut buf);
220 let txc = tx.clone();
221 let tmp = tmp_dir_path.clone();
222 let cmp_c = cmp;
223
224 self.thread_pool.spawn(move || {
226 let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
227 let _ = txc.send(res);
228 });
229 }
230 }
231
232 if !buf.is_empty() {
233 let chunk = std::mem::take(&mut buf);
234 let txc = tx.clone();
235 let tmp = tmp_dir_path.clone();
236 let cmp_c = cmp;
237
238 self.thread_pool.spawn(move || {
239 let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
240 let _ = txc.send(res);
241 });
242 }
243
244 drop(tx);
246
247 let mut external_chunks = Vec::new();
249 for res in rx.iter() {
250 external_chunks.push(res?);
251 }
252
253 Ok(BinaryHeapMerger::new(
254 num_items.load(AOrd::Relaxed),
255 external_chunks,
256 cmp,
257 ))
258 }
259
260 fn create_chunk<T, F>(
261 &self,
262 mut buffer: Vec<T>,
263 compare: F,
264 ) -> Result<ExternalChunk<T>, SortError>
265 where
266 T: Encode + Send,
267 F: Fn(&T, &T) -> Ordering + Sync + Send,
268 {
269 self.thread_pool.install(|| {
270 buffer.par_sort_unstable_by(compare);
271 });
272
273 let tmp_file = tempfile::tempfile_in(&self.tmp_dir).unwrap();
274 let external_chunk =
275 ExternalChunk::new(tmp_file, buffer, self.compression).map_err(|err| match err {
276 ExternalChunkError::IO(err) => SortError::IO(err),
277 ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
278 ExternalChunkError::DecodeError(err) => SortError::DeserializationError(err),
279 })?;
280
281 return Ok(external_chunk);
282 }
283}
284
285fn create_chunk_from_parts<T, F>(
287 mut buffer: Vec<T>,
288 compare: F,
289 tmp_dir: &std::path::Path,
290 compression: u32,
291) -> Result<ExternalChunk<T>, SortError>
292where
293 T: Encode + Send + 'static,
294 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
295{
296 buffer.sort_unstable_by(compare);
297 let tmp_file = tempfile::tempfile_in(tmp_dir).map_err(SortError::IO)?;
298 ExternalChunk::new(tmp_file, buffer, compression).map_err(|err| match err {
299 ExternalChunkError::IO(e) => SortError::IO(e),
300 ExternalChunkError::EncodeError(e) => SortError::SerializationError(e),
301 ExternalChunkError::DecodeError(e) => SortError::DeserializationError(e),
302 })
303}
304
305fn _init_tmp_directory(tmp_path: Option<&Path>) -> io::Result<tempfile::TempDir> {
306 if let Some(tmp_path) = tmp_path {
307 tempfile::tempdir_in(tmp_path)
308 } else {
309 tempfile::tempdir()
310 }
311}
312
313fn _init_thread_pool(threads_number: Option<usize>) -> io::Result<rayon::ThreadPool> {
314 let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
315 if let Some(threads_number) = threads_number {
316 thread_pool_builder = thread_pool_builder.num_threads(threads_number);
317 }
318 thread_pool_builder
319 .build()
320 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
321}
322
323#[cfg(test)]
324mod test {
325 use std::path::Path;
326
327 use rand::seq::SliceRandom;
328 use rstest::*;
329
330 use super::{ExternalSorter, ExternalSorterBuilder};
331
332 #[rstest]
333 #[case(false)]
334 #[case(true)]
335 fn test_external_sorter(#[case] reversed: bool) {
336 let input_sorted = 0..100;
337
338 let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
339 input.shuffle(&mut rand::thread_rng());
340
341 let sorter: ExternalSorter = ExternalSorterBuilder::new()
342 .num_threads(2)
343 .with_tmp_dir(Path::new("./"))
344 .build()
345 .unwrap();
346
347 let compare = if reversed {
348 |a: &i32, b: &i32| a.cmp(b).reverse()
349 } else {
350 |a: &i32, b: &i32| a.cmp(b)
351 };
352
353
354 let expected_result = if reversed {
355 Vec::from_iter(input_sorted.clone().rev())
356 } else {
357 Vec::from_iter(input_sorted.clone())
358 };
359
360 let result = sorter.sort_by(input.clone(), compare).unwrap();
361 assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
362
363 let result = sorter.sort_by_async(input, compare).unwrap();
364 assert_eq!(result.collect::<Result<Vec<_>, _>>().unwrap(), expected_result);
365 }
366}