use anyhow::{anyhow, Context, Result};
use arrow::array::{Array, LargeStringArray, RecordBatch, StringArray};
use arrow::datatypes::DataType;
use parquet::arrow::arrow_reader::{
ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReaderBuilder,
};
use rayon::prelude::*;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::sync::mpsc::{self, Receiver, RecvTimeoutError, SyncSender};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;
use rype::{FirstErrorCapture, QueryRecord};
use super::fastx_io::OwnedFastxRecord;
pub fn is_parquet_input(path: &Path) -> bool {
path.extension()
.map(|ext| ext.eq_ignore_ascii_case("parquet"))
.unwrap_or(false)
}
#[allow(dead_code)]
pub struct ParquetInputReader {
reader: parquet::arrow::arrow_reader::ParquetRecordBatchReader,
is_paired: bool,
current_batch: Option<RecordBatch>,
current_idx: usize,
global_record_id: i64,
}
#[allow(dead_code)]
impl ParquetInputReader {
fn is_string_type(dt: &DataType) -> bool {
matches!(dt, DataType::Utf8 | DataType::LargeUtf8)
}
pub fn new(path: &Path) -> Result<Self> {
let file =
File::open(path).with_context(|| format!("Failed to open Parquet file: {:?}", path))?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)
.context("Failed to create Parquet reader")?;
let schema = builder.schema();
let read_id_field = schema
.fields()
.iter()
.find(|f| f.name() == "read_id")
.ok_or_else(|| anyhow!("Parquet input missing required column 'read_id'"))?;
if !Self::is_string_type(read_id_field.data_type()) {
return Err(anyhow!(
"Column 'read_id' must be string type (Utf8 or LargeUtf8), got {:?}",
read_id_field.data_type()
));
}
let sequence1_field = schema
.fields()
.iter()
.find(|f| f.name() == "sequence1")
.ok_or_else(|| anyhow!("Parquet input missing required column 'sequence1'"))?;
if !Self::is_string_type(sequence1_field.data_type()) {
return Err(anyhow!(
"Column 'sequence1' must be string type (Utf8 or LargeUtf8), got {:?}",
sequence1_field.data_type()
));
}
let has_sequence2 =
if let Some(field) = schema.fields().iter().find(|f| f.name() == "sequence2") {
if !Self::is_string_type(field.data_type()) {
return Err(anyhow!(
"Column 'sequence2' must be string type (Utf8 or LargeUtf8), got {:?}",
field.data_type()
));
}
true
} else {
false
};
let mut reader = builder.build().context("Failed to build Parquet reader")?;
let first_batch = reader.next();
let (is_paired, current_batch) = match first_batch {
Some(Ok(batch)) => {
let is_paired = if has_sequence2 {
if let Some(col) = batch.column_by_name("sequence2") {
col.null_count() < col.len()
} else {
false
}
} else {
false
};
(is_paired, Some(batch))
}
Some(Err(e)) => return Err(anyhow!("Error reading first Parquet batch: {}", e)),
None => (false, None), };
log::info!(
"Parquet input '{}': detected as {} data",
path.display(),
if is_paired {
"paired-end"
} else {
"single-end"
}
);
Ok(Self {
reader,
is_paired,
current_batch,
current_idx: 0,
global_record_id: 0,
})
}
#[allow(dead_code)]
pub fn is_paired(&self) -> bool {
self.is_paired
}
pub fn next_batch(
&mut self,
batch_size: usize,
) -> Result<Option<(Vec<OwnedFastxRecord>, Vec<String>)>> {
let mut records = Vec::with_capacity(batch_size);
let mut headers = Vec::with_capacity(batch_size);
while records.len() < batch_size {
if let Some(ref batch) = self.current_batch {
if self.current_idx < batch.num_rows() {
let read_id_col = batch
.column_by_name("read_id")
.ok_or_else(|| anyhow!("Missing read_id column"))?;
let sequence1_col = batch
.column_by_name("sequence1")
.ok_or_else(|| anyhow!("Missing sequence1 column"))?;
let read_id_arr = read_id_col
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow!("read_id column is not a string array"))?;
let sequence1_arr = sequence1_col
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow!("sequence1 column is not a string array"))?;
let sequence2_arr = if self.is_paired {
batch
.column_by_name("sequence2")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
} else {
None
};
let idx = self.current_idx;
let read_id = read_id_arr.value(idx).to_string();
let sequence1 = sequence1_arr.value(idx).as_bytes().to_vec();
let sequence2 = sequence2_arr
.filter(|arr| !arr.is_null(idx))
.map(|arr| arr.value(idx).as_bytes().to_vec());
records.push(OwnedFastxRecord::new(
records.len() as i64,
sequence1,
None, sequence2,
None, ));
headers.push(read_id);
self.global_record_id += 1;
self.current_idx += 1;
continue;
}
}
match self.reader.next() {
Some(Ok(batch)) => {
self.current_batch = Some(batch);
self.current_idx = 0;
}
Some(Err(e)) => return Err(anyhow!("Error reading Parquet batch: {}", e)),
None => {
break;
}
}
}
if records.is_empty() {
Ok(None)
} else {
Ok(Some((records, headers)))
}
}
}
pub enum ParquetBatch {
Arrow(RecordBatch, Vec<String>),
Owned(Vec<OwnedFastxRecord>, Vec<String>),
}
impl ParquetBatch {
pub fn into_arrow(self) -> (RecordBatch, Vec<String>) {
match self {
ParquetBatch::Arrow(batch, headers) => (batch, headers),
ParquetBatch::Owned(..) => {
panic!("Expected ParquetBatch::Arrow but got Owned")
}
}
}
#[allow(dead_code)]
pub fn into_owned(self) -> (Vec<OwnedFastxRecord>, Vec<String>) {
match self {
ParquetBatch::Owned(records, headers) => (records, headers),
ParquetBatch::Arrow(..) => {
panic!("Expected ParquetBatch::Owned but got Arrow")
}
}
}
}
type ParquetBatchResult = Result<Option<ParquetBatch>>;
const DEFAULT_PARQUET_PREFETCH_TIMEOUT: Duration = Duration::from_secs(300);
pub struct PrefetchingParquetReader {
receiver: Receiver<ParquetBatchResult>,
prefetch_thread: Option<JoinHandle<()>>,
error_capture: Arc<FirstErrorCapture>,
is_paired: bool,
timeout: Duration,
}
impl PrefetchingParquetReader {
#[allow(dead_code)]
pub fn new(path: &Path, batch_size: usize) -> Result<Self> {
Self::with_parallel_row_groups(path, batch_size, None, None, None)
}
pub fn with_parallel_row_groups(
path: &Path,
batch_size: usize,
parallel_row_groups: Option<usize>,
trim_to: Option<usize>,
minimum_length: Option<usize>,
) -> Result<Self> {
let path = path.to_path_buf();
let error_capture = Arc::new(FirstErrorCapture::new());
let thread_error = Arc::clone(&error_capture);
let (sender, receiver): (SyncSender<ParquetBatchResult>, Receiver<ParquetBatchResult>) =
mpsc::sync_channel(4);
let file = File::open(&path)
.with_context(|| format!("Failed to open Parquet file: {:?}", path))?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)
.context("Failed to create Parquet reader")?;
let schema = builder.schema().clone();
let (col_indices, has_sequence2) = Self::validate_and_get_projection(&schema)?;
let prefetch_thread = if let Some(parallel_rg) = parallel_row_groups.filter(|&n| n > 0) {
log::info!(
"Using parallel Parquet row group reading (parallelism={})",
parallel_rg
);
thread::spawn(move || {
Self::reader_thread_parallel(
path,
batch_size,
col_indices,
parallel_rg,
trim_to,
minimum_length,
sender,
thread_error,
);
})
} else {
thread::spawn(move || {
Self::reader_thread(
path,
batch_size,
col_indices,
trim_to,
minimum_length,
sender,
thread_error,
);
})
};
Ok(Self {
receiver,
prefetch_thread: Some(prefetch_thread),
error_capture,
is_paired: has_sequence2,
timeout: DEFAULT_PARQUET_PREFETCH_TIMEOUT,
})
}
fn validate_and_get_projection(
schema: &arrow::datatypes::Schema,
) -> Result<(Vec<usize>, bool)> {
fn is_string_type(dt: &DataType) -> bool {
matches!(dt, DataType::Utf8 | DataType::LargeUtf8)
}
let read_id_idx = schema
.fields()
.iter()
.position(|f| f.name() == "read_id")
.ok_or_else(|| anyhow!("Parquet input missing required column 'read_id'"))?;
if !is_string_type(schema.field(read_id_idx).data_type()) {
return Err(anyhow!(
"Column 'read_id' must be string type (Utf8 or LargeUtf8), got {:?}",
schema.field(read_id_idx).data_type()
));
}
let sequence1_idx = schema
.fields()
.iter()
.position(|f| f.name() == "sequence1")
.ok_or_else(|| anyhow!("Parquet input missing required column 'sequence1'"))?;
if !is_string_type(schema.field(sequence1_idx).data_type()) {
return Err(anyhow!(
"Column 'sequence1' must be string type (Utf8 or LargeUtf8), got {:?}",
schema.field(sequence1_idx).data_type()
));
}
let sequence2_idx = schema.fields().iter().position(|f| f.name() == "sequence2");
if let Some(idx) = sequence2_idx {
if !is_string_type(schema.field(idx).data_type()) {
return Err(anyhow!(
"Column 'sequence2' must be string type (Utf8 or LargeUtf8), got {:?}",
schema.field(idx).data_type()
));
}
}
let mut col_indices = vec![read_id_idx, sequence1_idx];
let has_sequence2 = sequence2_idx.is_some();
if let Some(idx) = sequence2_idx {
col_indices.push(idx);
}
Ok((col_indices, has_sequence2))
}
fn reader_thread(
path: PathBuf,
_batch_size: usize, col_indices: Vec<usize>,
trim_to: Option<usize>,
minimum_length: Option<usize>,
sender: SyncSender<ParquetBatchResult>,
error_capture: Arc<FirstErrorCapture>,
) {
let needs_trim_filter = trim_to.is_some() || minimum_length.is_some();
macro_rules! send_error {
($msg:expr) => {{
let err_msg = $msg;
if sender.send(Err(anyhow!("{}", &err_msg))).is_err() {
error_capture.store_msg(&err_msg);
}
return;
}};
}
let file = match File::open(&path) {
Ok(f) => f,
Err(e) => {
send_error!(format!("Failed to open Parquet file: {}", e));
}
};
let builder = match ParquetRecordBatchReaderBuilder::try_new(file) {
Ok(b) => b,
Err(e) => {
send_error!(format!("Failed to create Parquet reader: {}", e));
}
};
let projection =
parquet::arrow::ProjectionMask::roots(builder.parquet_schema(), col_indices);
let reader = match builder.with_projection(projection).build() {
Ok(r) => r,
Err(e) => {
send_error!(format!("Failed to build Parquet reader: {}", e));
}
};
for batch_result in reader {
let batch = match batch_result {
Ok(b) => b,
Err(e) => {
send_error!(format!("Error reading Parquet batch: {}", e));
}
};
let headers = match Self::extract_headers(&batch) {
Ok(h) => h,
Err(e) => {
send_error!(format!("Error extracting headers: {}", e));
}
};
let parquet_batch = if needs_trim_filter {
match batch_to_owned_records_trimmed(&batch, &headers, trim_to, minimum_length, 0) {
Ok((records, filtered_headers)) => {
ParquetBatch::Owned(records, filtered_headers)
}
Err(e) => {
send_error!(format!("Error trimming batch: {}", e));
}
}
} else {
ParquetBatch::Arrow(batch, headers)
};
if sender.send(Ok(Some(parquet_batch))).is_err() {
return;
}
}
let _ = sender.send(Ok(None));
}
#[allow(clippy::too_many_arguments)]
fn reader_thread_parallel(
path: PathBuf,
_batch_size: usize, col_indices: Vec<usize>,
parallel_rg: usize,
trim_to: Option<usize>,
minimum_length: Option<usize>,
sender: SyncSender<ParquetBatchResult>,
error_capture: Arc<FirstErrorCapture>,
) {
let needs_trim_filter = trim_to.is_some() || minimum_length.is_some();
macro_rules! send_error {
($msg:expr) => {{
let err_msg = $msg;
if sender.send(Err(anyhow!("{}", &err_msg))).is_err() {
error_capture.store_msg(&err_msg);
}
return;
}};
}
let file = match File::open(&path) {
Ok(f) => f,
Err(e) => {
send_error!(format!("Failed to open Parquet file: {}", e));
}
};
let initial_metadata = match ArrowReaderMetadata::load(&file, ArrowReaderOptions::default())
{
Ok(m) => m,
Err(e) => {
send_error!(format!("Failed to load Parquet metadata: {}", e));
}
};
let num_row_groups = initial_metadata.metadata().num_row_groups();
if num_row_groups == 0 {
let _ = sender.send(Ok(None));
return;
}
drop(file);
log::debug!(
"Parallel Parquet reader: {} row groups, parallelism={}",
num_row_groups,
parallel_rg
);
let col_indices = Arc::new(col_indices);
for chunk_start in (0..num_row_groups).step_by(parallel_rg) {
let chunk_end = (chunk_start + parallel_rg).min(num_row_groups);
let rg_indices: Vec<usize> = (chunk_start..chunk_end).collect();
#[allow(clippy::type_complexity)]
let chunk_results: Result<Vec<(usize, Vec<ParquetBatch>)>, String> = rg_indices
.into_par_iter()
.map(|rg_idx| {
let file = File::open(&path)
.map_err(|e| format!("Failed to open file for RG {}: {}", rg_idx, e))?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)
.map_err(|e| format!("Failed to create reader for RG {}: {}", rg_idx, e))?;
let projection = parquet::arrow::ProjectionMask::roots(
builder.parquet_schema(),
col_indices.iter().copied(),
);
let reader = builder
.with_row_groups(vec![rg_idx])
.with_projection(projection)
.build()
.map_err(|e| format!("Failed to build reader for RG {}: {}", rg_idx, e))?;
let mut batches = Vec::new();
for batch_result in reader {
let batch = batch_result.map_err(|e| {
format!("Error reading batch from RG {}: {}", rg_idx, e)
})?;
let headers = Self::extract_headers(&batch).map_err(|e| {
format!("Error extracting headers from RG {}: {}", rg_idx, e)
})?;
if needs_trim_filter {
let (records, filtered_headers) = batch_to_owned_records_trimmed(
&batch,
&headers,
trim_to,
minimum_length,
0,
)
.map_err(|e| {
format!("Error trimming batch from RG {}: {}", rg_idx, e)
})?;
batches.push(ParquetBatch::Owned(records, filtered_headers));
} else {
batches.push(ParquetBatch::Arrow(batch, headers));
}
}
Ok((rg_idx, batches))
})
.collect();
let mut sorted_results = match chunk_results {
Ok(results) => results,
Err(e) => {
error_capture.store_msg(&e);
return;
}
};
sorted_results.sort_by_key(|(idx, _)| *idx);
for (_, batches) in sorted_results {
for parquet_batch in batches {
if sender.send(Ok(Some(parquet_batch))).is_err() {
return;
}
}
}
}
let _ = sender.send(Ok(None));
}
fn extract_headers(batch: &RecordBatch) -> Result<Vec<String>> {
let col = batch
.column_by_name("read_id")
.ok_or_else(|| anyhow!("Missing read_id column"))?;
if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
let mut headers = Vec::with_capacity(batch.num_rows());
for i in 0..batch.num_rows() {
headers.push(arr.value(i).to_string());
}
return Ok(headers);
}
if let Some(arr) = col.as_any().downcast_ref::<LargeStringArray>() {
let mut headers = Vec::with_capacity(batch.num_rows());
for i in 0..batch.num_rows() {
headers.push(arr.value(i).to_string());
}
return Ok(headers);
}
Err(anyhow!(
"read_id column is not a string type: {:?}",
col.data_type()
))
}
#[allow(dead_code)]
pub fn is_paired(&self) -> bool {
self.is_paired
}
pub fn next_batch(&mut self) -> Result<Option<ParquetBatch>> {
match self.receiver.recv_timeout(self.timeout) {
Ok(result) => result,
Err(RecvTimeoutError::Timeout) => {
if let Some(err) = self.error_capture.get() {
return Err(anyhow!("Reader thread error: {}", err));
}
Err(anyhow!(
"Timeout waiting for next batch ({}s) - reader thread may be stalled",
self.timeout.as_secs()
))
}
Err(RecvTimeoutError::Disconnected) => {
if let Some(err) = self.error_capture.get() {
return Err(anyhow!("Reader thread error: {}", err));
}
if let Some(handle) = self.prefetch_thread.take() {
match handle.join() {
Ok(()) => Err(anyhow!("Prefetch thread exited unexpectedly")),
Err(_) => Err(anyhow!("Prefetch thread panicked")),
}
} else {
Err(anyhow!("Prefetch channel closed"))
}
}
}
}
pub fn finish(&mut self) -> Result<()> {
if let Some(handle) = self.prefetch_thread.take() {
handle
.join()
.map_err(|_| anyhow!("Prefetch thread panicked"))?;
}
Ok(())
}
}
enum SequenceColumnRef<'a> {
String(&'a StringArray),
LargeString(&'a LargeStringArray),
}
impl<'a> SequenceColumnRef<'a> {
#[inline]
fn value(&self, i: usize) -> &'a [u8] {
match self {
SequenceColumnRef::String(arr) => arr.value(i).as_bytes(),
SequenceColumnRef::LargeString(arr) => arr.value(i).as_bytes(),
}
}
#[inline]
fn is_null(&self, i: usize) -> bool {
match self {
SequenceColumnRef::String(arr) => arr.is_null(i),
SequenceColumnRef::LargeString(arr) => arr.is_null(i),
}
}
}
fn get_string_column<'a>(batch: &'a RecordBatch, col_name: &str) -> Result<SequenceColumnRef<'a>> {
let col = batch
.column_by_name(col_name)
.ok_or_else(|| anyhow!("Missing {} column", col_name))?;
if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
return Ok(SequenceColumnRef::String(arr));
}
if let Some(arr) = col.as_any().downcast_ref::<LargeStringArray>() {
return Ok(SequenceColumnRef::LargeString(arr));
}
Err(anyhow!(
"{} column is not a string type: {:?}",
col_name,
col.data_type()
))
}
#[allow(dead_code)]
pub fn batch_to_records_parquet(batch: &RecordBatch) -> Result<Vec<QueryRecord<'_>>> {
batch_to_records_parquet_with_offset(batch, 0)
}
pub fn batch_to_records_parquet_with_offset(
batch: &RecordBatch,
id_offset: usize,
) -> Result<Vec<QueryRecord<'_>>> {
let num_rows = batch.num_rows();
if num_rows == 0 {
return Ok(Vec::new());
}
let seq_col = get_string_column(batch, "sequence1")?;
let pair_col = batch
.column_by_name("sequence2")
.map(|_| get_string_column(batch, "sequence2"))
.transpose()?;
let mut records = Vec::with_capacity(num_rows);
for i in 0..num_rows {
let sum = id_offset
.checked_add(i)
.ok_or_else(|| anyhow!("Query ID overflow: offset {} + index {}", id_offset, i))?;
let query_id =
i64::try_from(sum).map_err(|_| anyhow!("Query ID {} exceeds i64::MAX", sum))?;
let seq = seq_col.value(i);
let pair = pair_col
.as_ref()
.and_then(|p| if p.is_null(i) { None } else { Some(p.value(i)) });
records.push((query_id, seq, pair));
}
Ok(records)
}
pub fn batch_to_owned_records_trimmed(
batch: &RecordBatch,
headers: &[String],
trim_to: Option<usize>,
minimum_length: Option<usize>,
id_offset: usize,
) -> Result<(Vec<OwnedFastxRecord>, Vec<String>)> {
let num_rows = batch.num_rows();
if num_rows == 0 {
return Ok((Vec::new(), Vec::new()));
}
let seq_col = get_string_column(batch, "sequence1")?;
let pair_col = batch
.column_by_name("sequence2")
.map(|_| get_string_column(batch, "sequence2"))
.transpose()?;
let mut records = Vec::with_capacity(num_rows);
let mut out_headers = Vec::with_capacity(num_rows);
#[allow(clippy::needless_range_loop)]
for i in 0..num_rows {
let seq1 = seq_col.value(i);
if let Some(min_len) = minimum_length {
if seq1.len() < min_len {
continue;
}
}
if let Some(trim_len) = trim_to {
if seq1.len() < trim_len {
continue; }
}
let sum = id_offset.checked_add(records.len()).ok_or_else(|| {
anyhow!(
"Query ID overflow: offset {} + count {}",
id_offset,
records.len()
)
})?;
let query_id =
i64::try_from(sum).map_err(|_| anyhow!("Query ID {} exceeds i64::MAX", sum))?;
let seq1_owned = match trim_to {
Some(trim_len) => seq1[..trim_len.min(seq1.len())].to_vec(),
None => seq1.to_vec(),
};
let seq2_owned = pair_col.as_ref().and_then(|p| {
if p.is_null(i) {
None
} else {
let seq2 = p.value(i);
match trim_to {
Some(trim_len) => Some(seq2[..trim_len.min(seq2.len())].to_vec()),
None => Some(seq2.to_vec()),
}
}
});
records.push(OwnedFastxRecord::new(
query_id, seq1_owned, None, seq2_owned, None, ));
out_headers.push(headers[i].clone());
}
debug_assert_eq!(
records.len(),
out_headers.len(),
"Records and headers must stay synchronized"
);
Ok((records, out_headers))
}
pub struct TrimmedBatchResult {
pub records: Vec<OwnedFastxRecord>,
pub headers: Vec<String>,
pub rg_count: usize,
pub reached_end: bool,
}
pub fn accumulate_owned_batches(
reader: &mut PrefetchingParquetReader,
target_batch_size: usize,
) -> Result<TrimmedBatchResult> {
let mut records: Vec<OwnedFastxRecord> = Vec::new();
let mut headers: Vec<String> = Vec::new();
let mut reached_end = false;
let mut rg_count = 0usize;
while records.len() < target_batch_size {
match reader.next_batch()? {
Some(ParquetBatch::Owned(mut batch_records, batch_headers)) => {
rg_count += 1;
let offset = records.len() as i64;
for rec in &mut batch_records {
rec.query_id += offset;
}
records.extend(batch_records);
headers.extend(batch_headers);
}
Some(ParquetBatch::Arrow(..)) => {
unreachable!("Expected Owned variant when trim/filter is active");
}
None => {
reached_end = true;
break;
}
}
}
Ok(TrimmedBatchResult {
records,
headers,
rg_count,
reached_end,
})
}
pub fn stacked_batches_to_records<'a>(
batches: &'a [(RecordBatch, Vec<String>)],
) -> Result<(Vec<QueryRecord<'a>>, Vec<&'a str>)> {
let total_rows: usize = batches.iter().map(|(b, _)| b.num_rows()).sum();
let mut all_records = Vec::with_capacity(total_rows);
let mut all_headers = Vec::with_capacity(total_rows);
let mut offset = 0usize;
for (batch, headers) in batches {
let records = batch_to_records_parquet_with_offset(batch, offset)?;
all_records.extend(records);
all_headers.extend(headers.iter().map(|s| s.as_str()));
offset += batch.num_rows();
}
Ok((all_records, all_headers))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::LargeStringArray;
use arrow::datatypes::{DataType, Field, Schema};
use std::path::Path;
use std::sync::Arc;
#[test]
fn test_is_parquet_input() {
assert!(is_parquet_input(Path::new("input.parquet")));
assert!(is_parquet_input(Path::new("input.PARQUET")));
assert!(is_parquet_input(Path::new("/path/to/input.parquet")));
assert!(!is_parquet_input(Path::new("input.fastq")));
assert!(!is_parquet_input(Path::new("input.fasta")));
assert!(!is_parquet_input(Path::new("input.parquet.gz")));
}
fn make_test_batch(seqs: &[&str]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"sequence1",
DataType::LargeUtf8,
false,
)]));
let seq_array = LargeStringArray::from_iter_values(seqs.iter().copied());
RecordBatch::try_new(schema, vec![Arc::new(seq_array)]).unwrap()
}
fn make_test_batch_paired(seqs1: &[&str], seqs2: &[Option<&str>]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("sequence1", DataType::LargeUtf8, false),
Field::new("sequence2", DataType::LargeUtf8, true),
]));
let seq1_array = LargeStringArray::from_iter_values(seqs1.iter().copied());
let seq2_array = LargeStringArray::from_iter(seqs2.iter().copied());
RecordBatch::try_new(schema, vec![Arc::new(seq1_array), Arc::new(seq2_array)]).unwrap()
}
#[test]
fn test_batch_to_owned_records_trimmed_no_trim() {
let seqs = vec!["ACGTACGTACGT", "GGGGCCCCAAAA"];
let batch = make_test_batch(&seqs);
let headers = vec!["read1".to_string(), "read2".to_string()];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, None, None, 0).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(out_headers.len(), 2);
assert_eq!(records[0].seq1, b"ACGTACGTACGT");
assert_eq!(records[1].seq1, b"GGGGCCCCAAAA");
assert_eq!(out_headers[0], "read1");
assert_eq!(out_headers[1], "read2");
}
#[test]
fn test_batch_to_owned_records_trimmed_with_trim() {
let seqs = vec!["ACGTACGTACGT", "GGGGCCCCAAAA"];
let batch = make_test_batch(&seqs);
let headers = vec!["read1".to_string(), "read2".to_string()];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(4), None, 0).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(out_headers.len(), 2);
assert_eq!(records[0].seq1, b"ACGT");
assert_eq!(records[1].seq1, b"GGGG");
}
#[test]
fn test_batch_to_owned_records_trimmed_skip_short_reads() {
let seqs = vec!["ACGTACGTACGT", "GGGG"];
let batch = make_test_batch(&seqs);
let headers = vec!["long_read".to_string(), "short_read".to_string()];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(8), None, 0).unwrap();
assert_eq!(records.len(), 1, "Short read should be skipped");
assert_eq!(out_headers.len(), 1);
assert_eq!(records[0].seq1, b"ACGTACGT");
assert_eq!(out_headers[0], "long_read");
}
#[test]
fn test_batch_to_owned_records_trimmed_query_id_with_offset() {
let seqs = vec!["ACGTACGTACGT", "GGGGCCCCAAAA"];
let batch = make_test_batch(&seqs);
let headers = vec!["read1".to_string(), "read2".to_string()];
let (records, _) =
batch_to_owned_records_trimmed(&batch, &headers, None, None, 100).unwrap();
assert_eq!(records[0].query_id, 100, "First query_id should be offset");
assert_eq!(
records[1].query_id, 101,
"Second query_id should be offset+1"
);
}
#[test]
fn test_batch_to_owned_records_trimmed_query_id_with_skipped_reads() {
let seqs = vec!["ACGTACGTACGT", "GG", "TTTTTTTTTTTT"];
let batch = make_test_batch(&seqs);
let headers = vec![
"long1".to_string(),
"short".to_string(),
"long2".to_string(),
];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(8), None, 0).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].query_id, 0);
assert_eq!(records[1].query_id, 1);
assert_eq!(out_headers[0], "long1");
assert_eq!(out_headers[1], "long2");
}
#[test]
fn test_batch_to_owned_records_trimmed_paired_sequences() {
let seqs1 = vec!["ACGTACGTACGT", "GGGGCCCCAAAA"];
let seqs2: Vec<Option<&str>> = vec![Some("TTTTTTTTTTTT"), Some("CCCCCCCCCCCC")];
let batch = make_test_batch_paired(&seqs1, &seqs2);
let headers = vec!["read1".to_string(), "read2".to_string()];
let (records, _) =
batch_to_owned_records_trimmed(&batch, &headers, Some(4), None, 0).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].seq1, b"ACGT");
assert_eq!(records[0].seq2.as_ref().unwrap(), b"TTTT");
assert_eq!(records[1].seq1, b"GGGG");
assert_eq!(records[1].seq2.as_ref().unwrap(), b"CCCC");
}
#[test]
fn test_batch_to_owned_records_trimmed_empty_batch() {
let seqs: Vec<&str> = vec![];
let batch = make_test_batch(&seqs);
let headers: Vec<String> = vec![];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(100), None, 0).unwrap();
assert!(records.is_empty());
assert!(out_headers.is_empty());
}
#[test]
fn test_batch_to_owned_records_trimmed_all_reads_too_short() {
let seqs = vec!["ACGT", "GGGG", "TTTT"];
let batch = make_test_batch(&seqs);
let headers = vec![
"read1".to_string(),
"read2".to_string(),
"read3".to_string(),
];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(100), None, 0).unwrap();
assert!(records.is_empty(), "All reads should be skipped");
assert!(out_headers.is_empty());
}
#[test]
fn test_batch_to_owned_records_trimmed_records_headers_synchronized() {
let seqs = vec!["ACGTACGTACGT", "GG", "TTTTTTTTTTTT", "AA"];
let batch = make_test_batch(&seqs);
let headers = vec![
"keep1".to_string(),
"skip1".to_string(),
"keep2".to_string(),
"skip2".to_string(),
];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(8), None, 0).unwrap();
assert_eq!(
records.len(),
out_headers.len(),
"Records and headers must be synchronized"
);
assert_eq!(out_headers[0], "keep1");
assert_eq!(out_headers[1], "keep2");
}
#[test]
fn test_batch_to_owned_records_with_minimum_length() {
let s0 = "A".repeat(30);
let s1 = "G".repeat(80);
let s2 = "T".repeat(50);
let seqs_ref = vec![s0.as_str(), s1.as_str(), s2.as_str()];
let batch = make_test_batch(&seqs_ref);
let headers = vec![
"short30".to_string(),
"long80".to_string(),
"exact50".to_string(),
];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, None, Some(50), 0).unwrap();
assert_eq!(records.len(), 2, "30bp read should be skipped");
assert_eq!(out_headers.len(), 2);
assert_eq!(records[0].seq1.len(), 80);
assert_eq!(records[1].seq1.len(), 50);
assert_eq!(out_headers[0], "long80");
assert_eq!(out_headers[1], "exact50");
}
#[test]
fn test_batch_to_owned_records_minimum_length_before_trim() {
let s0 = "A".repeat(40);
let s1 = "G".repeat(100);
let s2 = "T".repeat(60);
let s3 = "C".repeat(80);
let seqs_ref = vec![s0.as_str(), s1.as_str(), s2.as_str(), s3.as_str()];
let batch = make_test_batch(&seqs_ref);
let headers = vec![
"short40".to_string(),
"long100".to_string(),
"mid60".to_string(),
"mid80".to_string(),
];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(70), Some(50), 0).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(out_headers[0], "long100");
assert_eq!(out_headers[1], "mid80");
assert_eq!(records[0].seq1.len(), 70, "100bp trimmed to 70");
assert_eq!(records[1].seq1.len(), 70, "80bp trimmed to 70");
}
#[test]
fn test_batch_to_owned_records_minimum_length_with_paired() {
let s1_short = "A".repeat(30);
let s1_long = "G".repeat(80);
let s2_long = "T".repeat(100);
let s2_short = "C".repeat(20);
let seqs1 = vec![s1_short.as_str(), s1_long.as_str()];
let seqs2: Vec<Option<&str>> = vec![Some(s2_long.as_str()), Some(s2_short.as_str())];
let batch = make_test_batch_paired(&seqs1, &seqs2);
let headers = vec!["pair_short_r1".to_string(), "pair_long_r1".to_string()];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, None, Some(50), 0).unwrap();
assert_eq!(records.len(), 1, "Pair with R1=30bp should be skipped");
assert_eq!(out_headers[0], "pair_long_r1");
assert_eq!(records[0].seq1.len(), 80);
assert_eq!(records[0].seq2.as_ref().unwrap().len(), 20);
}
#[test]
fn test_batch_to_owned_records_minimum_length_gt_trim_to() {
let s0 = "A".repeat(80); let s1 = "G".repeat(120); let s2 = "T".repeat(100); let s3 = "C".repeat(40); let seqs_ref = vec![s0.as_str(), s1.as_str(), s2.as_str(), s3.as_str()];
let batch = make_test_batch(&seqs_ref);
let headers = vec![
"r80".to_string(),
"r120".to_string(),
"r100".to_string(),
"r40".to_string(),
];
let (records, out_headers) =
batch_to_owned_records_trimmed(&batch, &headers, Some(50), Some(100), 0).unwrap();
assert_eq!(records.len(), 2, "Only reads >= 100bp should survive");
assert_eq!(out_headers[0], "r120");
assert_eq!(out_headers[1], "r100");
assert_eq!(records[0].seq1.len(), 50, "120bp trimmed to 50");
assert_eq!(records[1].seq1.len(), 50, "100bp trimmed to 50");
assert_eq!(records[0].query_id, 0);
assert_eq!(records[1].query_id, 1);
}
#[test]
fn test_parquet_batch_enum_arrow_into_arrow() {
let seqs = vec!["ACGTACGT"];
let batch = make_test_batch(&seqs);
let headers = vec!["read1".to_string()];
let pb = ParquetBatch::Arrow(batch, headers);
let (record_batch, hdrs) = pb.into_arrow();
assert_eq!(record_batch.num_rows(), 1);
assert_eq!(hdrs.len(), 1);
assert_eq!(hdrs[0], "read1");
}
#[test]
fn test_parquet_batch_enum_owned_into_owned() {
let records = vec![OwnedFastxRecord::new(
0,
b"ACGTACGT".to_vec(),
None,
None,
None,
)];
let headers = vec!["read1".to_string()];
let pb = ParquetBatch::Owned(records, headers);
let (owned_records, hdrs) = pb.into_owned();
assert_eq!(owned_records.len(), 1);
assert_eq!(owned_records[0].seq1, b"ACGTACGT");
assert_eq!(hdrs.len(), 1);
assert_eq!(hdrs[0], "read1");
}
#[test]
#[should_panic(expected = "Expected ParquetBatch::Arrow but got Owned")]
fn test_parquet_batch_enum_owned_into_arrow_panics() {
let records = vec![OwnedFastxRecord::new(0, b"ACGT".to_vec(), None, None, None)];
let headers = vec!["read1".to_string()];
let pb = ParquetBatch::Owned(records, headers);
let _ = pb.into_arrow(); }
#[test]
#[should_panic(expected = "Expected ParquetBatch::Owned but got Arrow")]
fn test_parquet_batch_enum_arrow_into_owned_panics() {
let seqs = vec!["ACGT"];
let batch = make_test_batch(&seqs);
let headers = vec!["read1".to_string()];
let pb = ParquetBatch::Arrow(batch, headers);
let _ = pb.into_owned(); }
use parquet::arrow::ArrowWriter;
use tempfile::tempdir;
fn write_test_parquet(dir: &std::path::Path, ids: &[&str], seqs: &[&str]) -> PathBuf {
let path = dir.join("test.parquet");
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
Field::new("read_id", DataType::LargeUtf8, false),
Field::new("sequence1", DataType::LargeUtf8, false),
]));
let file = File::create(&path).unwrap();
let mut writer = ArrowWriter::try_new(file, schema.clone(), None).unwrap();
let id_array = LargeStringArray::from_iter_values(ids.iter().copied());
let seq_array = LargeStringArray::from_iter_values(seqs.iter().copied());
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
path
}
#[test]
fn test_prefetching_parquet_reader_trims_in_reader_thread() {
let dir = tempdir().unwrap();
let s0 = "A".repeat(30);
let s1 = "G".repeat(100);
let s2 = "T".repeat(60);
let path = write_test_parquet(dir.path(), &["r0", "r1", "r2"], &[&s0, &s1, &s2]);
let mut reader = PrefetchingParquetReader::with_parallel_row_groups(
&path,
1000,
None, Some(50),
Some(40),
)
.unwrap();
let mut all_records = Vec::new();
let mut all_headers = Vec::new();
while let Some(batch) = reader.next_batch().unwrap() {
let (records, headers) = batch.into_owned();
all_records.extend(records);
all_headers.extend(headers);
}
assert_eq!(all_records.len(), 2, "30bp read should be filtered out");
assert_eq!(all_headers.len(), 2);
assert_eq!(all_headers[0], "r1");
assert_eq!(all_headers[1], "r2");
assert_eq!(
all_records[0].seq1.len(),
50,
"100bp should be trimmed to 50"
);
assert_eq!(
all_records[1].seq1.len(),
50,
"60bp should be trimmed to 50"
);
reader.finish().unwrap();
}
#[test]
fn test_prefetching_parquet_reader_parallel_trims_in_reader_thread() {
let dir = tempdir().unwrap();
let s0 = "A".repeat(30);
let s1 = "G".repeat(100);
let s2 = "T".repeat(60);
let path = write_test_parquet(dir.path(), &["r0", "r1", "r2"], &[&s0, &s1, &s2]);
let mut reader = PrefetchingParquetReader::with_parallel_row_groups(
&path,
1000,
Some(2), Some(50),
Some(40),
)
.unwrap();
let mut all_records = Vec::new();
let mut all_headers = Vec::new();
while let Some(batch) = reader.next_batch().unwrap() {
let (records, headers) = batch.into_owned();
all_records.extend(records);
all_headers.extend(headers);
}
assert_eq!(all_records.len(), 2, "30bp read should be filtered out");
assert_eq!(all_headers.len(), 2);
assert_eq!(all_headers[0], "r1");
assert_eq!(all_headers[1], "r2");
assert_eq!(all_records[0].seq1.len(), 50);
assert_eq!(all_records[1].seq1.len(), 50);
reader.finish().unwrap();
}
#[test]
fn test_prefetching_parquet_reader_no_filter_returns_arrow() {
let dir = tempdir().unwrap();
let path = write_test_parquet(
dir.path(),
&["r0", "r1", "r2"],
&["ACGTACGT", "GGGGCCCC", "TTTTAAAA"],
);
let mut reader =
PrefetchingParquetReader::with_parallel_row_groups(&path, 1000, None, None, None)
.unwrap();
let mut total_rows = 0;
while let Some(batch) = reader.next_batch().unwrap() {
let (record_batch, headers) = batch.into_arrow();
total_rows += record_batch.num_rows();
assert_eq!(record_batch.num_rows(), headers.len());
}
assert_eq!(total_rows, 3);
reader.finish().unwrap();
}
}