1use log;
4use std::cmp::Ordering;
5use std::error::Error;
6use std::fmt;
7use std::fmt::{Debug, Display};
8use std::io;
9use std::marker::PhantomData;
10use std::path::Path;
11
12use crate::chunk::{ExternalChunk, ExternalChunkError, RmpExternalChunk};
13use crate::merger::BinaryHeapMerger;
14use crate::{ChunkBuffer, ChunkBufferBuilder, LimitedBufferBuilder};
15
16#[derive(Debug)]
18pub enum SortError<S: Error, D: Error, I: Error> {
19 TempDir(io::Error),
21 ThreadPoolBuildError(rayon::ThreadPoolBuildError),
23 IO(io::Error),
25 SerializationError(S),
27 DeserializationError(D),
29 InputError(I),
31}
32
33impl<S, D, I> Error for SortError<S, D, I>
34where
35 S: Error + 'static,
36 D: Error + 'static,
37 I: Error + 'static,
38{
39 fn source(&self) -> Option<&(dyn Error + 'static)> {
40 Some(match &self {
41 SortError::TempDir(err) => err,
42 SortError::ThreadPoolBuildError(err) => err,
43 SortError::IO(err) => err,
44 SortError::SerializationError(err) => err,
45 SortError::DeserializationError(err) => err,
46 SortError::InputError(err) => err,
47 })
48 }
49}
50
51impl<S: Error, D: Error, I: Error> Display for SortError<S, D, I> {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match &self {
54 SortError::TempDir(err) => write!(f, "temporary directory or file not created: {}", err),
55 SortError::ThreadPoolBuildError(err) => write!(f, "thread pool initialization failed: {}", err),
56 SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
57 SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
58 SortError::DeserializationError(err) => write!(f, "data deserialization error: {}", err),
59 SortError::InputError(err) => write!(f, "input data stream error: {}", err),
60 }
61 }
62}
63
64#[derive(Clone)]
66pub struct ExternalSorterBuilder<T, E, B = LimitedBufferBuilder, C = RmpExternalChunk<T>>
67where
68 T: Send,
69 E: Error,
70 B: ChunkBufferBuilder<T>,
71 C: ExternalChunk<T>,
72{
73 threads_number: Option<usize>,
75 tmp_dir: Option<Box<Path>>,
77 rw_buf_size: Option<usize>,
79 buffer_builder: B,
81
82 external_chunk_type: PhantomData<C>,
84 item_type: PhantomData<T>,
86 input_error_type: PhantomData<E>,
88}
89
90impl<T, E, B, C> ExternalSorterBuilder<T, E, B, C>
91where
92 T: Send,
93 E: Error,
94 B: ChunkBufferBuilder<T>,
95 C: ExternalChunk<T>,
96{
97 pub fn new() -> Self {
99 ExternalSorterBuilder::default()
100 }
101
102 pub fn build(
104 self,
105 ) -> Result<ExternalSorter<T, E, B, C>, SortError<C::SerializationError, C::DeserializationError, E>> {
106 ExternalSorter::new(
107 self.threads_number,
108 self.tmp_dir.as_deref(),
109 self.buffer_builder,
110 self.rw_buf_size,
111 )
112 }
113
114 pub fn with_threads_number(mut self, threads_number: usize) -> ExternalSorterBuilder<T, E, B, C> {
116 self.threads_number = Some(threads_number);
117 return self;
118 }
119
120 pub fn with_tmp_dir(mut self, path: &Path) -> ExternalSorterBuilder<T, E, B, C> {
122 self.tmp_dir = Some(path.into());
123 return self;
124 }
125
126 pub fn with_buffer(mut self, buffer_builder: B) -> ExternalSorterBuilder<T, E, B, C> {
128 self.buffer_builder = buffer_builder;
129 return self;
130 }
131
132 pub fn with_rw_buf_size(mut self, buf_size: usize) -> ExternalSorterBuilder<T, E, B, C> {
134 self.rw_buf_size = Some(buf_size);
135 return self;
136 }
137}
138
139impl<T, E, B, C> Default for ExternalSorterBuilder<T, E, B, C>
140where
141 T: Send,
142 E: Error,
143 B: ChunkBufferBuilder<T>,
144 C: ExternalChunk<T>,
145{
146 fn default() -> Self {
147 ExternalSorterBuilder {
148 threads_number: None,
149 tmp_dir: None,
150 rw_buf_size: None,
151 buffer_builder: B::default(),
152 external_chunk_type: PhantomData,
153 item_type: PhantomData,
154 input_error_type: PhantomData,
155 }
156 }
157}
158
159pub struct ExternalSorter<T, E, B = LimitedBufferBuilder, C = RmpExternalChunk<T>>
161where
162 T: Send,
163 E: Error,
164 B: ChunkBufferBuilder<T>,
165 C: ExternalChunk<T>,
166{
167 thread_pool: rayon::ThreadPool,
169 tmp_dir: tempfile::TempDir,
171 buffer_builder: B,
173 rw_buf_size: Option<usize>,
175
176 external_chunk_type: PhantomData<C>,
178 item_type: PhantomData<T>,
180 input_error_type: PhantomData<E>,
182}
183
184impl<T, E, B, C> ExternalSorter<T, E, B, C>
185where
186 T: Send,
187 E: Error,
188 B: ChunkBufferBuilder<T>,
189 C: ExternalChunk<T>,
190{
191 pub fn new(
201 threads_number: Option<usize>,
202 tmp_path: Option<&Path>,
203 buffer_builder: B,
204 rw_buf_size: Option<usize>,
205 ) -> Result<Self, SortError<C::SerializationError, C::DeserializationError, E>> {
206 return Ok(ExternalSorter {
207 rw_buf_size,
208 buffer_builder,
209 thread_pool: Self::init_thread_pool(threads_number)?,
210 tmp_dir: Self::init_tmp_directory(tmp_path)?,
211 external_chunk_type: PhantomData,
212 item_type: PhantomData,
213 input_error_type: PhantomData,
214 });
215 }
216
217 fn init_thread_pool(
218 threads_number: Option<usize>,
219 ) -> Result<rayon::ThreadPool, SortError<C::SerializationError, C::DeserializationError, E>> {
220 let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
221
222 if let Some(threads_number) = threads_number {
223 log::info!("initializing thread-pool (threads: {})", threads_number);
224 thread_pool_builder = thread_pool_builder.num_threads(threads_number);
225 } else {
226 log::info!("initializing thread-pool (threads: default)");
227 }
228 let thread_pool = thread_pool_builder
229 .build()
230 .map_err(|err| SortError::ThreadPoolBuildError(err))?;
231
232 return Ok(thread_pool);
233 }
234
235 fn init_tmp_directory(
236 tmp_path: Option<&Path>,
237 ) -> Result<tempfile::TempDir, SortError<C::SerializationError, C::DeserializationError, E>> {
238 let tmp_dir = if let Some(tmp_path) = tmp_path {
239 tempfile::tempdir_in(tmp_path)
240 } else {
241 tempfile::tempdir()
242 }
243 .map_err(|err| SortError::TempDir(err))?;
244
245 log::info!("using {} as a temporary directory", tmp_dir.path().display());
246
247 return Ok(tmp_dir);
248 }
249
250 pub fn sort<I>(
256 &self,
257 input: I,
258 ) -> Result<
259 BinaryHeapMerger<T, C::DeserializationError, impl Fn(&T, &T) -> Ordering + Copy, C>,
260 SortError<C::SerializationError, C::DeserializationError, E>,
261 >
262 where
263 T: Ord,
264 I: IntoIterator<Item = Result<T, E>>,
265 {
266 self.sort_by(input, T::cmp)
267 }
268
269 pub fn sort_by<I, F>(
276 &self,
277 input: I,
278 compare: F,
279 ) -> Result<
280 BinaryHeapMerger<T, C::DeserializationError, F, C>,
281 SortError<C::SerializationError, C::DeserializationError, E>,
282 >
283 where
284 I: IntoIterator<Item = Result<T, E>>,
285 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
286 {
287 let mut chunk_buf = self.buffer_builder.build();
288 let mut external_chunks = Vec::new();
289
290 for item in input.into_iter() {
291 match item {
292 Ok(item) => chunk_buf.push(item),
293 Err(err) => return Err(SortError::InputError(err)),
294 }
295
296 if chunk_buf.is_full() {
297 external_chunks.push(self.create_chunk(chunk_buf, compare)?);
298 chunk_buf = self.buffer_builder.build();
299 }
300 }
301
302 if chunk_buf.len() > 0 {
303 external_chunks.push(self.create_chunk(chunk_buf, compare)?);
304 }
305
306 log::debug!("external sort preparation done");
307
308 return Ok(BinaryHeapMerger::new(external_chunks, compare));
309 }
310
311 fn create_chunk<F>(
312 &self,
313 mut buffer: impl ChunkBuffer<T>,
314 compare: F,
315 ) -> Result<C, SortError<C::SerializationError, C::DeserializationError, E>>
316 where
317 F: Fn(&T, &T) -> Ordering + Sync + Send,
318 {
319 log::debug!("sorting chunk data ...");
320 self.thread_pool.install(|| {
321 buffer.par_sort_by(compare);
322 });
323
324 log::debug!("saving chunk data");
325 let external_chunk =
326 ExternalChunk::build(&self.tmp_dir, buffer, self.rw_buf_size).map_err(|err| match err {
327 ExternalChunkError::IO(err) => SortError::IO(err),
328 ExternalChunkError::SerializationError(err) => SortError::SerializationError(err),
329 })?;
330
331 return Ok(external_chunk);
332 }
333}
334
335#[cfg(test)]
336mod test {
337 use std::io;
338 use std::path::Path;
339
340 use rand::seq::SliceRandom;
341 use rstest::*;
342
343 use super::{ExternalSorter, ExternalSorterBuilder, LimitedBufferBuilder};
344
345 #[rstest]
346 #[case(false)]
347 #[case(true)]
348 fn test_external_sorter(#[case] reversed: bool) {
349 let input_sorted = 0..100;
350
351 let mut input: Vec<Result<i32, io::Error>> = Vec::from_iter(input_sorted.clone().map(|item| Ok(item)));
352 input.shuffle(&mut rand::thread_rng());
353
354 let sorter: ExternalSorter<i32, _> = ExternalSorterBuilder::new()
355 .with_buffer(LimitedBufferBuilder::new(8, true))
356 .with_threads_number(2)
357 .with_tmp_dir(Path::new("./"))
358 .build()
359 .unwrap();
360
361 let compare = if reversed {
362 |a: &i32, b: &i32| a.cmp(b).reverse()
363 } else {
364 |a: &i32, b: &i32| a.cmp(b)
365 };
366
367 let result = sorter.sort_by(input, compare).unwrap();
368
369 let actual_result: Result<Vec<i32>, _> = result.collect();
370 let actual_result = actual_result.unwrap();
371 let expected_result = if reversed {
372 Vec::from_iter(input_sorted.clone().rev())
373 } else {
374 Vec::from_iter(input_sorted.clone())
375 };
376
377 assert_eq!(actual_result, expected_result)
378 }
379}