1use crate::extsort::merger::BinaryHeapMerger;
2use crate::extsort::{
3 chunk::{ExternalChunk, ExternalChunkError},
4 DiskDeserializer, DiskSerializer,
5};
6
7use rayon::slice::ParallelSliceMut;
8use rkyv::{Archive, Deserialize, Serialize};
9use std::sync::{
10 atomic::{AtomicUsize, Ordering as AOrd},
11 mpsc,
12};
13use std::{
14 cmp::Ordering,
15 error::Error,
16 fmt::{self, Display},
17 io,
18 path::{Path, PathBuf},
19};
20
21#[derive(Debug)]
23pub enum SortError {
24 TempDir(io::Error),
26 ThreadPoolBuildError(rayon::ThreadPoolBuildError),
28 IO(io::Error),
30 SerializationError(rkyv::rancor::Error),
32 DeserializationError(rkyv::rancor::Error),
34}
35
36impl Error for SortError {
37 fn source(&self) -> Option<&(dyn Error + 'static)> {
38 Some(match &self {
39 SortError::TempDir(err) => err,
40 SortError::ThreadPoolBuildError(err) => err,
41 SortError::IO(err) => err,
42 SortError::SerializationError(err) => err,
43 SortError::DeserializationError(err) => err,
44 })
45 }
46}
47
48impl Display for SortError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match &self {
51 SortError::TempDir(err) => {
52 write!(f, "temporary directory or file not created: {}", err)
53 }
54 SortError::ThreadPoolBuildError(err) => {
55 write!(f, "thread pool initialization failed: {}", err)
56 }
57 SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
58 SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
59 SortError::DeserializationError(err) => {
60 write!(f, "data deserialization error: {}", err)
61 }
62 }
63 }
64}
65
66pub struct ExternalSorterBuilder {
70 chunk_size: usize,
71 tmp_dir: Option<PathBuf>,
72 num_threads: Option<usize>,
73 compression: u32,
74}
75
76impl ExternalSorterBuilder {
77 pub fn new() -> Self {
78 Self {
79 chunk_size: 50000000,
80 tmp_dir: None,
81 num_threads: None,
82 compression: 1,
83 }
84 }
85
86 pub fn with_chunk_size(mut self, size: usize) -> Self {
95 self.chunk_size = size;
96 self
97 }
98
99 pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
102 self.tmp_dir = Some(path.as_ref().to_path_buf());
103 self
104 }
105
106 pub fn with_compression(mut self, level: u32) -> Self {
109 self.compression = level;
110 self
111 }
112
113 pub fn num_threads(mut self, num_threads: usize) -> Self {
118 self.num_threads = Some(num_threads);
119 self
120 }
121
122 pub fn build(self) -> io::Result<ExternalSorter> {
127 Ok(ExternalSorter {
128 chunk_size: self.chunk_size,
129 compression: self.compression,
130 tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
131 thread_pool: _init_thread_pool(self.num_threads)?,
132 })
133 }
134}
135
136pub struct ExternalSorter {
137 chunk_size: usize,
138 compression: u32,
139 thread_pool: rayon::ThreadPool,
141 tmp_dir: tempfile::TempDir,
143}
144
145impl ExternalSorter {
146 pub fn sort<I, T>(
154 &self,
155 input: I,
156 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
157 where
158 T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + Ord,
159 T::Archived: Deserialize<T, DiskDeserializer>,
160 I: IntoIterator<Item = T>,
161 {
162 self.sort_by(input, T::cmp)
163 }
164
165 pub fn sort_by<I, T, F>(
174 &self,
175 input: I,
176 cmp: F,
177 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
178 where
179 T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send,
180 T::Archived: Deserialize<T, DiskDeserializer>,
181 I: IntoIterator<Item = T>,
182 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
183 {
184 let mut chunk_buf = Vec::with_capacity(self.chunk_size);
185 let mut external_chunks = Vec::new();
186 let mut num_items = 0;
187
188 for item in input.into_iter() {
189 num_items += 1;
190 chunk_buf.push(item);
191 if chunk_buf.len() >= self.chunk_size {
192 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
193 chunk_buf = Vec::with_capacity(self.chunk_size);
194 }
195 }
196
197 if chunk_buf.len() > 0 {
198 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
199 }
200
201 return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
202 }
203
204 pub fn sort_async<I, T>(
212 &self,
213 input: I,
214 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
215 where
216 T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + Ord + 'static,
217 T::Archived: Deserialize<T, DiskDeserializer>,
218 I: IntoIterator<Item = T>,
219 {
220 self.sort_by_async(input, T::cmp)
221 }
222
223 pub fn sort_by_async<I, T, F>(
232 &self,
233 input: I,
234 cmp: F,
235 ) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
236 where
237 I: IntoIterator<Item = T>,
238 T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + 'static,
239 T::Archived: Deserialize<T, DiskDeserializer>,
240 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
241 {
242 let (tx, rx) = mpsc::channel::<Result<ExternalChunk<T>, SortError>>();
244
245 let num_items = AtomicUsize::new(0);
246 let tmp_dir_path: PathBuf = self.tmp_dir.path().to_path_buf();
247 let compression = self.compression;
248
249 let mut buf: Vec<T> = Vec::with_capacity(self.chunk_size);
251
252 for item in input.into_iter() {
253 num_items.fetch_add(1, AOrd::Relaxed);
254 buf.push(item);
255 if buf.len() >= self.chunk_size {
256 let chunk = std::mem::take(&mut buf);
257 let txc = tx.clone();
258 let tmp = tmp_dir_path.clone();
259 let cmp_c = cmp;
260
261 self.thread_pool.spawn(move || {
263 let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
264 let _ = txc.send(res);
265 });
266 }
267 }
268
269 if !buf.is_empty() {
270 let chunk = std::mem::take(&mut buf);
271 let txc = tx.clone();
272 let tmp = tmp_dir_path.clone();
273 let cmp_c = cmp;
274
275 self.thread_pool.spawn(move || {
276 let res = create_chunk_from_parts(chunk, cmp_c, &tmp, compression);
277 let _ = txc.send(res);
278 });
279 }
280
281 drop(tx);
283
284 let mut external_chunks = Vec::new();
286 for res in rx.iter() {
287 external_chunks.push(res?);
288 }
289
290 Ok(BinaryHeapMerger::new(
291 num_items.load(AOrd::Relaxed),
292 external_chunks,
293 cmp,
294 ))
295 }
296
297 fn create_chunk<T, F>(
298 &self,
299 mut buffer: Vec<T>,
300 compare: F,
301 ) -> Result<ExternalChunk<T>, SortError>
302 where
303 T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send,
304 T::Archived: Deserialize<T, DiskDeserializer>,
305 F: Fn(&T, &T) -> Ordering + Sync + Send,
306 {
307 self.thread_pool.install(|| {
308 buffer.par_sort_unstable_by(compare);
309 });
310
311 let tmp_file = tempfile::tempfile_in(&self.tmp_dir).unwrap();
312 let external_chunk =
313 ExternalChunk::new(tmp_file, buffer, self.compression).map_err(|err| match err {
314 ExternalChunkError::IO(err) => SortError::IO(err),
315 ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
316 ExternalChunkError::DecodeError(err) => SortError::DeserializationError(err),
317 })?;
318
319 return Ok(external_chunk);
320 }
321}
322
323fn create_chunk_from_parts<T, F>(
325 mut buffer: Vec<T>,
326 compare: F,
327 tmp_dir: &std::path::Path,
328 compression: u32,
329) -> Result<ExternalChunk<T>, SortError>
330where
331 T: Archive + for<'a> Serialize<DiskSerializer<'a>> + Send + 'static,
332 T::Archived: Deserialize<T, DiskDeserializer>,
333 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy + 'static,
334{
335 buffer.sort_unstable_by(compare);
336 let tmp_file = tempfile::tempfile_in(tmp_dir).map_err(SortError::IO)?;
337 ExternalChunk::new(tmp_file, buffer, compression).map_err(|err| match err {
338 ExternalChunkError::IO(e) => SortError::IO(e),
339 ExternalChunkError::EncodeError(e) => SortError::SerializationError(e),
340 ExternalChunkError::DecodeError(e) => SortError::DeserializationError(e),
341 })
342}
343
344fn _init_tmp_directory(tmp_path: Option<&Path>) -> io::Result<tempfile::TempDir> {
345 if let Some(tmp_path) = tmp_path {
346 tempfile::tempdir_in(tmp_path)
347 } else {
348 tempfile::tempdir()
349 }
350}
351
352fn _init_thread_pool(threads_number: Option<usize>) -> io::Result<rayon::ThreadPool> {
353 let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
354 if let Some(threads_number) = threads_number {
355 thread_pool_builder = thread_pool_builder.num_threads(threads_number);
356 }
357 thread_pool_builder
358 .build()
359 .map_err(|x| io::Error::new(io::ErrorKind::Other, x))
360}
361
362#[cfg(test)]
363mod test {
364 use std::path::Path;
365
366 use rand::seq::SliceRandom;
367 use rstest::*;
368
369 use super::{ExternalSorter, ExternalSorterBuilder};
370
371 #[rstest]
372 #[case(false)]
373 #[case(true)]
374 fn test_external_sorter(#[case] reversed: bool) {
375 let input_sorted = 0..100;
376
377 let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
378 input.shuffle(&mut rand::thread_rng());
379
380 let sorter: ExternalSorter = ExternalSorterBuilder::new()
381 .num_threads(2)
382 .with_tmp_dir(Path::new("./"))
383 .build()
384 .unwrap();
385
386 let compare = if reversed {
387 |a: &i32, b: &i32| a.cmp(b).reverse()
388 } else {
389 |a: &i32, b: &i32| a.cmp(b)
390 };
391
392 let expected_result = if reversed {
393 Vec::from_iter(input_sorted.clone().rev())
394 } else {
395 Vec::from_iter(input_sorted.clone())
396 };
397
398 let result = sorter.sort_by(input.clone(), compare).unwrap();
399 assert_eq!(
400 result.collect::<Result<Vec<_>, _>>().unwrap(),
401 expected_result
402 );
403
404 let result = sorter.sort_by_async(input, compare).unwrap();
405 assert_eq!(
406 result.collect::<Result<Vec<_>, _>>().unwrap(),
407 expected_result
408 );
409 }
410}