use crate::sort::worker_pool::{
BufferPool, CompressJob, CompressResult, PermitPool, SortWorkerPool,
};
use anyhow::Result;
use crossbeam_channel::{Receiver, Sender};
use fgumi_bgzf::{BGZF_EOF, BGZF_MAX_BLOCK_SIZE};
use std::collections::BTreeMap;
use std::io::{BufWriter, Write};
use std::sync::Arc;
const STAGING_PADDING: usize = 4096;
pub(crate) struct StagingBuffer {
pool: Arc<SortWorkerPool>,
buf: Vec<u8>,
next_serial: u64,
result_tx: Sender<CompressResult>,
permit_pool: Arc<PermitPool>,
}
impl StagingBuffer {
#[must_use]
pub(crate) fn new(
pool: Arc<SortWorkerPool>,
result_tx: Sender<CompressResult>,
permit_pool: Arc<PermitPool>,
) -> Self {
Self {
pool,
buf: Vec::with_capacity(BGZF_MAX_BLOCK_SIZE + STAGING_PADDING),
next_serial: 0,
result_tx,
permit_pool,
}
}
pub(crate) fn buf(&mut self) -> &mut Vec<u8> {
&mut self.buf
}
#[inline]
pub(crate) fn is_full(&self) -> bool {
self.buf.len() >= BGZF_MAX_BLOCK_SIZE
}
pub(crate) fn flush(&mut self) -> anyhow::Result<()> {
if self.buf.is_empty() {
return Ok(());
}
self.permit_pool.acquire()?;
let data = std::mem::replace(&mut self.buf, self.pool.buffer_pool.checkout());
if self.buf.capacity() < BGZF_MAX_BLOCK_SIZE + STAGING_PADDING {
self.buf.reserve(BGZF_MAX_BLOCK_SIZE + STAGING_PADDING - self.buf.capacity());
}
let serial = self.next_serial;
self.next_serial += 1;
self.pool.submit_compress(CompressJob { data, serial, result_tx: self.result_tx.clone() });
Ok(())
}
#[inline]
pub(crate) fn flush_if_full(&mut self) -> anyhow::Result<()> {
if self.is_full() { self.flush() } else { Ok(()) }
}
pub(crate) fn write_chunked(&mut self, data: &[u8]) -> anyhow::Result<()> {
let mut remaining = data;
while !remaining.is_empty() {
let space = BGZF_MAX_BLOCK_SIZE.saturating_sub(self.buf.len());
let n = remaining.len().min(space);
self.buf.extend_from_slice(&remaining[..n]);
remaining = &remaining[n..];
self.flush_if_full()?;
}
Ok(())
}
}
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn io_writer_loop(
mut writer: BufWriter<std::fs::File>,
result_rx: Receiver<CompressResult>,
buffer_pool: BufferPool,
permit_pool: Arc<PermitPool>,
) -> Result<()> {
let result = io_writer_loop_inner(&mut writer, &result_rx, &buffer_pool, &permit_pool);
if result.is_err() {
permit_pool.close();
}
result
}
fn io_writer_loop_inner(
writer: &mut BufWriter<std::fs::File>,
result_rx: &Receiver<CompressResult>,
buffer_pool: &BufferPool,
permit_pool: &Arc<PermitPool>,
) -> Result<()> {
let mut next_expected: u64 = 0;
let mut reorder_buf: BTreeMap<u64, Vec<u8>> = BTreeMap::new();
while let Ok(result) = result_rx.recv() {
buffer_pool.checkin(result.recycled_buf);
if result.serial == next_expected {
writer.write_all(&result.compressed)?;
permit_pool.release();
next_expected += 1;
while let Some(data) = reorder_buf.remove(&next_expected) {
writer.write_all(&data)?;
permit_pool.release();
next_expected += 1;
}
} else {
reorder_buf.insert(result.serial, result.compressed);
}
}
while let Some((&serial, _)) = reorder_buf.first_key_value() {
if serial == next_expected {
let data = reorder_buf.remove(&serial).expect("key just checked");
writer.write_all(&data)?;
permit_pool.release();
next_expected += 1;
} else {
return Err(anyhow::anyhow!(
"missing compressed block {next_expected}: next available is {serial}; \
the output would be silently truncated"
));
}
}
writer.write_all(&BGZF_EOF)?;
writer.flush()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tempfile::TempDir;
fn make_permit_pool(pool: &Arc<SortWorkerPool>) -> Arc<PermitPool> {
Arc::new(PermitPool::new(pool.num_workers() * 4))
}
fn roundtrip_data(data: &[u8]) -> Vec<u8> {
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
let (result_tx, result_rx) = pool.compress_result_channel();
let buffer_pool = pool.buffer_pool.clone();
let permit_pool = make_permit_pool(&pool);
let dir = TempDir::new().unwrap();
let out_path = dir.path().join("out.bgzf");
let out_file = std::fs::File::create(&out_path).unwrap();
let writer = std::io::BufWriter::new(out_file);
let pp = Arc::clone(&permit_pool);
let io_handle =
std::thread::spawn(move || io_writer_loop(writer, result_rx, buffer_pool, pp));
let mut staging = StagingBuffer::new(Arc::clone(&pool), result_tx, permit_pool);
staging.write_chunked(data).unwrap();
staging.flush().unwrap();
drop(staging);
io_handle.join().unwrap().unwrap();
if let Ok(p) = Arc::try_unwrap(pool) {
p.shutdown();
}
std::fs::read(&out_path).unwrap()
}
#[test]
fn test_staging_buffer_flush_empty_is_noop() {
let pool = Arc::new(SortWorkerPool::new(1, 1, 6));
let (result_tx, _result_rx) = pool.compress_result_channel();
let permit_pool = make_permit_pool(&pool);
let mut staging = StagingBuffer::new(Arc::clone(&pool), result_tx, permit_pool);
staging.flush().unwrap();
assert_eq!(
pool.stats.compress_jobs_submitted.load(std::sync::atomic::Ordering::Relaxed),
0
);
if let Ok(p) = Arc::try_unwrap(pool) {
p.shutdown();
}
}
#[test]
fn test_staging_buffer_is_full() {
let pool = Arc::new(SortWorkerPool::new(1, 1, 6));
let (result_tx, _result_rx) = pool.compress_result_channel();
let permit_pool = make_permit_pool(&pool);
let mut staging = StagingBuffer::new(Arc::clone(&pool), result_tx, permit_pool);
assert!(!staging.is_full(), "empty buffer should not be full");
staging.buf().extend(vec![0u8; BGZF_MAX_BLOCK_SIZE]);
assert!(staging.is_full(), "buffer at BGZF_MAX_BLOCK_SIZE should be full");
if let Ok(p) = Arc::try_unwrap(pool) {
p.shutdown();
}
}
#[test]
fn test_staging_buffer_write_chunked_large_data() {
let large = vec![b'A'; BGZF_MAX_BLOCK_SIZE * 2 + 1000];
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
let (result_tx, result_rx) = pool.compress_result_channel();
let buffer_pool = pool.buffer_pool.clone();
let permit_pool = make_permit_pool(&pool);
let dir = TempDir::new().unwrap();
let out_path = dir.path().join("large.bgzf");
let out_file = std::fs::File::create(&out_path).unwrap();
let writer = std::io::BufWriter::new(out_file);
let pp = Arc::clone(&permit_pool);
let io_handle =
std::thread::spawn(move || io_writer_loop(writer, result_rx, buffer_pool, pp));
let mut staging = StagingBuffer::new(Arc::clone(&pool), result_tx, permit_pool);
staging.write_chunked(&large).unwrap();
staging.flush().unwrap();
drop(staging);
io_handle.join().unwrap().unwrap();
assert!(
pool.stats.compress_jobs_submitted.load(std::sync::atomic::Ordering::Relaxed) >= 2,
"expected multiple compress jobs for data > BGZF_MAX_BLOCK_SIZE"
);
if let Ok(p) = Arc::try_unwrap(pool) {
p.shutdown();
}
}
#[test]
fn test_io_writer_loop_reorders_out_of_order_blocks() {
let data1 = b"first block data".to_vec();
let data2 = b"second block data".to_vec();
let pool = Arc::new(SortWorkerPool::new(2, 1, 6));
let (result_tx, result_rx) = pool.compress_result_channel();
let buffer_pool = pool.buffer_pool.clone();
let permit_pool = Arc::new(PermitPool::new(4));
let dir = TempDir::new().unwrap();
let out_path = dir.path().join("reorder.bgzf");
let out_file = std::fs::File::create(&out_path).unwrap();
let writer = std::io::BufWriter::new(out_file);
let pp = Arc::clone(&permit_pool);
let io_handle =
std::thread::spawn(move || io_writer_loop(writer, result_rx, buffer_pool, pp));
permit_pool.acquire().unwrap();
pool.submit_compress(CompressJob { data: data2, serial: 1, result_tx: result_tx.clone() });
permit_pool.acquire().unwrap();
pool.submit_compress(CompressJob { data: data1, serial: 0, result_tx });
io_handle.join().unwrap().unwrap();
let bytes = std::fs::read(&out_path).unwrap();
assert!(bytes.ends_with(&BGZF_EOF), "output should end with BGZF EOF marker");
if let Ok(p) = Arc::try_unwrap(pool) {
p.shutdown();
}
}
#[test]
fn test_roundtrip_small_data() {
let data = b"hello world from bgzf_io";
let output = roundtrip_data(data);
assert!(output.ends_with(&BGZF_EOF), "must end with BGZF EOF");
assert!(output.len() > BGZF_EOF.len());
}
#[test]
fn test_roundtrip_empty_data() {
let output = roundtrip_data(b"");
assert_eq!(output, BGZF_EOF.to_vec(), "empty input → only BGZF EOF marker");
}
}