use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
use crate::{error::Result, FlightData, SchemaAsIpc};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use bytes::Bytes;
use futures::{ready, stream::BoxStream, Stream, StreamExt};
#[derive(Debug)]
pub struct FlightDataEncoderBuilder {
max_flight_data_size: usize,
options: IpcWriteOptions,
app_metadata: Bytes,
schema: Option<SchemaRef>,
}
pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
impl Default for FlightDataEncoderBuilder {
fn default() -> Self {
Self {
max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
options: IpcWriteOptions::default(),
app_metadata: Bytes::new(),
schema: None,
}
}
}
impl FlightDataEncoderBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
self.max_flight_data_size = max_flight_data_size;
self
}
pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
self.app_metadata = app_metadata;
self
}
pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
self.options = options;
self
}
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
self
}
pub fn build<S>(self, input: S) -> FlightDataEncoder
where
S: Stream<Item = Result<RecordBatch>> + Send + 'static,
{
let Self {
max_flight_data_size,
options,
app_metadata,
schema,
} = self;
FlightDataEncoder::new(
input.boxed(),
schema,
max_flight_data_size,
options,
app_metadata,
)
}
}
pub struct FlightDataEncoder {
inner: BoxStream<'static, Result<RecordBatch>>,
schema: Option<SchemaRef>,
max_flight_data_size: usize,
encoder: FlightIpcEncoder,
app_metadata: Option<Bytes>,
queue: VecDeque<FlightData>,
done: bool,
}
impl FlightDataEncoder {
fn new(
inner: BoxStream<'static, Result<RecordBatch>>,
schema: Option<SchemaRef>,
max_flight_data_size: usize,
options: IpcWriteOptions,
app_metadata: Bytes,
) -> Self {
let mut encoder = Self {
inner,
schema: None,
max_flight_data_size,
encoder: FlightIpcEncoder::new(options),
app_metadata: Some(app_metadata),
queue: VecDeque::new(),
done: false,
};
if let Some(schema) = schema {
encoder.encode_schema(&schema);
}
encoder
}
fn queue_message(&mut self, data: FlightData) {
self.queue.push_back(data);
}
fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
for data in datas {
self.queue_message(data)
}
}
fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
let schema = Arc::new(prepare_schema_for_flight(schema));
let mut schema_flight_data = self.encoder.encode_schema(&schema);
if let Some(app_metadata) = self.app_metadata.take() {
schema_flight_data.app_metadata = app_metadata;
}
self.queue_message(schema_flight_data);
self.schema = Some(schema.clone());
schema
}
fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
let schema = match &self.schema {
Some(schema) => schema.clone(),
None => self.encode_schema(&batch.schema()),
};
let batch = prepare_batch_for_flight(&batch, schema)?;
for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
let (flight_dictionaries, flight_batch) =
self.encoder.encode_batch(&batch)?;
self.queue_messages(flight_dictionaries);
self.queue_message(flight_batch);
}
Ok(())
}
}
impl Stream for FlightDataEncoder {
type Item = Result<FlightData>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
if self.done && self.queue.is_empty() {
return Poll::Ready(None);
}
if let Some(data) = self.queue.pop_front() {
return Poll::Ready(Some(Ok(data)));
}
let batch = ready!(self.inner.poll_next_unpin(cx));
match batch {
None => {
self.done = true;
assert!(self.queue.is_empty());
return Poll::Ready(None);
}
Some(Err(e)) => {
self.done = true;
self.queue.clear();
return Poll::Ready(Some(Err(e)));
}
Some(Ok(batch)) => {
if let Err(e) = self.encode_batch(batch) {
self.done = true;
self.queue.clear();
return Poll::Ready(Some(Err(e)));
}
}
}
}
}
}
fn prepare_schema_for_flight(schema: &Schema) -> Schema {
let fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Dictionary(_, value_type) => Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
_ => field.clone(),
})
.collect();
Schema::new(fields)
}
fn split_batch_for_grpc_response(
batch: RecordBatch,
max_flight_data_size: usize,
) -> Vec<RecordBatch> {
let size = batch
.columns()
.iter()
.map(|col| col.get_buffer_memory_size())
.sum::<usize>();
let n_batches = (size / max_flight_data_size
+ usize::from(size % max_flight_data_size != 0))
.max(1);
let rows_per_batch = (batch.num_rows() / n_batches).max(1);
let mut out = Vec::with_capacity(n_batches + 1);
let mut offset = 0;
while offset < batch.num_rows() {
let length = (rows_per_batch).min(batch.num_rows() - offset);
out.push(batch.slice(offset, length));
offset += length;
}
out
}
struct FlightIpcEncoder {
options: IpcWriteOptions,
data_gen: IpcDataGenerator,
dictionary_tracker: DictionaryTracker,
}
impl FlightIpcEncoder {
fn new(options: IpcWriteOptions) -> Self {
let error_on_replacement = true;
Self {
options,
data_gen: IpcDataGenerator::default(),
dictionary_tracker: DictionaryTracker::new(error_on_replacement),
}
}
fn encode_schema(&self, schema: &Schema) -> FlightData {
SchemaAsIpc::new(schema, &self.options).into()
}
fn encode_batch(
&mut self,
batch: &RecordBatch,
) -> Result<(Vec<FlightData>, FlightData)> {
let (encoded_dictionaries, encoded_batch) = self.data_gen.encoded_batch(
batch,
&mut self.dictionary_tracker,
&self.options,
)?;
let flight_dictionaries =
encoded_dictionaries.into_iter().map(Into::into).collect();
let flight_batch = encoded_batch.into();
Ok((flight_dictionaries, flight_batch))
}
}
fn prepare_batch_for_flight(
batch: &RecordBatch,
schema: SchemaRef,
) -> Result<RecordBatch> {
let columns = batch
.columns()
.iter()
.map(hydrate_dictionary)
.collect::<Result<Vec<_>>>()?;
Ok(RecordBatch::try_new(schema, columns)?)
}
fn hydrate_dictionary(array: &ArrayRef) -> Result<ArrayRef> {
let arr = if let DataType::Dictionary(_, value) = array.data_type() {
arrow_cast::cast(array, value)?
} else {
Arc::clone(array)
};
Ok(arr)
}
#[cfg(test)]
mod tests {
use arrow::{
array::{UInt32Array, UInt8Array},
compute::concat_batches,
datatypes::Int32Type,
};
use arrow_array::{
DictionaryArray, Int16Array, Int32Array, Int64Array, StringArray, UInt64Array,
};
use super::*;
#[test]
fn test_encode_flight_data() {
let options = arrow::ipc::writer::IpcWriteOptions::default();
let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
.expect("cannot create record batch");
let schema = batch.schema();
let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
let big_batch = batch.slice(0, batch.num_rows() - 1);
let optimized_big_batch =
prepare_batch_for_flight(&big_batch, Arc::clone(&schema))
.expect("failed to optimize");
let (_, optimized_big_flight_batch) =
make_flight_data(&optimized_big_batch, &options);
assert_eq!(
baseline_flight_batch.data_body.len(),
optimized_big_flight_batch.data_body.len()
);
let small_batch = batch.slice(0, 1);
let optimized_small_batch =
prepare_batch_for_flight(&small_batch, Arc::clone(&schema))
.expect("failed to optimize");
let (_, optimized_small_flight_batch) =
make_flight_data(&optimized_small_batch, &options);
assert!(
baseline_flight_batch.data_body.len()
> optimized_small_flight_batch.data_body.len()
);
}
pub fn make_flight_data(
batch: &RecordBatch,
options: &IpcWriteOptions,
) -> (Vec<FlightData>, FlightData) {
let data_gen = IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(false);
let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, &mut dictionary_tracker, options)
.expect("DictionaryTracker configured above to not error on replacement");
let flight_dictionaries =
encoded_dictionaries.into_iter().map(Into::into).collect();
let flight_batch = encoded_batch.into();
(flight_dictionaries, flight_batch)
}
#[test]
fn test_split_batch_for_grpc_response() {
let max_flight_data_size = 1024;
let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
assert_eq!(split.len(), 1);
assert_eq!(batch, split[0]);
let n_rows = max_flight_data_size + 1;
assert!(n_rows % 2 == 1, "should be an odd number");
let c =
UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
assert_eq!(split.len(), 3);
assert_eq!(
split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
n_rows
);
assert_eq!(concat_batches(&batch.schema(), &split).unwrap(), batch);
}
#[test]
fn test_split_batch_for_grpc_response_sizes() {
verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
verify_split(10, 1024, vec![10]);
}
fn verify_split(
num_input_rows: u64,
max_flight_data_size_bytes: usize,
expected_sizes: Vec<usize>,
) {
let array: UInt64Array = (0..num_input_rows).collect();
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
.expect("cannot create record batch");
let input_rows = batch.num_rows();
let split =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
let sizes: Vec<_> = split.iter().map(|batch| batch.num_rows()).collect();
let output_rows: usize = sizes.iter().sum();
assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
}
#[tokio::test]
async fn flight_data_size_even() {
let s1 =
StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024));
let i1 = Int16Array::from_iter_values(0..1024);
let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024));
let i2 = Int64Array::from_iter_values(0..1024);
let batch = RecordBatch::try_from_iter(vec![
("s1", Arc::new(s1) as _),
("i1", Arc::new(i1) as _),
("s2", Arc::new(s2) as _),
("i2", Arc::new(i2) as _),
])
.unwrap();
verify_encoded_split(batch, 112).await;
}
#[tokio::test]
async fn flight_data_size_uneven_variable_lengths() {
let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
let batch =
RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
verify_encoded_split(batch, 4304).await;
}
#[tokio::test]
async fn flight_data_size_large_row() {
let array1 = StringArray::from_iter_values(vec![
"*".repeat(500),
"*".repeat(500),
"*".repeat(500),
"*".repeat(500),
]);
let array2 = StringArray::from_iter_values(vec![
"*".to_string(),
"*".repeat(1000),
"*".repeat(2000),
"*".repeat(4000),
]);
let array3 = StringArray::from_iter_values(vec![
"*".to_string(),
"*".to_string(),
"*".repeat(1000),
"*".repeat(2000),
]);
let batch = RecordBatch::try_from_iter(vec![
("a1", Arc::new(array1) as _),
("a2", Arc::new(array2) as _),
("a3", Arc::new(array3) as _),
])
.unwrap();
verify_encoded_split(batch, 5800).await;
}
#[tokio::test]
async fn flight_data_size_string_dictionary() {
let array: DictionaryArray<Int32Type> = (1..1024)
.map(|i| match i % 3 {
0 => Some("value0"),
1 => Some("value1"),
_ => None,
})
.collect();
let batch =
RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
verify_encoded_split(batch, 160).await;
}
#[tokio::test]
async fn flight_data_size_large_dictionary() {
let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
let array: DictionaryArray<Int32Type> =
values.iter().map(|s| Some(s.as_str())).collect();
let batch =
RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
verify_encoded_split(batch, 3328).await;
}
#[tokio::test]
async fn flight_data_size_large_dictionary_repeated_non_uniform() {
let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
let array = DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
let batch =
RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
verify_encoded_split(batch, 5280).await;
}
#[tokio::test]
async fn flight_data_size_multiple_dictionaries() {
let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
let array1: DictionaryArray<Int32Type> =
values1.iter().map(|s| Some(s.as_str())).collect();
let array2: DictionaryArray<Int32Type> =
values2.iter().map(|s| Some(s.as_str())).collect();
let array3: DictionaryArray<Int32Type> =
values3.iter().map(|s| Some(s.as_str())).collect();
let batch = RecordBatch::try_from_iter(vec![
("a1", Arc::new(array1) as _),
("a2", Arc::new(array2) as _),
("a3", Arc::new(array3) as _),
])
.unwrap();
verify_encoded_split(batch, 4128).await;
}
fn flight_data_size(d: &FlightData) -> usize {
let flight_descriptor_size = d
.flight_descriptor
.as_ref()
.map(|descriptor| {
let path_len: usize =
descriptor.path.iter().map(|p| p.as_bytes().len()).sum();
std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
})
.unwrap_or(0);
flight_descriptor_size
+ d.app_metadata.len()
+ d.data_body.len()
+ d.data_header.len()
}
async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
let num_rows = batch.num_rows();
let mut max_overage_seen = 0;
for max_flight_data_size in [1024, 2021, 5000] {
println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
let mut stream = FlightDataEncoderBuilder::new()
.with_max_flight_data_size(max_flight_data_size)
.build(futures::stream::iter([Ok(batch.clone())]));
let mut i = 0;
while let Some(data) = stream.next().await.transpose().unwrap() {
let actual_data_size = flight_data_size(&data);
let actual_overage = if actual_data_size > max_flight_data_size {
actual_data_size - max_flight_data_size
} else {
0
};
assert!(
actual_overage <= allowed_overage,
"encoded data[{i}]: actual size {actual_data_size}, \
actual_overage: {actual_overage} \
allowed_overage: {allowed_overage}"
);
i += 1;
max_overage_seen = max_overage_seen.max(actual_overage)
}
}
assert_eq!(
allowed_overage, max_overage_seen,
"Specified overage was too high"
);
}
}