use std::{collections::HashMap, sync::Arc};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_cast::pretty::pretty_format_batches;
use arrow_flight::flight_descriptor::DescriptorType;
use arrow_flight::FlightDescriptor;
use arrow_flight::{
decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream},
encode::FlightDataEncoderBuilder,
error::FlightError,
};
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
use bytes::Bytes;
use futures::{StreamExt, TryStreamExt};
mod common;
use common::utils::{make_dictionary_batch, make_primitive_batch, make_view_batches};
#[tokio::test]
async fn test_empty() {
roundtrip(vec![]).await;
}
#[tokio::test]
async fn test_empty_batch() {
let batch = make_primitive_batch(5);
let empty = RecordBatch::new_empty(batch.schema());
roundtrip(vec![empty]).await;
}
#[tokio::test]
async fn test_error() {
let input_batch_stream =
futures::stream::iter(vec![Err(FlightError::NotYetImplemented("foo".into()))]);
let encoder = FlightDataEncoderBuilder::default();
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let result: Result<Vec<_>, _> = decode_stream.try_collect().await;
let result = result.unwrap_err();
assert_eq!(result.to_string(), "Not yet implemented: foo");
}
#[tokio::test]
async fn test_primitive_one() {
roundtrip(vec![make_primitive_batch(5)]).await;
}
#[tokio::test]
async fn test_schema_metadata() {
let batch = make_primitive_batch(5);
let metadata = HashMap::from([("some_key".to_owned(), "some_value".to_owned())]);
let schema = Arc::new(batch.schema().as_ref().clone().with_metadata(metadata));
let batch = RecordBatch::try_new(schema, batch.columns().to_vec()).unwrap();
roundtrip(vec![batch]).await;
}
#[tokio::test]
async fn test_primitive_many() {
roundtrip(vec![
make_primitive_batch(1),
make_primitive_batch(7),
make_primitive_batch(32),
])
.await;
}
#[tokio::test]
async fn test_primitive_empty() {
let batch = make_primitive_batch(5);
let empty = RecordBatch::new_empty(batch.schema());
roundtrip(vec![batch, empty]).await;
}
#[tokio::test]
async fn test_dictionary_one() {
roundtrip_dictionary(vec![make_dictionary_batch(5)]).await;
}
#[tokio::test]
async fn test_dictionary_many() {
roundtrip_dictionary(vec![
make_dictionary_batch(5),
make_dictionary_batch(9),
make_dictionary_batch(5),
make_dictionary_batch(5),
])
.await;
}
#[tokio::test]
async fn test_view_types_one() {
roundtrip(vec![make_view_batches(5)]).await;
}
#[tokio::test]
async fn test_view_types_many() {
roundtrip(vec![
make_view_batches(5),
make_view_batches(9),
make_view_batches(5),
make_view_batches(5),
])
.await;
}
#[tokio::test]
async fn test_zero_batches_no_schema() {
let stream = FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![]));
let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream);
assert!(decoder.schema().is_none());
assert!(decoder.next().await.is_none());
assert!(decoder.schema().is_none());
}
#[tokio::test]
async fn test_zero_batches_schema_specified() {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
let stream = FlightDataEncoderBuilder::default()
.with_schema(schema.clone())
.build(futures::stream::iter(vec![]));
let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream);
assert!(decoder.schema().is_none());
assert!(decoder.next().await.is_none());
assert_eq!(decoder.schema(), Some(&schema));
}
#[tokio::test]
async fn test_with_flight_descriptor() {
let stream = futures::stream::iter(vec![Ok(make_dictionary_batch(5))]);
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
let descriptor = Some(FlightDescriptor {
r#type: DescriptorType::Path.into(),
path: vec!["table_name".to_string()],
cmd: Bytes::default(),
});
let encoder = FlightDataEncoderBuilder::default()
.with_schema(schema.clone())
.with_flight_descriptor(descriptor.clone());
let mut encoder = encoder.build(stream);
let first_batch = encoder.next().await.unwrap().unwrap();
assert_eq!(first_batch.flight_descriptor, descriptor);
}
#[tokio::test]
async fn test_zero_batches_dictionary_schema_specified() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new_dictionary("b", DataType::Int32, DataType::Utf8, false),
]));
let expected_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, false),
]));
let stream = FlightDataEncoderBuilder::default()
.with_schema(schema.clone())
.build(futures::stream::iter(vec![]));
let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream);
assert!(decoder.schema().is_none());
assert!(decoder.next().await.is_none());
assert_eq!(decoder.schema(), Some(&expected_schema));
}
#[tokio::test]
async fn test_app_metadata() {
let input_batch_stream = futures::stream::iter(vec![Ok(make_primitive_batch(78))]);
let app_metadata = Bytes::from("My Metadata");
let encoder = FlightDataEncoderBuilder::default().with_metadata(app_metadata.clone());
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner();
let mut messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails");
println!("{messages:#?}");
assert_eq!(messages.len(), 2);
let message2 = messages.pop().unwrap();
let message1 = messages.pop().unwrap();
assert_eq!(message1.app_metadata(), app_metadata);
assert!(matches!(message1.payload, DecodedPayload::Schema(_)));
assert_eq!(message2.app_metadata(), Bytes::new());
assert!(matches!(message2.payload, DecodedPayload::RecordBatch(_)));
}
#[tokio::test]
async fn test_max_message_size() {
let input_batch_stream = futures::stream::iter(vec![Ok(make_primitive_batch(5))]);
let encoder = FlightDataEncoderBuilder::default().with_max_flight_data_size(1);
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner();
let messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails");
println!("{messages:#?}");
assert_eq!(messages.len(), 6);
assert!(matches!(messages[0].payload, DecodedPayload::Schema(_)));
for message in messages.iter().skip(1) {
assert!(matches!(message.payload, DecodedPayload::RecordBatch(_)));
}
}
#[tokio::test]
async fn test_max_message_size_fuzz() {
let input = vec![
make_primitive_batch(123),
make_primitive_batch(17),
make_primitive_batch(201),
make_primitive_batch(2),
make_primitive_batch(1),
make_primitive_batch(11),
make_primitive_batch(127),
];
for max_message_size_bytes in [10, 1024, 2048, 6400, 3211212] {
let encoder =
FlightDataEncoderBuilder::default().with_max_flight_data_size(max_message_size_bytes);
let input_batch_stream = futures::stream::iter(input.clone()).map(Ok);
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let output: Vec<_> = decode_stream.try_collect().await.expect("encode / decode");
for b in &output {
assert_eq!(b.schema(), input[0].schema());
}
let a = pretty_format_batches(&input).unwrap().to_string();
let b = pretty_format_batches(&output).unwrap().to_string();
assert_eq!(a, b);
}
}
#[tokio::test]
async fn test_mismatched_record_batch_schema() {
let input_batch_stream = futures::stream::iter(vec![
Ok(make_primitive_batch(5)),
Ok(make_dictionary_batch(3)),
]);
let encoder = FlightDataEncoderBuilder::default();
let encode_stream = encoder.build(input_batch_stream);
let result: Result<Vec<_>, FlightError> = encode_stream.try_collect().await;
let err = result.unwrap_err();
assert_eq!(
err.to_string(),
"Arrow error: Invalid argument error: number of columns(1) must match number of fields(2) in schema"
);
}
#[tokio::test]
async fn test_chained_streams_batch_decoder() {
let batch1 = make_primitive_batch(5);
let batch2 = make_dictionary_batch(3);
let encode_stream1 =
FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())]));
let encode_stream2 =
FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())]));
let encode_stream = encode_stream1.chain(encode_stream2);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let result: Result<Vec<_>, FlightError> = decode_stream.try_collect().await;
let err = result.unwrap_err();
assert_eq!(
err.to_string(),
"Protocol error: Unexpectedly saw multiple Schema messages in FlightData stream"
);
}
#[tokio::test]
async fn test_chained_streams_data_decoder() {
let batch1 = make_primitive_batch(5);
let batch2 = make_dictionary_batch(3);
let encode_stream1 =
FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())]));
let encode_stream2 =
FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())]));
let encode_stream = encode_stream1.chain(encode_stream2);
let decode_stream = FlightDataDecoder::new(encode_stream);
let decoded_data: Vec<_> = decode_stream.try_collect().await.expect("encode / decode");
println!("decoded data: {decoded_data:#?}");
assert_eq!(decoded_data.len(), 4);
assert!(matches!(decoded_data[0].payload, DecodedPayload::Schema(_)));
assert!(matches!(
decoded_data[1].payload,
DecodedPayload::RecordBatch(_)
));
assert!(matches!(decoded_data[2].payload, DecodedPayload::Schema(_)));
assert!(matches!(
decoded_data[3].payload,
DecodedPayload::RecordBatch(_)
));
}
#[tokio::test]
async fn test_mismatched_schema_message() {
async fn do_test(batch1: RecordBatch, batch2: RecordBatch, expected: &str) {
let encode_stream1 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch1.clone())]))
.take(1);
let encode_stream2 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch2.clone())]))
.skip(1);
let encode_stream = encode_stream1.chain(encode_stream2);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let result: Result<Vec<_>, FlightError> = decode_stream.try_collect().await;
let err = result.unwrap_err().to_string();
assert!(
err.contains(expected),
"could not find '{expected}' in '{err}'"
);
}
do_test(
make_primitive_batch(5),
make_dictionary_batch(3),
"Error decoding ipc RecordBatch: Schema error: Invalid data for schema",
)
.await;
do_test(
make_dictionary_batch(3),
make_primitive_batch(5),
"Error decoding ipc RecordBatch: Invalid argument error",
)
.await;
}
async fn roundtrip(input: Vec<RecordBatch>) {
let expected_output = input.clone();
roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await
}
async fn roundtrip_dictionary(input: Vec<RecordBatch>) {
let schema = Arc::new(prepare_schema_for_flight(input[0].schema_ref()));
let expected_output: Vec<_> = input
.iter()
.map(|batch| prepare_batch_for_flight(batch, schema.clone()).unwrap())
.collect();
roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await
}
async fn roundtrip_with_encoder(
encoder: FlightDataEncoderBuilder,
input_batches: Vec<RecordBatch>,
expected_batches: Vec<RecordBatch>,
) {
println!("Round tripping with encoder:\n{encoder:#?}");
let input_batch_stream = futures::stream::iter(input_batches.clone()).map(Ok);
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let output_batches: Vec<_> = decode_stream.try_collect().await.expect("encode / decode");
let expected_batches: Vec<_> = expected_batches
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect();
assert_eq!(expected_batches, output_batches);
}
fn prepare_schema_for_flight(schema: &Schema) -> Schema {
let fields: Fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Dictionary(_, value_type) => Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
_ => field.as_ref().clone(),
})
.collect();
Schema::new(fields)
}
fn prepare_batch_for_flight(
batch: &RecordBatch,
schema: SchemaRef,
) -> Result<RecordBatch, FlightError> {
let columns = batch
.columns()
.iter()
.map(hydrate_dictionary)
.collect::<Result<Vec<_>, _>>()?;
Ok(RecordBatch::try_new(schema, columns)?)
}
fn hydrate_dictionary(array: &ArrayRef) -> Result<ArrayRef, FlightError> {
let arr = if let DataType::Dictionary(_, value) = array.data_type() {
arrow_cast::cast(array, value)?
} else {
Arc::clone(array)
};
Ok(arr)
}