use std::sync::Arc;
use arrow::{compute::concat_batches, datatypes::Int32Type};
use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array};
use arrow_flight::{
decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream},
encode::FlightDataEncoderBuilder,
error::FlightError,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use bytes::Bytes;
use futures::{StreamExt, TryStreamExt};
#[tokio::test]
async fn test_empty() {
roundtrip(vec![]).await;
}
#[tokio::test]
async fn test_empty_batch() {
let batch = make_primative_batch(5);
let empty = RecordBatch::new_empty(batch.schema());
roundtrip(vec![empty]).await;
}
#[tokio::test]
async fn test_error() {
let input_batch_stream =
futures::stream::iter(vec![Err(FlightError::NotYetImplemented("foo".into()))]);
let encoder = FlightDataEncoderBuilder::default();
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let result: Result<Vec<_>, _> = decode_stream.try_collect().await;
let result = result.unwrap_err();
assert_eq!(result.to_string(), r#"NotYetImplemented("foo")"#);
}
#[tokio::test]
async fn test_primative_one() {
roundtrip(vec![make_primative_batch(5)]).await;
}
#[tokio::test]
async fn test_primative_many() {
roundtrip(vec![
make_primative_batch(1),
make_primative_batch(7),
make_primative_batch(32),
])
.await;
}
#[tokio::test]
async fn test_primative_empty() {
let batch = make_primative_batch(5);
let empty = RecordBatch::new_empty(batch.schema());
roundtrip(vec![batch, empty]).await;
}
#[tokio::test]
async fn test_dictionary_one() {
roundtrip_dictionary(vec![make_dictionary_batch(5)]).await;
}
#[tokio::test]
async fn test_dictionary_many() {
roundtrip_dictionary(vec![
make_dictionary_batch(5),
make_dictionary_batch(9),
make_dictionary_batch(5),
make_dictionary_batch(5),
])
.await;
}
#[tokio::test]
async fn test_app_metadata() {
let input_batch_stream = futures::stream::iter(vec![Ok(make_primative_batch(78))]);
let app_metadata = Bytes::from("My Metadata");
let encoder = FlightDataEncoderBuilder::default().with_metadata(app_metadata.clone());
let encode_stream = encoder.build(input_batch_stream);
let decode_stream =
FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner();
let mut messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails");
println!("{messages:#?}");
assert_eq!(messages.len(), 2);
let message2 = messages.pop().unwrap();
let message1 = messages.pop().unwrap();
assert_eq!(message1.app_metadata(), app_metadata);
assert!(matches!(message1.payload, DecodedPayload::Schema(_)));
assert_eq!(message2.app_metadata(), Bytes::new());
assert!(matches!(message2.payload, DecodedPayload::RecordBatch(_)));
}
#[tokio::test]
async fn test_max_message_size() {
let input_batch_stream = futures::stream::iter(vec![Ok(make_primative_batch(5))]);
let encoder = FlightDataEncoderBuilder::default().with_max_flight_data_size(1);
let encode_stream = encoder.build(input_batch_stream);
let decode_stream =
FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner();
let messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails");
println!("{messages:#?}");
assert_eq!(messages.len(), 6);
assert!(matches!(messages[0].payload, DecodedPayload::Schema(_)));
for message in messages.iter().skip(1) {
assert!(matches!(message.payload, DecodedPayload::RecordBatch(_)));
}
}
#[tokio::test]
async fn test_max_message_size_fuzz() {
let input = vec![
make_primative_batch(123),
make_primative_batch(17),
make_primative_batch(201),
make_primative_batch(2),
make_primative_batch(1),
make_primative_batch(11),
make_primative_batch(127),
];
for max_message_size_bytes in [10, 1024, 2048, 6400, 3211212] {
let encoder = FlightDataEncoderBuilder::default()
.with_max_flight_data_size(max_message_size_bytes);
let input_batch_stream = futures::stream::iter(input.clone()).map(Ok);
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let output: Vec<_> = decode_stream.try_collect().await.expect("encode / decode");
let input_batch = concat_batches(&input[0].schema(), &input).unwrap();
let output_batch = concat_batches(&output[0].schema(), &output).unwrap();
assert_eq!(input_batch, output_batch);
}
}
#[tokio::test]
async fn test_mismatched_record_batch_schema() {
let input_batch_stream = futures::stream::iter(vec![
Ok(make_primative_batch(5)),
Ok(make_dictionary_batch(3)),
]);
let encoder = FlightDataEncoderBuilder::default();
let encode_stream = encoder.build(input_batch_stream);
let result: Result<Vec<_>, FlightError> = encode_stream.try_collect().await;
let err = result.unwrap_err();
assert_eq!(
err.to_string(),
"Arrow(InvalidArgumentError(\"number of columns(1) must match number of fields(2) in schema\"))"
);
}
#[tokio::test]
async fn test_chained_streams_batch_decoder() {
let batch1 = make_primative_batch(5);
let batch2 = make_dictionary_batch(3);
let encode_stream1 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch1.clone())]));
let encode_stream2 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch2.clone())]));
let encode_stream = encode_stream1.chain(encode_stream2);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let result: Result<Vec<_>, FlightError> = decode_stream.try_collect().await;
let err = result.unwrap_err();
assert_eq!(
err.to_string(),
"ProtocolError(\"Unexpectedly saw multiple Schema messages in FlightData stream\")"
);
}
#[tokio::test]
async fn test_chained_streams_data_decoder() {
let batch1 = make_primative_batch(5);
let batch2 = make_dictionary_batch(3);
let encode_stream1 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch1.clone())]));
let encode_stream2 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch2.clone())]));
let encode_stream = encode_stream1.chain(encode_stream2);
let decode_stream = FlightDataDecoder::new(encode_stream);
let decoded_data: Vec<_> =
decode_stream.try_collect().await.expect("encode / decode");
println!("decoded data: {decoded_data:#?}");
assert_eq!(decoded_data.len(), 4);
assert!(matches!(decoded_data[0].payload, DecodedPayload::Schema(_)));
assert!(matches!(
decoded_data[1].payload,
DecodedPayload::RecordBatch(_)
));
assert!(matches!(decoded_data[2].payload, DecodedPayload::Schema(_)));
assert!(matches!(
decoded_data[3].payload,
DecodedPayload::RecordBatch(_)
));
}
#[tokio::test]
async fn test_mismatched_schema_message() {
async fn do_test(batch1: RecordBatch, batch2: RecordBatch, expected: &str) {
let encode_stream1 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch1.clone())]))
.take(1);
let encode_stream2 = FlightDataEncoderBuilder::default()
.build(futures::stream::iter(vec![Ok(batch2.clone())]))
.skip(1);
let encode_stream = encode_stream1.chain(encode_stream2);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let result: Result<Vec<_>, FlightError> = decode_stream.try_collect().await;
let err = result.unwrap_err().to_string();
assert!(
err.contains(expected),
"could not find '{expected}' in '{err}'"
);
}
do_test(
make_primative_batch(5),
make_dictionary_batch(3),
"Error decoding ipc RecordBatch: Io error: Invalid data for schema",
)
.await;
do_test(
make_dictionary_batch(3),
make_primative_batch(5),
"Error decoding ipc RecordBatch: Invalid argument error",
)
.await;
}
fn make_primative_batch(num_rows: usize) -> RecordBatch {
let i: UInt8Array = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
Some(i.try_into().unwrap())
}
})
.collect();
let f: Float64Array = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
Some((num_rows - i) as f64)
}
})
.collect();
RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))])
.unwrap()
}
fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
let values: Vec<_> = (0..num_rows)
.map(|i| {
if i == num_rows / 2 {
None
} else {
let v = i / 3;
Some(format!("value{v}"))
}
})
.collect();
let a: DictionaryArray<Int32Type> = values
.iter()
.map(|s| s.as_ref().map(|s| s.as_str()))
.collect();
RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
}
async fn roundtrip(input: Vec<RecordBatch>) {
let expected_output = input.clone();
roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output)
.await
}
async fn roundtrip_dictionary(input: Vec<RecordBatch>) {
let schema = Arc::new(prepare_schema_for_flight(&input[0].schema()));
let expected_output: Vec<_> = input
.iter()
.map(|batch| prepare_batch_for_flight(batch, schema.clone()).unwrap())
.collect();
roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output)
.await
}
async fn roundtrip_with_encoder(
encoder: FlightDataEncoderBuilder,
input_batches: Vec<RecordBatch>,
expected_batches: Vec<RecordBatch>,
) {
println!("Round tripping with encoder:\n{encoder:#?}");
let input_batch_stream = futures::stream::iter(input_batches.clone()).map(Ok);
let encode_stream = encoder.build(input_batch_stream);
let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream);
let output_batches: Vec<_> =
decode_stream.try_collect().await.expect("encode / decode");
let expected_batches: Vec<_> = expected_batches
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect();
assert_eq!(expected_batches, output_batches);
}
fn prepare_schema_for_flight(schema: &Schema) -> Schema {
let fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Dictionary(_, value_type) => Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
_ => field.clone(),
})
.collect();
Schema::new(fields)
}
fn prepare_batch_for_flight(
batch: &RecordBatch,
schema: SchemaRef,
) -> Result<RecordBatch, FlightError> {
let columns = batch
.columns()
.iter()
.map(hydrate_dictionary)
.collect::<Result<Vec<_>, _>>()?;
Ok(RecordBatch::try_new(schema, columns)?)
}
fn hydrate_dictionary(array: &ArrayRef) -> Result<ArrayRef, FlightError> {
let arr = if let DataType::Dictionary(_, value) = array.data_type() {
arrow_cast::cast(array, value)?
} else {
Arc::clone(array)
};
Ok(arr)
}