use std::{
num::NonZeroUsize,
sync::mpsc::{Receiver, SyncSender},
thread::ScopedJoinHandle,
};
use crate::{orderer::Orderer, tape::TapeCollection, ExtsortConfig};
use super::*;
pub struct MultithreadedBufferCleaner<O, F> {
config: ExtsortConfig,
orderer: O,
buffer_sort: F,
}
pub struct MultithreadedBufferCleanerHandle<'scope, T, O, F> {
rx: Receiver<io::Result<Vec<T>>>,
tx: SyncSender<BufferCleanerCommand<T>>,
finalize_handle: ScopedJoinHandle<'scope, FinalizeContents<T, O, F>>,
buffer_capacity: NonZeroUsize,
}
enum BufferCleanerCommand<T> {
CleanBuffer(Vec<T>),
Finalize,
}
impl<O, F> MultithreadedBufferCleaner<O, F>
where
O: Send,
{
pub fn new(config: ExtsortConfig, orderer: O, buffer_sort: F) -> Self {
Self {
config,
orderer,
buffer_sort,
}
}
pub fn run<Fo, T, R>(self, func: Fo) -> R
where
Fo: FnOnce(MultithreadedBufferCleanerHandle<T, O, F>) -> R,
F: FnMut(&O, &mut [T]) + Send,
T: Send,
{
std::thread::scope(move |scope| {
let config = self.config;
let max_buffer_size_nonzero = config.get_num_items_for::<T>();
let max_buffer_size = max_buffer_size_nonzero.get();
let compression_choice = config.compression_choice();
let tape_collection = TapeCollection::<T>::new(
config.temp_file_folder,
NonZeroUsize::new(256).unwrap(),
compression_choice,
);
let (worker_tx, rx) = std::sync::mpsc::sync_channel(1);
let (tx, worker_rx) = std::sync::mpsc::sync_channel(1);
let finalize_handle = std::thread::Builder::new()
.name("Sort-Buffer-Writer".to_owned())
.spawn_scoped(scope, move || {
let mut cleaned_buffer = Vec::with_capacity(max_buffer_size / 2);
let orderer = self.orderer;
let mut tape_collection = tape_collection;
let mut buffer_sort = self.buffer_sort;
loop {
match worker_rx.recv().unwrap() {
BufferCleanerCommand::CleanBuffer(mut buf) => {
worker_tx.send(Ok(cleaned_buffer)).ok();
(buffer_sort)(&orderer, &mut buf);
if let Err(e) = tape_collection.add_run(&mut buf) {
worker_tx.send(Err(e)).ok();
break;
}
cleaned_buffer = buf;
}
BufferCleanerCommand::Finalize => {
drop(cleaned_buffer);
break;
}
};
}
let tapes = tape_collection.into_tapes(max_buffer_size_nonzero);
FinalizeContents {
tapes,
orderer,
sort_func: buffer_sort,
}
})
.unwrap();
let handle = MultithreadedBufferCleanerHandle {
rx,
tx,
finalize_handle,
buffer_capacity: max_buffer_size_nonzero,
};
func(handle)
})
}
}
impl<T, O, F> MultithreadedBufferCleanerHandle<'_, T, O, F> {
fn send(&mut self, command: BufferCleanerCommand<T>) -> io::Result<()> {
self.tx.send(command).map_err(|_buf| {
io::Error::new(
io::ErrorKind::BrokenPipe,
"the writer thread exited unexpectedly",
)
})
}
}
impl<T, O, F> BufferCleaner<T, O, F> for MultithreadedBufferCleanerHandle<'_, T, O, F>
where
O: Orderer<T> + Send,
T: Send,
F: FnMut(&O, &mut [T]),
{
fn clean_buffer(&mut self, buffer: &mut Vec<T>) -> io::Result<()> {
let buf = core::mem::take(buffer);
self.send(BufferCleanerCommand::CleanBuffer(buf))?;
let buf = self
.rx
.recv()
.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))??;
*buffer = buf;
Ok(())
}
fn get_buffer(&mut self) -> Vec<T> {
Vec::with_capacity(self.buffer_capacity.get() / 2)
}
fn finalize(mut self) -> io::Result<FinalizeContents<T, O, F>> {
self.send(BufferCleanerCommand::Finalize)?;
while let Ok(msg) = self.rx.recv() {
drop(msg?);
}
let res = self.finalize_handle.join().unwrap();
Ok(res)
}
}