use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use std::sync::Arc;
use super::error::ArrowClassifyError;
pub const COL_ID: &str = "id";
pub const COL_SEQUENCE: &str = "sequence";
pub const COL_PAIR_SEQUENCE: &str = "pair_sequence";
pub const COL_QUERY_ID: &str = "query_id";
pub const COL_BUCKET_ID: &str = "bucket_id";
pub const COL_SCORE: &str = "score";
pub const COL_LOG_RATIO: &str = "log_ratio";
pub const COL_FAST_PATH: &str = "fast_path";
pub fn result_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new(COL_QUERY_ID, DataType::Int64, false),
Field::new(COL_BUCKET_ID, DataType::UInt32, false),
Field::new(COL_SCORE, DataType::Float64, false),
]))
}
pub fn log_ratio_result_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new(COL_QUERY_ID, DataType::Int64, false),
Field::new(COL_LOG_RATIO, DataType::Float64, false),
Field::new(COL_FAST_PATH, DataType::Int32, false),
]))
}
fn is_binary_type(dt: &DataType) -> bool {
matches!(dt, DataType::Binary | DataType::LargeBinary)
}
pub fn validate_input_schema(schema: &Schema) -> Result<(), ArrowClassifyError> {
match schema.column_with_name(COL_ID) {
Some((_, field)) => {
if field.data_type() != &DataType::Int64 {
return Err(ArrowClassifyError::TypeError {
column: COL_ID.into(),
expected: "Int64".into(),
actual: format!("{:?}", field.data_type()),
});
}
}
None => {
return Err(ArrowClassifyError::ColumnNotFound(COL_ID.into()));
}
}
match schema.column_with_name(COL_SEQUENCE) {
Some((_, field)) => {
if !is_binary_type(field.data_type()) {
return Err(ArrowClassifyError::TypeError {
column: COL_SEQUENCE.into(),
expected: "Binary or LargeBinary".into(),
actual: format!("{:?}", field.data_type()),
});
}
}
None => {
return Err(ArrowClassifyError::ColumnNotFound(COL_SEQUENCE.into()));
}
}
if let Some((_, field)) = schema.column_with_name(COL_PAIR_SEQUENCE) {
if !is_binary_type(field.data_type()) {
return Err(ArrowClassifyError::TypeError {
column: COL_PAIR_SEQUENCE.into(),
expected: "Binary or LargeBinary".into(),
actual: format!("{:?}", field.data_type()),
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{DataType, Field, Schema};
fn make_valid_schema() -> Schema {
Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
])
}
fn make_valid_schema_with_pair() -> Schema {
Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
Field::new(COL_PAIR_SEQUENCE, DataType::Binary, true),
])
}
#[test]
fn test_input_schema_valid() {
let schema = make_valid_schema();
assert!(validate_input_schema(&schema).is_ok());
}
#[test]
fn test_input_schema_valid_with_large_binary() {
let schema = Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::LargeBinary, false),
]);
assert!(validate_input_schema(&schema).is_ok());
}
#[test]
fn test_input_schema_missing_id_column() {
let schema = Schema::new(vec![Field::new(COL_SEQUENCE, DataType::Binary, false)]);
let result = validate_input_schema(&schema);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::ColumnNotFound(col) if col == COL_ID
));
}
#[test]
fn test_input_schema_wrong_id_type() {
let schema = Schema::new(vec![
Field::new(COL_ID, DataType::Utf8, false), Field::new(COL_SEQUENCE, DataType::Binary, false),
]);
let result = validate_input_schema(&schema);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::TypeError { column, .. } if column == COL_ID
));
}
#[test]
fn test_input_schema_missing_sequence_column() {
let schema = Schema::new(vec![Field::new(COL_ID, DataType::Int64, false)]);
let result = validate_input_schema(&schema);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::ColumnNotFound(col) if col == COL_SEQUENCE
));
}
#[test]
fn test_input_schema_wrong_sequence_type() {
let schema = Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Utf8, false), ]);
let result = validate_input_schema(&schema);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::TypeError { column, .. } if column == COL_SEQUENCE
));
}
#[test]
fn test_input_schema_with_optional_pair() {
let schema = make_valid_schema_with_pair();
assert!(validate_input_schema(&schema).is_ok());
}
#[test]
fn test_input_schema_with_optional_pair_large_binary() {
let schema = Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
Field::new(COL_PAIR_SEQUENCE, DataType::LargeBinary, true),
]);
assert!(validate_input_schema(&schema).is_ok());
}
#[test]
fn test_input_schema_wrong_pair_type() {
let schema = Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
Field::new(COL_PAIR_SEQUENCE, DataType::Utf8, true), ]);
let result = validate_input_schema(&schema);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ArrowClassifyError::TypeError { column, .. } if column == COL_PAIR_SEQUENCE
));
}
#[test]
fn test_input_schema_extra_columns_allowed() {
let schema = Schema::new(vec![
Field::new(COL_ID, DataType::Int64, false),
Field::new(COL_SEQUENCE, DataType::Binary, false),
Field::new("extra_column", DataType::Utf8, true),
Field::new("another_extra", DataType::Float64, true),
]);
assert!(validate_input_schema(&schema).is_ok());
}
#[test]
fn test_log_ratio_output_schema_structure() {
let schema = log_ratio_result_schema();
assert_eq!(schema.fields().len(), 3);
let query_id_field = schema.field_with_name(COL_QUERY_ID).unwrap();
assert_eq!(query_id_field.data_type(), &DataType::Int64);
assert!(!query_id_field.is_nullable());
let log_ratio_field = schema.field_with_name(COL_LOG_RATIO).unwrap();
assert_eq!(log_ratio_field.data_type(), &DataType::Float64);
assert!(!log_ratio_field.is_nullable());
let fast_path_field = schema.field_with_name(COL_FAST_PATH).unwrap();
assert_eq!(fast_path_field.data_type(), &DataType::Int32);
assert!(!fast_path_field.is_nullable());
}
#[test]
fn test_output_schema_structure() {
let schema = result_schema();
assert_eq!(schema.fields().len(), 3);
let query_id_field = schema.field_with_name(COL_QUERY_ID).unwrap();
assert_eq!(query_id_field.data_type(), &DataType::Int64);
assert!(!query_id_field.is_nullable());
let bucket_id_field = schema.field_with_name(COL_BUCKET_ID).unwrap();
assert_eq!(bucket_id_field.data_type(), &DataType::UInt32);
assert!(!bucket_id_field.is_nullable());
let score_field = schema.field_with_name(COL_SCORE).unwrap();
assert_eq!(score_field.data_type(), &DataType::Float64);
assert!(!score_field.is_nullable());
}
}