use arrow::array::{Float64Array, Int64Array, UInt32Array};
use arrow::record_batch::RecordBatch;
use std::sync::Arc;
use super::error::ArrowClassifyError;
use super::schema::result_schema;
use crate::HitResult;
pub fn hits_to_record_batch(hits: Vec<HitResult>) -> Result<RecordBatch, ArrowClassifyError> {
let schema = result_schema();
if hits.is_empty() {
return Ok(RecordBatch::new_empty(schema));
}
let capacity = hits.len();
let mut query_ids = Vec::with_capacity(capacity);
let mut bucket_ids = Vec::with_capacity(capacity);
let mut scores = Vec::with_capacity(capacity);
for hit in hits {
query_ids.push(hit.query_id);
bucket_ids.push(hit.bucket_id);
scores.push(hit.score);
}
let query_id_array = Int64Array::from(query_ids);
let bucket_id_array = UInt32Array::from(bucket_ids);
let score_array = Float64Array::from(scores);
RecordBatch::try_new(
schema,
vec![
Arc::new(query_id_array),
Arc::new(bucket_id_array),
Arc::new(score_array),
],
)
.map_err(ArrowClassifyError::from)
}
pub fn empty_result_batch() -> RecordBatch {
RecordBatch::new_empty(result_schema())
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Array;
use arrow::datatypes::DataType;
use super::super::schema::{COL_BUCKET_ID, COL_QUERY_ID, COL_SCORE};
#[test]
fn test_hits_to_batch_basic() {
let hits = vec![
HitResult {
query_id: 1,
bucket_id: 10,
score: 0.95,
},
HitResult {
query_id: 2,
bucket_id: 20,
score: 0.85,
},
];
let batch = hits_to_record_batch(hits).unwrap();
assert_eq!(batch.num_rows(), 2);
assert_eq!(batch.num_columns(), 3);
let query_ids = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(query_ids.value(0), 1);
assert_eq!(query_ids.value(1), 2);
let bucket_ids = batch
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
assert_eq!(bucket_ids.value(0), 10);
assert_eq!(bucket_ids.value(1), 20);
let scores = batch
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((scores.value(0) - 0.95).abs() < 1e-10);
assert!((scores.value(1) - 0.85).abs() < 1e-10);
}
#[test]
fn test_hits_to_batch_empty() {
let hits: Vec<HitResult> = vec![];
let batch = hits_to_record_batch(hits).unwrap();
assert_eq!(batch.num_rows(), 0);
assert_eq!(batch.num_columns(), 3);
assert_eq!(batch.schema().field(0).name(), COL_QUERY_ID);
assert_eq!(batch.schema().field(1).name(), COL_BUCKET_ID);
assert_eq!(batch.schema().field(2).name(), COL_SCORE);
}
#[test]
fn test_hits_to_batch_large() {
let hits: Vec<HitResult> = (0..10000)
.map(|i| HitResult {
query_id: i as i64,
bucket_id: (i % 100) as u32,
score: (i as f64) / 10000.0,
})
.collect();
let batch = hits_to_record_batch(hits).unwrap();
assert_eq!(batch.num_rows(), 10000);
let query_ids = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(query_ids.value(0), 0);
assert_eq!(query_ids.value(9999), 9999);
}
#[test]
fn test_hits_to_batch_schema_correct() {
let hits = vec![HitResult {
query_id: 1,
bucket_id: 1,
score: 0.5,
}];
let batch = hits_to_record_batch(hits).unwrap();
let schema = batch.schema();
assert_eq!(schema.field(0).name(), COL_QUERY_ID);
assert_eq!(schema.field(1).name(), COL_BUCKET_ID);
assert_eq!(schema.field(2).name(), COL_SCORE);
assert_eq!(schema.field(0).data_type(), &DataType::Int64);
assert_eq!(schema.field(1).data_type(), &DataType::UInt32);
assert_eq!(schema.field(2).data_type(), &DataType::Float64);
assert!(!schema.field(0).is_nullable());
assert!(!schema.field(1).is_nullable());
assert!(!schema.field(2).is_nullable());
}
#[test]
fn test_empty_result_batch() {
let batch = empty_result_batch();
assert_eq!(batch.num_rows(), 0);
assert_eq!(batch.num_columns(), 3);
assert_eq!(batch.schema().field(0).name(), COL_QUERY_ID);
}
#[test]
fn test_hits_to_batch_extreme_values() {
let hits = vec![
HitResult {
query_id: i64::MIN,
bucket_id: 0,
score: 0.0,
},
HitResult {
query_id: i64::MAX,
bucket_id: u32::MAX,
score: 1.0,
},
];
let batch = hits_to_record_batch(hits).unwrap();
let query_ids = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(query_ids.value(0), i64::MIN);
assert_eq!(query_ids.value(1), i64::MAX);
let bucket_ids = batch
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
assert_eq!(bucket_ids.value(0), 0);
assert_eq!(bucket_ids.value(1), u32::MAX);
}
}