mod error;
pub mod extraction;
mod input;
mod output;
mod schema;
mod stream;
pub use error::ArrowClassifyError;
pub use extraction::{
extract_minimizer_set_batch, extract_strand_minimizers_batch, minimizer_set_schema,
strand_minimizers_schema,
};
pub use input::{batch_to_records, batch_to_records_with_columns};
pub use output::{empty_result_batch, hits_to_record_batch};
pub use schema::{
log_ratio_result_schema, result_schema, validate_input_schema, COL_BUCKET_ID, COL_FAST_PATH,
COL_ID, COL_LOG_RATIO, COL_PAIR_SEQUENCE, COL_QUERY_ID, COL_SCORE, COL_SEQUENCE,
};
pub use stream::ShardedStreamClassifier;
use std::collections::HashSet;
use arrow::record_batch::RecordBatch;
use crate::{classify_batch_sharded_merge_join, ShardedInvertedIndex};
pub fn classify_arrow_batch_sharded(
sharded: &ShardedInvertedIndex,
negative_mins: Option<&HashSet<u64>>,
batch: &RecordBatch,
threshold: f64,
) -> Result<RecordBatch, ArrowClassifyError> {
classify_arrow_batch_sharded_internal(sharded, negative_mins, batch, threshold, false)
}
pub fn classify_arrow_batch_sharded_best_hit(
sharded: &ShardedInvertedIndex,
negative_mins: Option<&HashSet<u64>>,
batch: &RecordBatch,
threshold: f64,
) -> Result<RecordBatch, ArrowClassifyError> {
classify_arrow_batch_sharded_internal(sharded, negative_mins, batch, threshold, true)
}
fn classify_arrow_batch_sharded_internal(
sharded: &ShardedInvertedIndex,
negative_mins: Option<&HashSet<u64>>,
batch: &RecordBatch,
threshold: f64,
best_hit_only: bool,
) -> Result<RecordBatch, ArrowClassifyError> {
let records = batch_to_records(batch)?;
let hits = classify_batch_sharded_merge_join(sharded, negative_mins, &records, threshold, None)
.map_err(|e| ArrowClassifyError::Classification(e.to_string()))?;
let hits = if best_hit_only {
crate::classify::filter_best_hits(hits)
} else {
hits
};
hits_to_record_batch(hits)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
create_parquet_inverted_index, extract_into, BucketData, MinimizerWorkspace,
ParquetWriteOptions,
};
use arrow::array::{BinaryArray, Float64Array, Int64Array, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
use tempfile::tempdir;
fn generate_sequence(len: usize, seed: u8) -> Vec<u8> {
let bases = [b'A', b'C', b'G', b'T'];
(0..len).map(|i| bases[(i + seed as usize) % 4]).collect()
}
fn make_test_batch(ids: &[i64], seqs: &[&[u8]]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let id_array = Int64Array::from(ids.to_vec());
let seq_array = BinaryArray::from_iter_values(seqs.iter().copied());
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap()
}
fn create_test_parquet_index() -> (tempfile::TempDir, ShardedInvertedIndex) {
let dir = tempdir().unwrap();
let index_path = dir.path().join("test.ryxdi");
let mut ws = MinimizerWorkspace::new();
let ref_seq = generate_sequence(100, 0);
extract_into(&ref_seq, 16, 5, 0x12345, &mut ws);
let mut mins: Vec<u64> = ws.buffer.drain(..).collect();
mins.sort();
mins.dedup();
let buckets = vec![BucketData {
bucket_id: 1,
bucket_name: "test_bucket".to_string(),
sources: vec!["ref1".to_string()],
minimizers: mins,
}];
let options = ParquetWriteOptions::default();
create_parquet_inverted_index(
&index_path,
buckets,
16,
5,
0x12345,
None,
Some(&options),
None,
)
.unwrap();
let index = ShardedInvertedIndex::open(&index_path).unwrap();
(dir, index)
}
#[test]
fn test_classify_arrow_batch_sharded_basic() {
let (_dir, index) = create_test_parquet_index();
let query_seq = generate_sequence(100, 0);
let batch = make_test_batch(&[101], &[&query_seq]);
let result = classify_arrow_batch_sharded(&index, None, &batch, 0.0).unwrap();
assert!(result.num_rows() > 0, "Should have classification results");
assert_eq!(result.num_columns(), 3);
assert_eq!(result.schema().field(0).name(), COL_QUERY_ID);
assert_eq!(result.schema().field(1).name(), COL_BUCKET_ID);
assert_eq!(result.schema().field(2).name(), COL_SCORE);
}
#[test]
fn test_classify_arrow_batch_sharded_no_matches() {
let (_dir, index) = create_test_parquet_index();
let query_seq = vec![b'N'; 100]; let batch = make_test_batch(&[101], &[&query_seq]);
let result = classify_arrow_batch_sharded(&index, None, &batch, 0.5).unwrap();
assert_eq!(result.num_columns(), 3);
}
#[test]
fn test_classify_arrow_batch_sharded_multiple_queries() {
let (_dir, index) = create_test_parquet_index();
let query1 = generate_sequence(100, 0); let query2 = generate_sequence(100, 1); let batch = make_test_batch(&[1, 2], &[&query1, &query2]);
let result = classify_arrow_batch_sharded(&index, None, &batch, 0.0).unwrap();
assert!(result.num_rows() >= 1);
let query_ids = result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
let bucket_ids = result
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
let scores = result
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
for i in 0..result.num_rows() {
assert!(query_ids.value(i) == 1 || query_ids.value(i) == 2);
assert_eq!(bucket_ids.value(i), 1); assert!(scores.value(i) >= 0.0 && scores.value(i) <= 1.0);
}
}
#[test]
fn test_classify_arrow_batch_sharded_empty() {
let (_dir, index) = create_test_parquet_index();
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let empty_batch = RecordBatch::new_empty(schema);
let result = classify_arrow_batch_sharded(&index, None, &empty_batch, 0.1).unwrap();
assert_eq!(result.num_rows(), 0);
}
#[test]
fn test_classify_arrow_batch_sharded_invalid_schema() {
let (_dir, index) = create_test_parquet_index();
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Utf8, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let batch = RecordBatch::new_empty(schema);
let result = classify_arrow_batch_sharded(&index, None, &batch, 0.1);
assert!(result.is_err());
}
}