1use crate::chunk::*;
2use crate::merge::*;
3
4use bytemuck::{bytes_of, Pod};
5use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
6use std::collections::BinaryHeap;
7use std::fs::File;
8use std::io::{self, BufWriter, Write};
9use std::path::{Path, PathBuf};
10use std::thread;
11use tempfile::tempfile_in;
12
13const ONE_MIB: usize = 1 << 20;
14
15pub struct SortingPipeline<K, V> {
21 unsorted_chunk_tx: Sender<Chunk<K, V>>,
22 merge_initiator_thread_handle: thread::JoinHandle<Result<(), io::Error>>,
23}
24
25impl<K, V> SortingPipeline<K, V>
26where
27 K: Ord + Pod + Send,
28 V: Pod + Send,
29{
30 pub fn new(
34 max_sort_concurrency: usize,
35 max_merge_concurrency: usize,
36 merge_k: usize,
37 tmp_dir_path: impl AsRef<Path>,
38 output_key_path: impl AsRef<Path>,
39 output_value_path: impl AsRef<Path>,
40 ) -> Self {
41 assert!(max_sort_concurrency > 0);
42 assert!(max_merge_concurrency > 0);
43
44 let tmp_dir_path = tmp_dir_path.as_ref().to_owned();
45 let output_key_path = output_key_path.as_ref().to_owned();
46 let output_value_path = output_value_path.as_ref().to_owned();
47
48 let (unsorted_chunk_tx, unsorted_chunk_rx) = bounded(1);
51 let (sorted_chunk_tx, sorted_chunk_rx) = bounded(1024);
54
55 for _ in 0..max_sort_concurrency {
56 let this_tmp_dir_path = tmp_dir_path.clone();
57 let this_unsorted_chunk_rx = unsorted_chunk_rx.clone();
58 let this_sorted_chunk_tx = sorted_chunk_tx.clone();
59 thread::spawn(move || {
60 run_sorter(
61 &this_tmp_dir_path,
62 this_unsorted_chunk_rx,
63 this_sorted_chunk_tx,
64 )
65 });
66 }
67
68 let merge_initiator_thread_handle = thread::spawn(move || {
70 let result = run_merge_initiator::<K, V>(
71 tmp_dir_path,
72 &output_key_path,
73 &output_value_path,
74 max_merge_concurrency,
75 merge_k,
76 sorted_chunk_rx,
77 );
78 if result.is_err() {
79 log::error!("Merge initiator exited with: {result:?}");
80 eprintln!("Merge initiator exited with: {result:?}");
81 }
82 result
83 });
84
85 Self {
86 unsorted_chunk_tx,
87 merge_initiator_thread_handle,
88 }
89 }
90
91 pub fn submit_unsorted_chunk(&self, chunk: Chunk<K, V>) {
95 if chunk.is_empty() {
96 return;
97 }
98 self.unsorted_chunk_tx.send(chunk).unwrap();
99 }
100
101 pub fn finish(self) -> Result<(), io::Error> {
103 drop(self.unsorted_chunk_tx);
105
106 self.merge_initiator_thread_handle.join().unwrap()
107 }
108}
109
110fn run_sorter<K, V>(
111 tmp_dir_path: &Path,
112 unsorted_chunk_rx: Receiver<Chunk<K, V>>,
113 sorted_chunk_tx: Sender<Result<SortedChunkFiles, io::Error>>,
114) where
115 K: Ord + Pod,
116 V: Pod,
117{
118 while let Ok(unsorted_chunk) = unsorted_chunk_rx.recv() {
119 sorted_chunk_tx
120 .send(sort_and_persist_chunk(tmp_dir_path, unsorted_chunk))
121 .unwrap();
122 }
123}
124
125fn run_merger<K, V>(
126 tmp_dir_path: &Path,
127 chunk_pair_rx: Receiver<Vec<SortedChunkFiles>>,
128 merged_chunk_tx: Sender<Result<SortedChunkFiles, io::Error>>,
129) where
130 K: Ord + Pod,
131 V: Pod,
132{
133 while let Ok(chunks) = chunk_pair_rx.recv() {
134 merged_chunk_tx
135 .send(merge_chunks_into_tempfiles::<K, V>(chunks, tmp_dir_path))
136 .unwrap();
137 }
138}
139
140fn sort_and_persist_chunk<K, V>(
141 tmp_dir_path: &Path,
142 mut chunk: Chunk<K, V>,
143) -> Result<SortedChunkFiles, io::Error>
144where
145 K: Ord + Pod,
146 V: Pod,
147{
148 let sort_span = tracing::info_span!("sort_chunk");
149 sort_span.in_scope(|| chunk.sort());
150
151 let num_entries = chunk.len();
152
153 let persist_span = tracing::info_span!("persist_sorted_chunk");
155 let _gaurd = persist_span.enter();
156 let mut key_writer = BufWriter::with_capacity(ONE_MIB, tempfile_in(tmp_dir_path)?);
157 let mut value_writer = BufWriter::with_capacity(ONE_MIB, tempfile_in(tmp_dir_path)?);
158 for (k, v) in chunk.entries.into_iter() {
159 key_writer.write_all(bytes_of(&k))?;
160 value_writer.write_all(bytes_of(&v))?;
161 }
162
163 SortedChunkFiles::new(
164 key_writer.into_inner()?,
165 value_writer.into_inner()?,
166 num_entries,
167 )
168}
169
170fn run_merge_initiator<K, V>(
174 tmp_dir_path: PathBuf,
175 output_key_path: &Path,
176 output_value_path: &Path,
177 max_merge_concurrency: usize,
178 merge_k: usize,
179 sorted_chunk_rx: Receiver<Result<SortedChunkFiles, io::Error>>,
180) -> Result<(), io::Error>
181where
182 K: Ord + Pod,
183 V: Pod,
184{
185 let (chunk_pair_tx, chunk_pair_rx) = bounded(1);
187 let (merged_chunk_tx, merged_chunk_rx) = unbounded();
189
190 for _ in 0..max_merge_concurrency {
191 let this_tmp_dir_path = tmp_dir_path.clone();
192 let this_chunk_pair_rx = chunk_pair_rx.clone();
193 let this_merged_chunk_tx = merged_chunk_tx.clone();
194 thread::spawn(move || {
195 run_merger::<K, V>(&this_tmp_dir_path, this_chunk_pair_rx, this_merged_chunk_tx)
196 });
197 }
198
199 let mut merge_queue = BinaryHeap::new();
200
201 let mut num_sorted_chunks_received = 0;
202 let mut num_merges_started = 0;
204 let mut num_merges_completed = 0;
206 macro_rules! num_pending_merges {
207 () => {
208 num_merges_started - num_merges_completed
209 };
210 }
211
212 while let Ok(sorted_chunk_result) = sorted_chunk_rx.recv() {
214 num_sorted_chunks_received += 1;
215 log::debug!("# sorted chunks received = {num_sorted_chunks_received}");
216
217 merge_queue.push(sorted_chunk_result?);
219
220 while num_pending_merges!() < max_merge_concurrency && merge_queue.len() >= merge_k {
221 let chunks: Vec<_> = (0..merge_k).filter_map(|_| merge_queue.pop()).collect();
222 chunk_pair_tx.send(chunks).unwrap();
223 num_merges_started += 1;
224 }
225
226 log::info!(
227 "Merge queue length = {}, # pending merges = {}",
228 merge_queue.len(),
229 num_pending_merges!()
230 );
231
232 while let Ok(merged_chunk) = merged_chunk_rx.try_recv() {
234 num_merges_completed += 1;
235 merge_queue.push(merged_chunk?);
236 }
237 }
238 log::info!("All chunks sorted, only merge work remains");
240 log::info!(
241 "Merge queue length = {}, # pending merges = {}",
242 merge_queue.len(),
243 num_pending_merges!()
244 );
245
246 while merge_queue.len() + num_pending_merges!() > merge_k {
248 while merge_queue.len() >= merge_k {
250 let chunks: Vec<_> = (0..merge_k).filter_map(|_| merge_queue.pop()).collect();
251 chunk_pair_tx.send(chunks).unwrap();
252 num_merges_started += 1;
253 }
254
255 log::info!(
256 "Merge queue length = {}, # pending merges = {}",
257 merge_queue.len(),
258 num_pending_merges!()
259 );
260
261 if num_pending_merges!() > 0 {
263 let merged_chunk_result = merged_chunk_rx.recv().unwrap();
264 num_merges_completed += 1;
265 merge_queue.push(merged_chunk_result?);
266 }
267 }
268
269 while num_pending_merges!() > 0 {
271 let merged_chunk_result = merged_chunk_rx.recv().unwrap();
272 num_merges_completed += 1;
273 merge_queue.push(merged_chunk_result?);
274 }
275
276 let mut output_key_file = File::create(output_key_path)?;
277 let mut output_value_file = File::create(output_value_path)?;
278
279 if merge_queue.is_empty() {
280 return Ok(());
281 }
282
283 if merge_queue.len() == 1 {
284 log::info!("Only one chunk: just copying to output location");
287 let mut final_chunk = merge_queue.pop().unwrap();
288 io::copy(&mut final_chunk.key_file, &mut output_key_file)?;
289 io::copy(&mut final_chunk.value_file, &mut output_value_file)?;
290 return Ok(());
291 }
292
293 assert!(merge_queue.len() <= merge_k);
295 let chunks: Vec<_> = (0..merge_k).filter_map(|_| merge_queue.pop()).collect();
296 let _ = merge_chunks::<K, V>(chunks, output_key_file, output_value_file)?;
297 num_merges_completed += 1;
298
299 log::info!("Done merging! Performed {num_merges_completed} merge(s) total");
300 Ok(())
301}