1use crate::extsort::merger::BinaryHeapMerger;
2use crate::extsort::chunk::{ExternalChunk, ExternalChunkError};
3
4use rayon::prelude::*;
5use std::{fmt::{self, Display}, io};
6use rayon;
7use std::error::Error;
8use bincode::{self, Decode, Encode};
9use std::{
10 cmp::Ordering,
11 path::{Path, PathBuf},
12};
13
14#[derive(Debug)]
16pub enum SortError {
17 TempDir(io::Error),
19 ThreadPoolBuildError(rayon::ThreadPoolBuildError),
21 IO(io::Error),
23 SerializationError(bincode::error::EncodeError),
25 DeserializationError(bincode::error::DecodeError),
27}
28
29impl Error for SortError
30{
31 fn source(&self) -> Option<&(dyn Error + 'static)> {
32 Some(match &self {
33 SortError::TempDir(err) => err,
34 SortError::ThreadPoolBuildError(err) => err,
35 SortError::IO(err) => err,
36 SortError::SerializationError(err) => err,
37 SortError::DeserializationError(err) => err,
38 })
39 }
40}
41
42impl Display for SortError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match &self {
45 SortError::TempDir(err) => write!(f, "temporary directory or file not created: {}", err),
46 SortError::ThreadPoolBuildError(err) => write!(f, "thread pool initialization failed: {}", err),
47 SortError::IO(err) => write!(f, "I/O operation failed: {}", err),
48 SortError::SerializationError(err) => write!(f, "data serialization error: {}", err),
49 SortError::DeserializationError(err) => write!(f, "data deserialization error: {}", err),
50 }
51 }
52}
53
54pub struct ExternalSorterBuilder {
58 chunk_size: usize,
59 tmp_dir: Option<PathBuf>,
60 num_threads: Option<usize>,
61 compression: Option<u32>,
62}
63
64impl ExternalSorterBuilder {
65 pub fn new() -> Self {
66 Self {
67 chunk_size: 50000000,
68 tmp_dir: None,
69 num_threads: None,
70 compression: None,
71 }
72 }
73
74 pub fn with_chunk_size(mut self, size: usize) -> Self {
83 self.chunk_size = size;
84 self
85 }
86
87 pub fn with_tmp_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
90 self.tmp_dir = Some(path.as_ref().to_path_buf());
91 self
92 }
93
94 pub fn with_compression(mut self, level: u32) -> Self {
97 self.compression = Some(level);
98 self
99 }
100
101 pub fn num_threads(mut self, num_threads: usize) -> Self {
106 self.num_threads = Some(num_threads);
107 self
108 }
109
110 pub fn build(self) -> io::Result<ExternalSorter> {
111 Ok(ExternalSorter {
112 chunk_size: self.chunk_size,
113 compression: self.compression,
114 tmp_dir: _init_tmp_directory(self.tmp_dir.as_deref())?,
115 thread_pool: _init_thread_pool(self.num_threads)?,
116 })
117 }
118}
119
120pub struct ExternalSorter {
121 chunk_size: usize,
122 compression: Option<u32>,
123 thread_pool: rayon::ThreadPool,
125 tmp_dir: tempfile::TempDir,
127}
128
129impl ExternalSorter {
130 pub fn sort<I, T>(&self, input: I) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
136 where
137 T: Encode + Decode<()> + Send + Ord,
138 I: IntoIterator<Item = T>,
139 {
140 self.sort_by(input, T::cmp)
141 }
142
143 pub fn sort_by<I, T, F>(&self, input: I, cmp: F) -> Result<impl ExactSizeIterator<Item = Result<T, ExternalChunkError>>, SortError>
145 where
146 T: Encode + Decode<()> + Send,
147 I: IntoIterator<Item = T>,
148 F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
149 {
150 let mut chunk_buf = Vec::with_capacity(self.chunk_size);
151 let mut external_chunks = Vec::new();
152 let mut num_items = 0;
153
154 for item in input.into_iter() {
155 num_items += 1;
156 chunk_buf.push(item);
157 if chunk_buf.len() >= self.chunk_size {
158 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
159 chunk_buf = Vec::with_capacity(self.chunk_size);
160 }
161 }
162
163 if chunk_buf.len() > 0 {
164 external_chunks.push(self.create_chunk(chunk_buf, cmp)?);
165 }
166
167 return Ok(BinaryHeapMerger::new(num_items, external_chunks, cmp));
168 }
169
170 fn create_chunk<T, F>(&self, mut buffer: Vec<T>, compare: F) -> Result<ExternalChunk<T>, SortError>
171 where
172 T: Encode + Send,
173 F: Fn(&T, &T) -> Ordering + Sync + Send,
174 {
175 self.thread_pool.install(|| {
176 buffer.par_sort_unstable_by(compare);
177 });
178
179 let external_chunk =
180 ExternalChunk::new(&self.tmp_dir, buffer, self.compression).map_err(|err| match err {
181 ExternalChunkError::IO(err) => SortError::IO(err),
182 ExternalChunkError::EncodeError(err) => SortError::SerializationError(err),
183 ExternalChunkError::DecodeError(err) => SortError::DeserializationError(err),
184 })?;
185
186 return Ok(external_chunk);
187 }
188}
189
190fn _init_tmp_directory(
191 tmp_path: Option<&Path>,
192) -> io::Result<tempfile::TempDir> {
193 if let Some(tmp_path) = tmp_path {
194 tempfile::tempdir_in(tmp_path)
195 } else {
196 tempfile::tempdir()
197 }
198}
199
200fn _init_thread_pool(
201 threads_number: Option<usize>,
202) -> io::Result<rayon::ThreadPool> {
203 let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
204 if let Some(threads_number) = threads_number {
205 thread_pool_builder = thread_pool_builder.num_threads(threads_number);
206 }
207 thread_pool_builder.build().map_err(|x| io::Error::new(io::ErrorKind::Other, x))
208}
209
210#[cfg(test)]
211mod test {
212 use std::path::Path;
213
214 use rand::seq::SliceRandom;
215 use rstest::*;
216
217 use super::{ExternalSorter, ExternalSorterBuilder};
218
219 #[rstest]
220 #[case(false)]
221 #[case(true)]
222 fn test_external_sorter(#[case] reversed: bool) {
223 let input_sorted = 0..100;
224
225 let mut input: Vec<i32> = Vec::from_iter(input_sorted.clone());
226 input.shuffle(&mut rand::thread_rng());
227
228 let sorter: ExternalSorter = ExternalSorterBuilder::new()
229 .num_threads(2)
230 .with_tmp_dir(Path::new("./"))
231 .build()
232 .unwrap();
233
234 let compare = if reversed {
235 |a: &i32, b: &i32| a.cmp(b).reverse()
236 } else {
237 |a: &i32, b: &i32| a.cmp(b)
238 };
239
240 let result = sorter.sort_by(input, compare).unwrap();
241
242 let actual_result: Result<Vec<i32>, _> = result.collect();
243 let actual_result = actual_result.unwrap();
244 let expected_result = if reversed {
245 Vec::from_iter(input_sorted.clone().rev())
246 } else {
247 Vec::from_iter(input_sorted.clone())
248 };
249
250 assert_eq!(actual_result, expected_result)
251 }
252}