use std::collections::HashSet;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use super::error::ArrowClassifyError;
use super::input::batch_to_records;
use super::output::hits_to_record_batch;
use super::schema::result_schema;
use crate::{classify_batch_sharded_merge_join, ShardedInvertedIndex};
pub struct ShardedStreamClassifier<'a> {
index: &'a ShardedInvertedIndex,
negative_mins: Option<&'a HashSet<u64>>,
threshold: f64,
output_schema: SchemaRef,
}
impl<'a> ShardedStreamClassifier<'a> {
pub fn new(
index: &'a ShardedInvertedIndex,
negative_mins: Option<&'a HashSet<u64>>,
threshold: f64,
) -> Self {
Self {
index,
negative_mins,
threshold,
output_schema: result_schema(),
}
}
pub fn output_schema(&self) -> SchemaRef {
self.output_schema.clone()
}
pub fn classify_batch(&self, batch: &RecordBatch) -> Result<RecordBatch, ArrowClassifyError> {
let records = batch_to_records(batch)?;
let hits = classify_batch_sharded_merge_join(
self.index,
self.negative_mins,
&records,
self.threshold,
None,
)
.map_err(|e| ArrowClassifyError::Classification(e.to_string()))?;
hits_to_record_batch(hits)
}
pub fn classify_iter<I>(
&self,
input: I,
) -> impl Iterator<Item = Result<RecordBatch, ArrowClassifyError>> + '_
where
I: Iterator<Item = Result<RecordBatch, arrow::error::ArrowError>> + 'a,
{
input.map(move |batch_result| {
let batch = batch_result.map_err(ArrowClassifyError::from)?;
self.classify_batch(&batch)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
create_parquet_inverted_index, extract_into, BucketData, MinimizerWorkspace,
ParquetWriteOptions,
};
use arrow::array::{BinaryArray, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
use tempfile::tempdir;
use super::super::schema::{COL_ID, COL_SEQUENCE};
fn generate_sequence(len: usize, seed: u8) -> Vec<u8> {
let bases = [b'A', b'C', b'G', b'T'];
(0..len).map(|i| bases[(i + seed as usize) % 4]).collect()
}
fn make_test_batch(ids: &[i64], seqs: &[&[u8]]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let id_array = Int64Array::from(ids.to_vec());
let seq_array = BinaryArray::from_iter_values(seqs.iter().copied());
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(seq_array)]).unwrap()
}
fn create_test_parquet_index() -> (tempfile::TempDir, ShardedInvertedIndex) {
let dir = tempdir().unwrap();
let index_path = dir.path().join("test.ryxdi");
let mut ws = MinimizerWorkspace::new();
let ref_seq = generate_sequence(100, 0);
extract_into(&ref_seq, 16, 5, 0x12345, &mut ws);
let mut mins: Vec<u64> = ws.buffer.drain(..).collect();
mins.sort();
mins.dedup();
let buckets = vec![BucketData {
bucket_id: 1,
bucket_name: "test_bucket".to_string(),
sources: vec!["ref1".to_string()],
minimizers: mins,
}];
let options = ParquetWriteOptions::default();
create_parquet_inverted_index(
&index_path,
buckets,
16,
5,
0x12345,
None,
Some(&options),
None,
)
.unwrap();
let index = ShardedInvertedIndex::open(&index_path).unwrap();
(dir, index)
}
#[test]
fn test_stream_classifier_single_batch() {
let (_dir, index) = create_test_parquet_index();
let classifier = ShardedStreamClassifier::new(&index, None, 0.0);
let query_seq = generate_sequence(100, 0);
let batch = make_test_batch(&[1], &[&query_seq]);
let result = classifier.classify_batch(&batch).unwrap();
assert!(result.num_rows() > 0, "Should have classification results");
}
#[test]
fn test_stream_classifier_multiple_batches() {
let (_dir, index) = create_test_parquet_index();
let classifier = ShardedStreamClassifier::new(&index, None, 0.0);
let query_seq1 = generate_sequence(100, 0);
let query_seq2 = generate_sequence(100, 1);
let batch1 = make_test_batch(&[1], &[&query_seq1]);
let batch2 = make_test_batch(&[2], &[&query_seq2]);
let input_batches: Vec<Result<RecordBatch, arrow::error::ArrowError>> =
vec![Ok(batch1), Ok(batch2)];
let results: Vec<_> = classifier
.classify_iter(input_batches.into_iter())
.collect();
assert_eq!(results.len(), 2);
assert!(results[0].is_ok());
assert!(results[1].is_ok());
}
#[test]
fn test_stream_classifier_empty_input() {
let (_dir, index) = create_test_parquet_index();
let classifier = ShardedStreamClassifier::new(&index, None, 0.1);
let input_batches: Vec<Result<RecordBatch, arrow::error::ArrowError>> = vec![];
let results: Vec<_> = classifier
.classify_iter(input_batches.into_iter())
.collect();
assert!(results.is_empty());
}
#[test]
fn test_stream_classifier_empty_batch() {
let (_dir, index) = create_test_parquet_index();
let classifier = ShardedStreamClassifier::new(&index, None, 0.1);
let schema = Arc::new(Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
]));
let empty_batch = RecordBatch::new_empty(schema);
let result = classifier.classify_batch(&empty_batch).unwrap();
assert_eq!(result.num_rows(), 0);
}
#[test]
fn test_stream_classifier_output_schema() {
let (_dir, index) = create_test_parquet_index();
let classifier = ShardedStreamClassifier::new(&index, None, 0.1);
let schema = classifier.output_schema();
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.field(0).name(), "query_id");
assert_eq!(schema.field(1).name(), "bucket_id");
assert_eq!(schema.field(2).name(), "score");
}
#[test]
fn test_stream_classifier_threshold_filtering() {
let (_dir, index) = create_test_parquet_index();
let classifier_high = ShardedStreamClassifier::new(&index, None, 1.1);
let query_seq = generate_sequence(100, 0);
let batch = make_test_batch(&[1], &[&query_seq]);
let result_high = classifier_high.classify_batch(&batch).unwrap();
assert_eq!(
result_high.num_rows(),
0,
"High threshold should filter all"
);
let classifier_low = ShardedStreamClassifier::new(&index, None, 0.0);
let result_low = classifier_low.classify_batch(&batch).unwrap();
assert!(result_low.num_rows() > 0, "Zero threshold should pass some");
}
#[test]
fn test_stream_classifier_error_propagation() {
let (_dir, index) = create_test_parquet_index();
let classifier = ShardedStreamClassifier::new(&index, None, 0.1);
let error = arrow::error::ArrowError::InvalidArgumentError("test error".into());
let input_batches: Vec<Result<RecordBatch, arrow::error::ArrowError>> = vec![Err(error)];
let results: Vec<_> = classifier
.classify_iter(input_batches.into_iter())
.collect();
assert_eq!(results.len(), 1);
assert!(results[0].is_err());
}
}