use crate::sort::bgzf_io::{StagingBuffer, io_writer_loop};
use crate::sort::keys::RawSortKey;
use crate::sort::worker_pool::{CompressResult, PermitPool, SortWorkerPool};
use anyhow::Result;
use crossbeam_channel::bounded;
use fgumi_bgzf::BGZF_MAX_BLOCK_SIZE;
use std::io::BufWriter;
use std::marker::PhantomData;
use std::path::Path;
use std::sync::Arc;
use std::thread::{self, JoinHandle};
pub struct PooledChunkWriter<K: RawSortKey> {
staging: Option<StagingBuffer>,
key_buf: Vec<u8>,
io_handle: Option<JoinHandle<Result<()>>>,
_phantom: PhantomData<K>,
}
impl<K: RawSortKey> PooledChunkWriter<K> {
pub fn new(pool: Arc<SortWorkerPool>, path: &Path) -> Result<Self> {
let file = std::fs::File::create(path)?;
let writer = BufWriter::with_capacity(256 * 1024, file);
let reorder_capacity = pool.num_workers() * 4;
let (result_tx, result_rx) = bounded::<CompressResult>(reorder_capacity);
let buffer_pool = pool.buffer_pool.clone();
let permit_pool = Arc::new(PermitPool::new(reorder_capacity));
let pp = Arc::clone(&permit_pool);
let io_handle = thread::spawn(move || io_writer_loop(writer, result_rx, buffer_pool, pp));
Ok(Self {
staging: Some(StagingBuffer::new(pool, result_tx, permit_pool)),
key_buf: Vec::new(),
io_handle: Some(io_handle),
_phantom: PhantomData,
})
}
#[allow(clippy::cast_possible_truncation)]
pub fn write_record(&mut self, key: &K, record: &[u8]) -> Result<()> {
let staging = self.staging.as_mut().expect("write_record called after start_finish");
if K::EMBEDDED_IN_RECORD {
let needed = 4 + record.len();
if staging.buf().len() + needed > BGZF_MAX_BLOCK_SIZE {
staging.flush()?;
}
staging.buf().extend_from_slice(&(record.len() as u32).to_le_bytes());
if record.len() > BGZF_MAX_BLOCK_SIZE.saturating_sub(4) {
staging.write_chunked(record)?;
} else {
staging.buf().extend_from_slice(record);
staging.flush_if_full()?;
}
} else {
self.key_buf.clear();
key.write_to(&mut self.key_buf)?;
let needed = self.key_buf.len() + 4 + record.len();
if staging.buf().len() + needed > BGZF_MAX_BLOCK_SIZE {
staging.flush()?;
}
staging.buf().extend_from_slice(&self.key_buf);
staging.buf().extend_from_slice(&(record.len() as u32).to_le_bytes());
staging.write_chunked(record)?;
}
Ok(())
}
pub fn finish(self) -> Result<()> {
self.start_finish()?.wait()
}
pub fn start_finish(mut self) -> Result<SpillWriteHandle> {
if let Some(mut staging) = self.staging.take() {
if !staging.buf().is_empty() {
staging.flush()?;
}
drop(staging); }
Ok(SpillWriteHandle::new(self.io_handle.take()))
}
}
impl<K: RawSortKey> Drop for PooledChunkWriter<K> {
fn drop(&mut self) {
if self.io_handle.is_some() {
drop(self.staging.take());
if let Some(handle) = self.io_handle.take() {
match handle.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => log::error!("PooledChunkWriter: I/O writer thread error: {e}"),
Err(_) => log::error!("PooledChunkWriter: I/O writer thread panicked"),
}
}
}
}
}
#[must_use = "call wait() to propagate write errors; dropping silently logs them"]
pub struct SpillWriteHandle {
io_handle: Option<JoinHandle<Result<()>>>,
}
impl SpillWriteHandle {
pub(crate) fn new(io_handle: Option<JoinHandle<Result<()>>>) -> Self {
Self { io_handle }
}
pub fn wait(mut self) -> Result<()> {
if let Some(handle) = self.io_handle.take() {
handle.join().map_err(|_| anyhow::anyhow!("I/O writer thread panicked"))??;
}
Ok(())
}
}
impl Drop for SpillWriteHandle {
fn drop(&mut self) {
if let Some(handle) = self.io_handle.take() {
match handle.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => log::error!("SpillWriteHandle: I/O writer thread error: {e}"),
Err(_) => log::error!("SpillWriteHandle: I/O writer thread panicked"),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sort::inline_buffer::TemplateKey;
use crate::sort::raw::GenericKeyedChunkReader;
use tempfile::TempDir;
#[allow(clippy::cast_possible_truncation)]
fn make_key(i: u64) -> TemplateKey {
TemplateKey::new(
i as i32, i as i32, false, i32::MAX, i32::MAX, false, 0, 0, (0, false), i, false, )
}
#[test]
#[allow(clippy::cast_possible_truncation)]
fn test_pooled_writer_roundtrip() {
let dir = TempDir::new().unwrap();
let chunk_path = dir.path().join("test_chunk.keyed");
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
let records: Vec<(TemplateKey, Vec<u8>)> = (0..100)
.map(|i| {
let key = make_key(i);
let record = vec![(i % 256) as u8; 200 + (i as usize % 50)];
(key, record)
})
.collect();
{
let mut writer = PooledChunkWriter::<TemplateKey>::new(Arc::clone(&pool), &chunk_path)
.expect("create writer");
for (key, record) in &records {
writer.write_record(key, record).expect("write record");
}
writer.finish().expect("finish writer");
}
let mut reader =
GenericKeyedChunkReader::<TemplateKey>::open(&chunk_path, None).expect("open reader");
let mut buf = Vec::new();
let mut read_records = Vec::new();
while let Some(key) = reader.next_record(&mut buf).expect("read record") {
read_records.push((key, buf.clone()));
}
assert_eq!(records.len(), read_records.len(), "record count mismatch");
for (i, ((expected_key, expected_data), (actual_key, actual_data))) in
records.iter().zip(read_records.iter()).enumerate()
{
assert_eq!(*expected_key, *actual_key, "key mismatch at {i}");
assert_eq!(expected_data, actual_data, "data mismatch at {i}");
}
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
}
#[test]
fn test_pooled_writer_empty() {
let dir = TempDir::new().unwrap();
let chunk_path = dir.path().join("empty_chunk.keyed");
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
{
let writer = PooledChunkWriter::<TemplateKey>::new(Arc::clone(&pool), &chunk_path)
.expect("create writer");
writer.finish().expect("finish empty writer");
}
assert!(chunk_path.exists());
let metadata = std::fs::metadata(&chunk_path).expect("stat file");
assert!(metadata.len() > 0, "file should not be empty (has EOF marker)");
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
}
#[test]
#[allow(clippy::cast_possible_truncation)]
fn test_pooled_writer_large_records() {
let dir = TempDir::new().unwrap();
let chunk_path = dir.path().join("large_chunk.keyed");
let pool = Arc::new(SortWorkerPool::new(4, 1, 6));
let records: Vec<(TemplateKey, Vec<u8>)> = (0..500)
.map(|i| {
let key = make_key(i);
let record = vec![(i % 256) as u8; 1000];
(key, record)
})
.collect();
{
let mut writer = PooledChunkWriter::<TemplateKey>::new(Arc::clone(&pool), &chunk_path)
.expect("create writer");
for (key, record) in &records {
writer.write_record(key, record).expect("write record");
}
writer.finish().expect("finish writer");
}
let mut reader =
GenericKeyedChunkReader::<TemplateKey>::open(&chunk_path, None).expect("open reader");
let mut buf = Vec::new();
let mut count = 0;
while let Some(key) = reader.next_record(&mut buf).expect("read record") {
assert_eq!(key, records[count].0, "key mismatch at {count}");
assert_eq!(buf, records[count].1, "data mismatch at {count}");
count += 1;
}
assert_eq!(count, records.len());
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
}
#[test]
#[allow(clippy::cast_possible_truncation)]
fn test_start_finish_and_wait() {
let dir = TempDir::new().unwrap();
let chunk_path = dir.path().join("pipelined_chunk.keyed");
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
let records: Vec<(TemplateKey, Vec<u8>)> =
(0..50).map(|i| (make_key(i), vec![(i % 256) as u8; 100])).collect();
let handle = {
let mut writer = PooledChunkWriter::<TemplateKey>::new(Arc::clone(&pool), &chunk_path)
.expect("create writer");
for (key, record) in &records {
writer.write_record(key, record).expect("write record");
}
writer.start_finish().expect("start_finish")
};
handle.wait().expect("wait should succeed");
let mut reader =
GenericKeyedChunkReader::<TemplateKey>::open(&chunk_path, None).expect("open reader");
let mut buf = Vec::new();
let mut count = 0;
while let Some(key) = reader.next_record(&mut buf).expect("read record") {
assert_eq!(key, records[count].0, "key mismatch at {count}");
count += 1;
}
assert_eq!(count, records.len());
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
}
#[test]
fn test_spill_write_handle_drop_without_wait() {
let dir = TempDir::new().unwrap();
let chunk_path = dir.path().join("dropped_chunk.keyed");
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
let handle = {
let mut writer = PooledChunkWriter::<TemplateKey>::new(Arc::clone(&pool), &chunk_path)
.expect("create writer");
writer.write_record(&make_key(0), &[1, 2, 3]).expect("write");
writer.start_finish().expect("start_finish")
};
drop(handle);
assert!(chunk_path.exists());
assert!(std::fs::metadata(&chunk_path).unwrap().len() > 0);
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
}
#[test]
fn test_drop_before_finish() {
let dir = TempDir::new().unwrap();
let chunk_path = dir.path().join("dropped_writer.keyed");
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
{
let mut writer = PooledChunkWriter::<TemplateKey>::new(Arc::clone(&pool), &chunk_path)
.expect("create writer");
writer.write_record(&make_key(0), &[1, 2, 3]).expect("write");
}
assert!(chunk_path.exists());
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
}
}