use arrow::array::{
Array, BinaryArray, Int64Array, LargeBinaryArray, LargeStringArray, StringArray,
};
use arrow::record_batch::RecordBatch;
use super::error::{ArrowClassifyError, MAX_SEQUENCE_LENGTH};
use super::schema::{validate_input_schema, COL_ID, COL_PAIR_SEQUENCE, COL_SEQUENCE};
use crate::QueryRecord;
trait BinaryArrayAccess<'a>: Sized {
fn value_at(&self, i: usize) -> &'a [u8];
fn is_null_at(&self, i: usize) -> bool;
}
impl<'a> BinaryArrayAccess<'a> for &'a BinaryArray {
#[inline]
fn value_at(&self, i: usize) -> &'a [u8] {
self.value(i)
}
#[inline]
fn is_null_at(&self, i: usize) -> bool {
self.is_null(i)
}
}
impl<'a> BinaryArrayAccess<'a> for &'a LargeBinaryArray {
#[inline]
fn value_at(&self, i: usize) -> &'a [u8] {
self.value(i)
}
#[inline]
fn is_null_at(&self, i: usize) -> bool {
self.is_null(i)
}
}
impl<'a> BinaryArrayAccess<'a> for &'a StringArray {
#[inline]
fn value_at(&self, i: usize) -> &'a [u8] {
self.value(i).as_bytes()
}
#[inline]
fn is_null_at(&self, i: usize) -> bool {
self.is_null(i)
}
}
impl<'a> BinaryArrayAccess<'a> for &'a LargeStringArray {
#[inline]
fn value_at(&self, i: usize) -> &'a [u8] {
self.value(i).as_bytes()
}
#[inline]
fn is_null_at(&self, i: usize) -> bool {
self.is_null(i)
}
}
enum BinaryColumnRef<'a> {
Binary(&'a BinaryArray),
LargeBinary(&'a LargeBinaryArray),
String(&'a StringArray),
LargeString(&'a LargeStringArray),
}
impl<'a> BinaryColumnRef<'a> {
#[inline]
fn value(&self, i: usize) -> &'a [u8] {
match self {
BinaryColumnRef::Binary(arr) => arr.value_at(i),
BinaryColumnRef::LargeBinary(arr) => arr.value_at(i),
BinaryColumnRef::String(arr) => arr.value_at(i),
BinaryColumnRef::LargeString(arr) => arr.value_at(i),
}
}
#[inline]
fn is_null(&self, i: usize) -> bool {
match self {
BinaryColumnRef::Binary(arr) => arr.is_null_at(i),
BinaryColumnRef::LargeBinary(arr) => arr.is_null_at(i),
BinaryColumnRef::String(arr) => arr.is_null_at(i),
BinaryColumnRef::LargeString(arr) => arr.is_null_at(i),
}
}
}
fn get_binary_column(
batch: &RecordBatch,
col_idx: usize,
) -> Result<BinaryColumnRef<'_>, ArrowClassifyError> {
let column = batch.column(col_idx);
if let Some(arr) = column.as_any().downcast_ref::<BinaryArray>() {
return Ok(BinaryColumnRef::Binary(arr));
}
if let Some(arr) = column.as_any().downcast_ref::<LargeBinaryArray>() {
return Ok(BinaryColumnRef::LargeBinary(arr));
}
if let Some(arr) = column.as_any().downcast_ref::<StringArray>() {
return Ok(BinaryColumnRef::String(arr));
}
if let Some(arr) = column.as_any().downcast_ref::<LargeStringArray>() {
return Ok(BinaryColumnRef::LargeString(arr));
}
let schema = batch.schema();
let field = schema.field(col_idx);
Err(ArrowClassifyError::TypeError {
column: field.name().clone(),
expected: "Binary, LargeBinary, Utf8, or LargeUtf8".into(),
actual: format!("{:?}", column.data_type()),
})
}
pub fn batch_to_records(batch: &RecordBatch) -> Result<Vec<QueryRecord<'_>>, ArrowClassifyError> {
validate_input_schema(batch.schema().as_ref())?;
batch_to_records_with_columns(batch, COL_ID, COL_SEQUENCE, Some(COL_PAIR_SEQUENCE))
}
pub fn batch_to_records_with_columns<'a>(
batch: &'a RecordBatch,
id_col: &str,
seq_col: &str,
pair_col: Option<&str>,
) -> Result<Vec<QueryRecord<'a>>, ArrowClassifyError> {
let num_rows = batch.num_rows();
if num_rows == 0 {
return Ok(Vec::new());
}
let id_idx = batch
.schema()
.index_of(id_col)
.map_err(|_| ArrowClassifyError::ColumnNotFound(id_col.into()))?;
let seq_idx = batch
.schema()
.index_of(seq_col)
.map_err(|_| ArrowClassifyError::ColumnNotFound(seq_col.into()))?;
let pair_idx = pair_col.and_then(|col| batch.schema().index_of(col).ok());
let ids = batch
.column(id_idx)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| ArrowClassifyError::TypeError {
column: id_col.into(),
expected: "Int64".into(),
actual: format!("{:?}", batch.column(id_idx).data_type()),
})?;
let seqs = get_binary_column(batch, seq_idx).map_err(|e| match e {
ArrowClassifyError::TypeError {
expected, actual, ..
} => ArrowClassifyError::TypeError {
column: seq_col.into(),
expected,
actual,
},
other => other,
})?;
let pairs = pair_idx
.map(|idx| {
get_binary_column(batch, idx).map_err(|e| match e {
ArrowClassifyError::TypeError {
expected, actual, ..
} => ArrowClassifyError::TypeError {
column: pair_col.unwrap_or("pair").into(),
expected,
actual,
},
other => other,
})
})
.transpose()?;
let mut records = Vec::with_capacity(num_rows);
for i in 0..num_rows {
if ids.is_null(i) {
return Err(ArrowClassifyError::NullError {
column: id_col.into(),
row: i,
});
}
let id = ids.value(i);
if seqs.is_null(i) {
return Err(ArrowClassifyError::NullError {
column: seq_col.into(),
row: i,
});
}
let seq: &[u8] = seqs.value(i);
if seq.len() > MAX_SEQUENCE_LENGTH {
return Err(ArrowClassifyError::SequenceTooLong {
row: i,
length: seq.len(),
max_length: MAX_SEQUENCE_LENGTH,
});
}
let pair: Option<&[u8]> = match pairs.as_ref() {
Some(p) if !p.is_null(i) => {
let pair_seq = p.value(i);
if pair_seq.len() > MAX_SEQUENCE_LENGTH {
return Err(ArrowClassifyError::SequenceTooLong {
row: i,
length: pair_seq.len(),
max_length: MAX_SEQUENCE_LENGTH,
});
}
Some(pair_seq)
}
_ => None,
};
records.push((id, seq, pair));
}
Ok(records)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{BinaryArray, Int64Array, LargeBinaryArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use std::sync::Arc;
fn make_test_batch(ids: &[i64], seqs: &[&[u8]]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let id_array = Int64Array::from(ids.to_vec());
let seq_array = BinaryArray::from_iter_values(seqs.iter().copied());
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap()
}
fn make_test_batch_paired(ids: &[i64], seqs: &[&[u8]], pairs: &[Option<&[u8]>]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
Field::new(COL_PAIR_SEQUENCE, DataType::Binary, true),
]));
let id_array = Int64Array::from(ids.to_vec());
let seq_array = BinaryArray::from_iter_values(seqs.iter().copied());
let pair_array = BinaryArray::from_iter(pairs.iter().copied());
RecordBatch::try_new(
schema,
vec![
Arc::new(id_array),
Arc::new(seq_array),
Arc::new(pair_array),
],
)
.unwrap()
}
#[test]
fn test_batch_to_records_single_end() {
let batch = make_test_batch(&[1, 2, 3], &[b"ACGT", b"TGCA", b"GGCC"]);
let records = batch_to_records(&batch).unwrap();
assert_eq!(records.len(), 3);
assert_eq!(records[0].0, 1);
assert_eq!(records[0].1, b"ACGT");
assert!(records[0].2.is_none());
assert_eq!(records[1].0, 2);
assert_eq!(records[1].1, b"TGCA");
assert_eq!(records[2].0, 3);
assert_eq!(records[2].1, b"GGCC");
}
#[test]
fn test_batch_to_records_paired_end() {
let batch = make_test_batch_paired(
&[1, 2],
&[b"ACGT", b"TGCA"],
&[Some(b"AAAA" as &[u8]), Some(b"TTTT")],
);
let records = batch_to_records(&batch).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].0, 1);
assert_eq!(records[0].1, b"ACGT");
assert_eq!(records[0].2, Some(b"AAAA" as &[u8]));
assert_eq!(records[1].0, 2);
assert_eq!(records[1].1, b"TGCA");
assert_eq!(records[1].2, Some(b"TTTT" as &[u8]));
}
#[test]
fn test_batch_to_records_null_pair_handling() {
let batch = make_test_batch_paired(
&[1, 2, 3],
&[b"ACGT", b"TGCA", b"GGCC"],
&[Some(b"AAAA" as &[u8]), None, Some(b"CCCC")],
);
let records = batch_to_records(&batch).unwrap();
assert_eq!(records.len(), 3);
assert_eq!(records[0].2, Some(b"AAAA" as &[u8]));
assert!(records[1].2.is_none()); assert_eq!(records[2].2, Some(b"CCCC" as &[u8]));
}
#[test]
fn test_batch_to_records_large_binary_array() {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::LargeBinary, false),
]));
let id_array = Int64Array::from(vec![1, 2]);
let seq_array = LargeBinaryArray::from_iter_values([b"ACGT" as &[u8], b"TGCA"]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let records = batch_to_records(&batch).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].1, b"ACGT");
assert_eq!(records[1].1, b"TGCA");
}
#[test]
fn test_batch_to_records_empty_batch() {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let batch = RecordBatch::new_empty(schema);
let records = batch_to_records(&batch).unwrap();
assert!(records.is_empty());
}
#[test]
fn test_batch_to_records_zero_copy_verification() {
let seq1 = b"ACGTACGTACGT";
let seq2 = b"TGCATGCATGCA";
let batch = make_test_batch(&[1, 2], &[seq1 as &[u8], seq2]);
let records = batch_to_records(&batch).unwrap();
let record_ptr_1 = records[0].1.as_ptr();
let record_ptr_2 = records[1].1.as_ptr();
let seq_col = batch.column(1);
let binary_arr = seq_col.as_any().downcast_ref::<BinaryArray>().unwrap();
let arrow_ptr_1 = binary_arr.value(0).as_ptr();
let arrow_ptr_2 = binary_arr.value(1).as_ptr();
assert_eq!(
record_ptr_1, arrow_ptr_1,
"Record should point directly into Arrow buffer"
);
assert_eq!(
record_ptr_2, arrow_ptr_2,
"Record should point directly into Arrow buffer"
);
}
#[test]
fn test_batch_to_records_null_id_error() {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, true), Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let id_array = Int64Array::from(vec![Some(1), None, Some(3)]);
let seq_array = BinaryArray::from_iter_values([b"ACGT" as &[u8], b"TGCA", b"GGCC"]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let result = batch_to_records(&batch);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::NullError { column, row } if column == COL_ID && row == 1
));
}
#[test]
fn test_batch_to_records_null_sequence_error() {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, true), ]));
let id_array = Int64Array::from(vec![1, 2, 3]);
let seq_array = BinaryArray::from_iter([
Some(b"ACGT" as &[u8]),
Some(b"TGCA"),
None, ]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let result = batch_to_records(&batch);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::NullError { column, row } if column == COL_SEQUENCE && row == 2
));
}
#[test]
fn test_batch_to_records_with_columns_custom_names() {
let schema = Arc::new(Schema::new(vec![
Field::new("read_id", DataType::Int64, false),
Field::new("sequence1", DataType::Binary, false),
Field::new("sequence2", DataType::Binary, true),
]));
let id_array = Int64Array::from(vec![1, 2]);
let seq_array = BinaryArray::from_iter_values([b"ACGT" as &[u8], b"TGCA"]);
let pair_array = BinaryArray::from_iter([Some(b"AAAA" as &[u8]), Some(b"TTTT")]);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(id_array),
Arc::new(seq_array),
Arc::new(pair_array),
],
)
.unwrap();
let records =
batch_to_records_with_columns(&batch, "read_id", "sequence1", Some("sequence2"))
.unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].0, 1);
assert_eq!(records[0].1, b"ACGT");
assert_eq!(records[0].2, Some(b"AAAA" as &[u8]));
assert_eq!(records[1].0, 2);
assert_eq!(records[1].1, b"TGCA");
assert_eq!(records[1].2, Some(b"TTTT" as &[u8]));
}
#[test]
fn test_batch_to_records_with_columns_single_end() {
let schema = Arc::new(Schema::new(vec![
Field::new("my_id", DataType::Int64, false),
Field::new("my_seq", DataType::Binary, false),
]));
let id_array = Int64Array::from(vec![1, 2, 3]);
let seq_array = BinaryArray::from_iter_values([b"ACGT" as &[u8], b"TGCA", b"GGCC"]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let records = batch_to_records_with_columns(&batch, "my_id", "my_seq", None).unwrap();
assert_eq!(records.len(), 3);
assert_eq!(records[0].0, 1);
assert_eq!(records[0].1, b"ACGT");
assert!(records[0].2.is_none());
assert_eq!(records[2].0, 3);
assert_eq!(records[2].1, b"GGCC");
}
#[test]
fn test_batch_to_records_with_columns_zero_copy() {
let schema = Arc::new(Schema::new(vec![
Field::new("id_col", DataType::Int64, false),
Field::new("seq_col", DataType::Binary, false),
]));
let id_array = Int64Array::from(vec![1, 2]);
let seq_array = BinaryArray::from_iter_values([b"ACGTACGTACGT" as &[u8], b"TGCATGCATGCA"]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let records = batch_to_records_with_columns(&batch, "id_col", "seq_col", None).unwrap();
let record_ptr = records[0].1.as_ptr();
let seq_col = batch.column(1);
let binary_arr = seq_col.as_any().downcast_ref::<BinaryArray>().unwrap();
let arrow_ptr = binary_arr.value(0).as_ptr();
assert_eq!(
record_ptr, arrow_ptr,
"Record should point directly into Arrow buffer"
);
}
#[test]
fn test_batch_to_records_with_columns_missing_id() {
let schema = Arc::new(Schema::new(vec![
Field::new("wrong_id", DataType::Int64, false),
Field::new("sequence", DataType::Binary, false),
]));
let id_array = Int64Array::from(vec![1]);
let seq_array = BinaryArray::from_iter_values([b"ACGT" as &[u8]]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let result = batch_to_records_with_columns(&batch, "id", "sequence", None);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::ColumnNotFound(col) if col == "id"
));
}
#[test]
fn test_batch_to_records_with_columns_missing_sequence() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("wrong_seq", DataType::Binary, false),
]));
let id_array = Int64Array::from(vec![1]);
let seq_array = BinaryArray::from_iter_values([b"ACGT" as &[u8]]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let result = batch_to_records_with_columns(&batch, "id", "sequence", None);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::ColumnNotFound(col) if col == "sequence"
));
}
#[test]
fn test_batch_to_records_with_columns_wrong_id_type() {
use arrow::array::StringArray;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false), Field::new("sequence", DataType::Binary, false),
]));
let id_array = StringArray::from(vec!["1"]);
let seq_array = BinaryArray::from_iter_values([b"ACGT" as &[u8]]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let result = batch_to_records_with_columns(&batch, "id", "sequence", None);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::TypeError { column, .. } if column == "id"
));
}
#[test]
fn test_batch_to_records_with_columns_large_binary() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("seq", DataType::LargeBinary, false),
Field::new("pair", DataType::LargeBinary, true),
]));
let id_array = Int64Array::from(vec![1, 2]);
let seq_array = LargeBinaryArray::from_iter_values([b"ACGT" as &[u8], b"TGCA"]);
let pair_array = LargeBinaryArray::from_iter([Some(b"AAAA" as &[u8]), None]);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(id_array),
Arc::new(seq_array),
Arc::new(pair_array),
],
)
.unwrap();
let records = batch_to_records_with_columns(&batch, "id", "seq", Some("pair")).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].1, b"ACGT");
assert_eq!(records[0].2, Some(b"AAAA" as &[u8]));
assert_eq!(records[1].1, b"TGCA");
assert!(records[1].2.is_none()); }
#[test]
fn test_batch_to_records_with_columns_string_arrays() {
use arrow::array::StringArray;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("seq", DataType::Utf8, false),
Field::new("pair", DataType::Utf8, true),
]));
let id_array = Int64Array::from(vec![1, 2]);
let seq_array = StringArray::from(vec!["ACGT", "TGCA"]);
let pair_array = StringArray::from(vec![Some("AAAA"), None]);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(id_array),
Arc::new(seq_array),
Arc::new(pair_array),
],
)
.unwrap();
let records = batch_to_records_with_columns(&batch, "id", "seq", Some("pair")).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].0, 1);
assert_eq!(records[0].1, b"ACGT");
assert_eq!(records[0].2, Some(b"AAAA" as &[u8]));
assert_eq!(records[1].0, 2);
assert_eq!(records[1].1, b"TGCA");
assert!(records[1].2.is_none()); }
#[test]
fn test_batch_to_records_with_columns_string_zero_copy() {
use arrow::array::StringArray;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("seq", DataType::Utf8, false),
]));
let id_array = Int64Array::from(vec![1]);
let seq_array = StringArray::from(vec!["ACGTACGTACGT"]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let records = batch_to_records_with_columns(&batch, "id", "seq", None).unwrap();
let record_ptr = records[0].1.as_ptr();
let seq_col = batch.column(1);
let string_arr = seq_col.as_any().downcast_ref::<StringArray>().unwrap();
let arrow_ptr = string_arr.value(0).as_bytes().as_ptr();
assert_eq!(
record_ptr, arrow_ptr,
"Record should point directly into Arrow buffer"
);
}
#[test]
fn test_batch_to_records_with_columns_large_string_arrays() {
use arrow::array::LargeStringArray;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("seq", DataType::LargeUtf8, false),
]));
let id_array = Int64Array::from(vec![1, 2]);
let seq_array = LargeStringArray::from(vec!["ACGT", "TGCA"]);
let batch =
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap();
let records = batch_to_records_with_columns(&batch, "id", "seq", None).unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].1, b"ACGT");
assert_eq!(records[1].1, b"TGCA");
}
#[test]
fn test_batch_to_records_invalid_schema() {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Utf8, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let batch = RecordBatch::new_empty(schema);
let result = batch_to_records(&batch);
assert!(result.is_err());
}
}