1use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
19
20use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc};
21
22use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
24
25use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
26use bytes::Bytes;
27use futures::{ready, stream::BoxStream, Stream, StreamExt};
28
29#[derive(Debug)]
145pub struct FlightDataEncoderBuilder {
146 max_flight_data_size: usize,
149 options: IpcWriteOptions,
151 app_metadata: Bytes,
153 schema: Option<SchemaRef>,
155 descriptor: Option<FlightDescriptor>,
157 dictionary_handling: DictionaryHandling,
160}
161
162pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
167
168impl Default for FlightDataEncoderBuilder {
169 fn default() -> Self {
170 Self {
171 max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
172 options: IpcWriteOptions::default(),
173 app_metadata: Bytes::new(),
174 schema: None,
175 descriptor: None,
176 dictionary_handling: DictionaryHandling::Hydrate,
177 }
178 }
179}
180
181impl FlightDataEncoderBuilder {
182 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
198 self.max_flight_data_size = max_flight_data_size;
199 self
200 }
201
202 pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
204 self.dictionary_handling = dictionary_handling;
205 self
206 }
207
208 pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
212 self.app_metadata = app_metadata;
213 self
214 }
215
216 pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
218 self.options = options;
219 self
220 }
221
222 pub fn with_schema(mut self, schema: SchemaRef) -> Self {
227 self.schema = Some(schema);
228 self
229 }
230
231 pub fn with_flight_descriptor(mut self, descriptor: Option<FlightDescriptor>) -> Self {
233 self.descriptor = descriptor;
234 self
235 }
236
237 pub fn build<S>(self, input: S) -> FlightDataEncoder
242 where
243 S: Stream<Item = Result<RecordBatch>> + Send + 'static,
244 {
245 let Self {
246 max_flight_data_size,
247 options,
248 app_metadata,
249 schema,
250 descriptor,
251 dictionary_handling,
252 } = self;
253
254 FlightDataEncoder::new(
255 input.boxed(),
256 schema,
257 max_flight_data_size,
258 options,
259 app_metadata,
260 descriptor,
261 dictionary_handling,
262 )
263 }
264}
265
266pub struct FlightDataEncoder {
270 inner: BoxStream<'static, Result<RecordBatch>>,
272 schema: Option<SchemaRef>,
274 max_flight_data_size: usize,
277 encoder: FlightIpcEncoder,
279 app_metadata: Option<Bytes>,
281 queue: VecDeque<FlightData>,
283 done: bool,
285 descriptor: Option<FlightDescriptor>,
287 dictionary_handling: DictionaryHandling,
290}
291
292impl FlightDataEncoder {
293 fn new(
294 inner: BoxStream<'static, Result<RecordBatch>>,
295 schema: Option<SchemaRef>,
296 max_flight_data_size: usize,
297 options: IpcWriteOptions,
298 app_metadata: Bytes,
299 descriptor: Option<FlightDescriptor>,
300 dictionary_handling: DictionaryHandling,
301 ) -> Self {
302 let mut encoder = Self {
303 inner,
304 schema: None,
305 max_flight_data_size,
306 encoder: FlightIpcEncoder::new(
307 options,
308 dictionary_handling != DictionaryHandling::Resend,
309 ),
310 app_metadata: Some(app_metadata),
311 queue: VecDeque::new(),
312 done: false,
313 descriptor,
314 dictionary_handling,
315 };
316
317 if let Some(schema) = schema {
319 encoder.encode_schema(&schema);
320 }
321
322 encoder
323 }
324
325 pub fn known_schema(&self) -> Option<SchemaRef> {
328 self.schema.clone()
329 }
330
331 fn queue_message(&mut self, mut data: FlightData) {
333 if let Some(descriptor) = self.descriptor.take() {
334 data.flight_descriptor = Some(descriptor);
335 }
336 self.queue.push_back(data);
337 }
338
339 fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
341 for data in datas {
342 self.queue_message(data)
343 }
344 }
345
346 fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
349 let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend;
352 let schema = Arc::new(prepare_schema_for_flight(
353 schema,
354 &mut self.encoder.dictionary_tracker,
355 send_dictionaries,
356 ));
357 let mut schema_flight_data = self.encoder.encode_schema(&schema);
358
359 if let Some(app_metadata) = self.app_metadata.take() {
361 schema_flight_data.app_metadata = app_metadata;
362 }
363 self.queue_message(schema_flight_data);
364 self.schema = Some(schema.clone());
366 schema
367 }
368
369 fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
371 let schema = match &self.schema {
372 Some(schema) => schema.clone(),
373 None => self.encode_schema(batch.schema_ref()),
375 };
376
377 let batch = match self.dictionary_handling {
378 DictionaryHandling::Resend => batch,
379 DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
380 };
381
382 for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
383 let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
384
385 self.queue_messages(flight_dictionaries);
386 self.queue_message(flight_batch);
387 }
388
389 Ok(())
390 }
391}
392
393impl Stream for FlightDataEncoder {
394 type Item = Result<FlightData>;
395
396 fn poll_next(
397 mut self: Pin<&mut Self>,
398 cx: &mut std::task::Context<'_>,
399 ) -> Poll<Option<Self::Item>> {
400 loop {
401 if self.done && self.queue.is_empty() {
402 return Poll::Ready(None);
403 }
404
405 if let Some(data) = self.queue.pop_front() {
407 return Poll::Ready(Some(Ok(data)));
408 }
409
410 let batch = ready!(self.inner.poll_next_unpin(cx));
412
413 match batch {
414 None => {
415 self.done = true;
417 assert!(self.queue.is_empty());
419 return Poll::Ready(None);
420 }
421 Some(Err(e)) => {
422 self.done = true;
424 self.queue.clear();
425 return Poll::Ready(Some(Err(e)));
426 }
427 Some(Ok(batch)) => {
428 if let Err(e) = self.encode_batch(batch) {
430 self.done = true;
431 self.queue.clear();
432 return Poll::Ready(Some(Err(e)));
433 }
434 }
435 }
436 }
437 }
438}
439
440#[derive(Debug, PartialEq)]
469pub enum DictionaryHandling {
470 Hydrate,
478 Resend,
488}
489
490fn prepare_field_for_flight(
491 field: &FieldRef,
492 dictionary_tracker: &mut DictionaryTracker,
493 send_dictionaries: bool,
494) -> Field {
495 match field.data_type() {
496 DataType::List(inner) => Field::new_list(
497 field.name(),
498 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
499 field.is_nullable(),
500 )
501 .with_metadata(field.metadata().clone()),
502 DataType::LargeList(inner) => Field::new_list(
503 field.name(),
504 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
505 field.is_nullable(),
506 )
507 .with_metadata(field.metadata().clone()),
508 DataType::Struct(fields) => {
509 let new_fields: Vec<Field> = fields
510 .iter()
511 .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
512 .collect();
513 Field::new_struct(field.name(), new_fields, field.is_nullable())
514 .with_metadata(field.metadata().clone())
515 }
516 DataType::Union(fields, mode) => {
517 let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
518 .iter()
519 .map(|(type_id, f)| {
520 (
521 type_id,
522 prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
523 )
524 })
525 .unzip();
526
527 Field::new_union(field.name(), type_ids, new_fields, *mode)
528 }
529 DataType::Dictionary(_, value_type) => {
530 if !send_dictionaries {
531 Field::new(
532 field.name(),
533 value_type.as_ref().clone(),
534 field.is_nullable(),
535 )
536 .with_metadata(field.metadata().clone())
537 } else {
538 let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
539
540 Field::new_dict(
541 field.name(),
542 field.data_type().clone(),
543 field.is_nullable(),
544 dict_id,
545 field.dict_is_ordered().unwrap_or_default(),
546 )
547 .with_metadata(field.metadata().clone())
548 }
549 }
550 DataType::Map(inner, sorted) => Field::new(
551 field.name(),
552 DataType::Map(
553 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
554 *sorted,
555 ),
556 field.is_nullable(),
557 )
558 .with_metadata(field.metadata().clone()),
559 _ => field.as_ref().clone(),
560 }
561}
562
563fn prepare_schema_for_flight(
569 schema: &Schema,
570 dictionary_tracker: &mut DictionaryTracker,
571 send_dictionaries: bool,
572) -> Schema {
573 let fields: Fields = schema
574 .fields()
575 .iter()
576 .map(|field| match field.data_type() {
577 DataType::Dictionary(_, value_type) => {
578 if !send_dictionaries {
579 Field::new(
580 field.name(),
581 value_type.as_ref().clone(),
582 field.is_nullable(),
583 )
584 .with_metadata(field.metadata().clone())
585 } else {
586 let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
587 Field::new_dict(
588 field.name(),
589 field.data_type().clone(),
590 field.is_nullable(),
591 dict_id,
592 field.dict_is_ordered().unwrap_or_default(),
593 )
594 .with_metadata(field.metadata().clone())
595 }
596 }
597 tpe if tpe.is_nested() => {
598 prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
599 }
600 _ => field.as_ref().clone(),
601 })
602 .collect();
603
604 Schema::new(fields).with_metadata(schema.metadata().clone())
605}
606
607fn split_batch_for_grpc_response(
614 batch: RecordBatch,
615 max_flight_data_size: usize,
616) -> Vec<RecordBatch> {
617 let size = batch
618 .columns()
619 .iter()
620 .map(|col| col.get_buffer_memory_size())
621 .sum::<usize>();
622
623 let n_batches =
624 (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
625 let rows_per_batch = (batch.num_rows() / n_batches).max(1);
626 let mut out = Vec::with_capacity(n_batches + 1);
627
628 let mut offset = 0;
629 while offset < batch.num_rows() {
630 let length = (rows_per_batch).min(batch.num_rows() - offset);
631 out.push(batch.slice(offset, length));
632
633 offset += length;
634 }
635
636 out
637}
638
639struct FlightIpcEncoder {
646 options: IpcWriteOptions,
647 data_gen: IpcDataGenerator,
648 dictionary_tracker: DictionaryTracker,
649}
650
651impl FlightIpcEncoder {
652 fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
653 let preserve_dict_id = options.preserve_dict_id();
654 Self {
655 options,
656 data_gen: IpcDataGenerator::default(),
657 dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id(
658 error_on_replacement,
659 preserve_dict_id,
660 ),
661 }
662 }
663
664 fn encode_schema(&self, schema: &Schema) -> FlightData {
666 SchemaAsIpc::new(schema, &self.options).into()
667 }
668
669 fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
672 let (encoded_dictionaries, encoded_batch) =
673 self.data_gen
674 .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?;
675
676 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
677 let flight_batch = encoded_batch.into();
678
679 Ok((flight_dictionaries, flight_batch))
680 }
681}
682
683fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result<RecordBatch> {
686 let columns = schema
687 .fields()
688 .iter()
689 .zip(batch.columns())
690 .map(|(field, c)| hydrate_dictionary(c, field.data_type()))
691 .collect::<Result<Vec<_>>>()?;
692
693 let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
694
695 Ok(RecordBatch::try_new_with_options(
696 schema, columns, &options,
697 )?)
698}
699
700fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef> {
702 let arr = match (array.data_type(), data_type) {
703 (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
704 let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();
705
706 Arc::new(UnionArray::try_new(
707 fields.clone(),
708 union_arr.type_ids().clone(),
709 None,
710 fields
711 .iter()
712 .map(|(type_id, field)| {
713 Ok(arrow_cast::cast(
714 union_arr.child(type_id),
715 field.data_type(),
716 )?)
717 })
718 .collect::<Result<Vec<_>>>()?,
719 )?)
720 }
721 (_, data_type) => arrow_cast::cast(array, data_type)?,
722 };
723 Ok(arr)
724}
725
726#[cfg(test)]
727mod tests {
728 use crate::decode::{DecodedPayload, FlightDataDecoder};
729 use arrow_array::builder::{
730 GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder,
731 };
732 use arrow_array::*;
733 use arrow_array::{cast::downcast_array, types::*};
734 use arrow_buffer::ScalarBuffer;
735 use arrow_cast::pretty::pretty_format_batches;
736 use arrow_ipc::MetadataVersion;
737 use arrow_schema::{UnionFields, UnionMode};
738 use builder::{GenericStringBuilder, MapBuilder};
739 use std::collections::HashMap;
740
741 use super::*;
742
743 #[test]
744 fn test_encode_flight_data() {
747 let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
749 let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
750
751 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
752 .expect("cannot create record batch");
753 let schema = batch.schema_ref();
754
755 let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
756
757 let big_batch = batch.slice(0, batch.num_rows() - 1);
758 let optimized_big_batch =
759 hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize");
760 let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options);
761
762 assert_eq!(
763 baseline_flight_batch.data_body.len(),
764 optimized_big_flight_batch.data_body.len()
765 );
766
767 let small_batch = batch.slice(0, 1);
768 let optimized_small_batch =
769 hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize");
770 let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options);
771
772 assert!(
773 baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len()
774 );
775 }
776
777 #[tokio::test]
778 async fn test_dictionary_hydration() {
779 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
780 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
781
782 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
783 "dict",
784 DataType::UInt16,
785 DataType::Utf8,
786 false,
787 )]));
788 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
789 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
790
791 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
792
793 let encoder = FlightDataEncoderBuilder::default().build(stream);
794 let mut decoder = FlightDataDecoder::new(encoder);
795 let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]);
796 let expected_schema = Arc::new(expected_schema);
797 let mut expected_arrays = vec![
798 StringArray::from(vec!["a", "a", "b"]),
799 StringArray::from(vec!["c", "c", "d"]),
800 ]
801 .into_iter();
802 while let Some(decoded) = decoder.next().await {
803 let decoded = decoded.unwrap();
804 match decoded.payload {
805 DecodedPayload::None => {}
806 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
807 DecodedPayload::RecordBatch(b) => {
808 assert_eq!(b.schema(), expected_schema);
809 let expected_array = expected_arrays.next().unwrap();
810 let actual_array = b.column_by_name("dict").unwrap();
811 let actual_array = downcast_array::<StringArray>(actual_array);
812
813 assert_eq!(actual_array, expected_array);
814 }
815 }
816 }
817 }
818
819 #[tokio::test]
820 async fn test_dictionary_resend() {
821 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
822 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
823
824 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
825 "dict",
826 DataType::UInt16,
827 DataType::Utf8,
828 false,
829 )]));
830 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
831 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
832
833 verify_flight_round_trip(vec![batch1, batch2]).await;
834 }
835
836 #[tokio::test]
837 async fn test_dictionary_hydration_known_schema() {
838 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
839 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
840
841 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
842 "dict",
843 DataType::UInt16,
844 DataType::Utf8,
845 false,
846 )]));
847 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
848 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
849
850 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
851
852 let encoder = FlightDataEncoderBuilder::default()
853 .with_schema(schema)
854 .build(stream);
855 let expected_schema =
856 Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
857 assert_eq!(Some(expected_schema), encoder.known_schema())
858 }
859
860 #[tokio::test]
861 async fn test_dictionary_resend_known_schema() {
862 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
863 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
864
865 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
866 "dict",
867 DataType::UInt16,
868 DataType::Utf8,
869 false,
870 )]));
871 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
872 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
873
874 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
875
876 let encoder = FlightDataEncoderBuilder::default()
877 .with_dictionary_handling(DictionaryHandling::Resend)
878 .with_schema(schema.clone())
879 .build(stream);
880 assert_eq!(Some(schema), encoder.known_schema())
881 }
882
883 #[tokio::test]
884 async fn test_multiple_dictionaries_resend() {
885 let schema = Arc::new(Schema::new(vec![
887 Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false),
888 Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false),
889 ]));
890
891 let arr_one_1: Arc<DictionaryArray<UInt16Type>> =
892 Arc::new(vec!["a", "a", "b"].into_iter().collect());
893 let arr_one_2: Arc<DictionaryArray<UInt16Type>> =
894 Arc::new(vec!["c", "c", "d"].into_iter().collect());
895 let arr_two_1: Arc<DictionaryArray<UInt16Type>> =
896 Arc::new(vec!["b", "a", "c"].into_iter().collect());
897 let arr_two_2: Arc<DictionaryArray<UInt16Type>> =
898 Arc::new(vec!["k", "d", "e"].into_iter().collect());
899 let batch1 =
900 RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()])
901 .unwrap();
902 let batch2 =
903 RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()])
904 .unwrap();
905
906 verify_flight_round_trip(vec![batch1, batch2]).await;
907 }
908
909 #[tokio::test]
910 async fn test_dictionary_list_hydration() {
911 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
912
913 builder.append_value(vec![Some("a"), None, Some("b")]);
914
915 let arr1 = builder.finish();
916
917 builder.append_value(vec![Some("c"), None, Some("d")]);
918
919 let arr2 = builder.finish();
920
921 let schema = Arc::new(Schema::new(vec![Field::new_list(
922 "dict_list",
923 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
924 true,
925 )]));
926
927 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
928 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
929
930 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
931
932 let encoder = FlightDataEncoderBuilder::default().build(stream);
933
934 let mut decoder = FlightDataDecoder::new(encoder);
935 let expected_schema = Schema::new(vec![Field::new_list(
936 "dict_list",
937 Field::new("item", DataType::Utf8, true),
938 true,
939 )]);
940
941 let expected_schema = Arc::new(expected_schema);
942
943 let mut expected_arrays = vec![
944 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
945 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
946 ]
947 .into_iter();
948
949 while let Some(decoded) = decoder.next().await {
950 let decoded = decoded.unwrap();
951 match decoded.payload {
952 DecodedPayload::None => {}
953 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
954 DecodedPayload::RecordBatch(b) => {
955 assert_eq!(b.schema(), expected_schema);
956 let expected_array = expected_arrays.next().unwrap();
957 let list_array =
958 downcast_array::<ListArray>(b.column_by_name("dict_list").unwrap());
959 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
960
961 assert_eq!(elem_array, expected_array);
962 }
963 }
964 }
965 }
966
967 #[tokio::test]
968 async fn test_dictionary_list_resend() {
969 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
970
971 builder.append_value(vec![Some("a"), None, Some("b")]);
972
973 let arr1 = builder.finish();
974
975 builder.append_value(vec![Some("c"), None, Some("d")]);
976
977 let arr2 = builder.finish();
978
979 let schema = Arc::new(Schema::new(vec![Field::new_list(
980 "dict_list",
981 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
982 true,
983 )]));
984
985 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
986 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
987
988 verify_flight_round_trip(vec![batch1, batch2]).await;
989 }
990
991 #[tokio::test]
992 async fn test_dictionary_struct_hydration() {
993 let struct_fields = vec![Field::new_list(
994 "dict_list",
995 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
996 true,
997 )];
998
999 let mut struct_builder = StructBuilder::new(
1000 struct_fields.clone(),
1001 vec![Box::new(builder::ListBuilder::new(
1002 StringDictionaryBuilder::<UInt16Type>::new(),
1003 ))],
1004 );
1005
1006 struct_builder
1007 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1008 .unwrap()
1009 .append_value(vec![Some("a"), None, Some("b")]);
1010
1011 struct_builder.append(true);
1012
1013 let arr1 = struct_builder.finish();
1014
1015 struct_builder
1016 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1017 .unwrap()
1018 .append_value(vec![Some("c"), None, Some("d")]);
1019 struct_builder.append(true);
1020
1021 let arr2 = struct_builder.finish();
1022
1023 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1024 "struct",
1025 struct_fields,
1026 true,
1027 )]));
1028
1029 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1030 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1031
1032 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1033
1034 let encoder = FlightDataEncoderBuilder::default().build(stream);
1035
1036 let mut decoder = FlightDataDecoder::new(encoder);
1037 let expected_schema = Schema::new(vec![Field::new_struct(
1038 "struct",
1039 vec![Field::new_list(
1040 "dict_list",
1041 Field::new("item", DataType::Utf8, true),
1042 true,
1043 )],
1044 true,
1045 )]);
1046
1047 let expected_schema = Arc::new(expected_schema);
1048
1049 let mut expected_arrays = vec![
1050 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1051 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1052 ]
1053 .into_iter();
1054
1055 while let Some(decoded) = decoder.next().await {
1056 let decoded = decoded.unwrap();
1057 match decoded.payload {
1058 DecodedPayload::None => {}
1059 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1060 DecodedPayload::RecordBatch(b) => {
1061 assert_eq!(b.schema(), expected_schema);
1062 let expected_array = expected_arrays.next().unwrap();
1063 let struct_array =
1064 downcast_array::<StructArray>(b.column_by_name("struct").unwrap());
1065 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1066
1067 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1068
1069 assert_eq!(elem_array, expected_array);
1070 }
1071 }
1072 }
1073 }
1074
1075 #[tokio::test]
1076 async fn test_dictionary_struct_resend() {
1077 let struct_fields = vec![Field::new_list(
1078 "dict_list",
1079 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1080 true,
1081 )];
1082
1083 let mut struct_builder = StructBuilder::new(
1084 struct_fields.clone(),
1085 vec![Box::new(builder::ListBuilder::new(
1086 StringDictionaryBuilder::<UInt16Type>::new(),
1087 ))],
1088 );
1089
1090 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1091 .unwrap()
1092 .append_value(vec![Some("a"), None, Some("b")]);
1093 struct_builder.append(true);
1094
1095 let arr1 = struct_builder.finish();
1096
1097 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1098 .unwrap()
1099 .append_value(vec![Some("c"), None, Some("d")]);
1100 struct_builder.append(true);
1101
1102 let arr2 = struct_builder.finish();
1103
1104 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1105 "struct",
1106 struct_fields,
1107 true,
1108 )]));
1109
1110 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1111 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1112
1113 verify_flight_round_trip(vec![batch1, batch2]).await;
1114 }
1115
1116 #[tokio::test]
1117 async fn test_dictionary_union_hydration() {
1118 let struct_fields = vec![Field::new_list(
1119 "dict_list",
1120 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1121 true,
1122 )];
1123
1124 let union_fields = [
1125 (
1126 0,
1127 Arc::new(Field::new_list(
1128 "dict_list",
1129 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1130 true,
1131 )),
1132 ),
1133 (
1134 1,
1135 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1136 ),
1137 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1138 ]
1139 .into_iter()
1140 .collect::<UnionFields>();
1141
1142 let struct_fields = vec![Field::new_list(
1143 "dict_list",
1144 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1145 true,
1146 )];
1147
1148 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1149
1150 builder.append_value(vec![Some("a"), None, Some("b")]);
1151
1152 let arr1 = builder.finish();
1153
1154 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1155 let arr1 = UnionArray::try_new(
1156 union_fields.clone(),
1157 type_id_buffer,
1158 None,
1159 vec![
1160 Arc::new(arr1) as Arc<dyn Array>,
1161 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1162 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1163 ],
1164 )
1165 .unwrap();
1166
1167 builder.append_value(vec![Some("c"), None, Some("d")]);
1168
1169 let arr2 = Arc::new(builder.finish());
1170 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1171
1172 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1173 let arr2 = UnionArray::try_new(
1174 union_fields.clone(),
1175 type_id_buffer,
1176 None,
1177 vec![
1178 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1179 Arc::new(arr2),
1180 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1181 ],
1182 )
1183 .unwrap();
1184
1185 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1186 let arr3 = UnionArray::try_new(
1187 union_fields.clone(),
1188 type_id_buffer,
1189 None,
1190 vec![
1191 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1192 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1193 Arc::new(StringArray::from(vec!["e"])),
1194 ],
1195 )
1196 .unwrap();
1197
1198 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1199 .iter()
1200 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1201 .unzip();
1202 let schema = Arc::new(Schema::new(vec![Field::new_union(
1203 "union",
1204 type_ids.clone(),
1205 union_fields.clone(),
1206 UnionMode::Sparse,
1207 )]));
1208
1209 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1210 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1211 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1212
1213 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
1214
1215 let encoder = FlightDataEncoderBuilder::default().build(stream);
1216
1217 let mut decoder = FlightDataDecoder::new(encoder);
1218
1219 let hydrated_struct_fields = vec![Field::new_list(
1220 "dict_list",
1221 Field::new("item", DataType::Utf8, true),
1222 true,
1223 )];
1224
1225 let hydrated_union_fields = vec![
1226 Field::new_list("dict_list", Field::new("item", DataType::Utf8, true), true),
1227 Field::new_struct("struct", hydrated_struct_fields.clone(), true),
1228 Field::new("string", DataType::Utf8, true),
1229 ];
1230
1231 let expected_schema = Schema::new(vec![Field::new_union(
1232 "union",
1233 type_ids.clone(),
1234 hydrated_union_fields,
1235 UnionMode::Sparse,
1236 )]);
1237
1238 let expected_schema = Arc::new(expected_schema);
1239
1240 let mut expected_arrays = vec![
1241 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1242 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1243 StringArray::from(vec!["e"]),
1244 ]
1245 .into_iter();
1246
1247 let mut batch = 0;
1248 while let Some(decoded) = decoder.next().await {
1249 let decoded = decoded.unwrap();
1250 match decoded.payload {
1251 DecodedPayload::None => {}
1252 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1253 DecodedPayload::RecordBatch(b) => {
1254 assert_eq!(b.schema(), expected_schema);
1255 let expected_array = expected_arrays.next().unwrap();
1256 let union_arr =
1257 downcast_array::<UnionArray>(b.column_by_name("union").unwrap());
1258
1259 let elem_array = match batch {
1260 0 => {
1261 let list_array = downcast_array::<ListArray>(union_arr.child(0));
1262 downcast_array::<StringArray>(list_array.value(0).as_ref())
1263 }
1264 1 => {
1265 let struct_array = downcast_array::<StructArray>(union_arr.child(1));
1266 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1267
1268 downcast_array::<StringArray>(list_array.value(0).as_ref())
1269 }
1270 _ => downcast_array::<StringArray>(union_arr.child(2)),
1271 };
1272
1273 batch += 1;
1274
1275 assert_eq!(elem_array, expected_array);
1276 }
1277 }
1278 }
1279 }
1280
1281 #[tokio::test]
1282 async fn test_dictionary_union_resend() {
1283 let struct_fields = vec![Field::new_list(
1284 "dict_list",
1285 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1286 true,
1287 )];
1288
1289 let union_fields = [
1290 (
1291 0,
1292 Arc::new(Field::new_list(
1293 "dict_list",
1294 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1295 true,
1296 )),
1297 ),
1298 (
1299 1,
1300 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1301 ),
1302 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1303 ]
1304 .into_iter()
1305 .collect::<UnionFields>();
1306
1307 let mut field_types = union_fields.iter().map(|(_, field)| field.data_type());
1308 let dict_list_ty = field_types.next().unwrap();
1309 let struct_ty = field_types.next().unwrap();
1310 let string_ty = field_types.next().unwrap();
1311
1312 let struct_fields = vec![Field::new_list(
1313 "dict_list",
1314 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1315 true,
1316 )];
1317
1318 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1319
1320 builder.append_value(vec![Some("a"), None, Some("b")]);
1321
1322 let arr1 = builder.finish();
1323
1324 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1325 let arr1 = UnionArray::try_new(
1326 union_fields.clone(),
1327 type_id_buffer,
1328 None,
1329 vec![
1330 Arc::new(arr1),
1331 new_null_array(struct_ty, 1),
1332 new_null_array(string_ty, 1),
1333 ],
1334 )
1335 .unwrap();
1336
1337 builder.append_value(vec![Some("c"), None, Some("d")]);
1338
1339 let arr2 = Arc::new(builder.finish());
1340 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1341
1342 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1343 let arr2 = UnionArray::try_new(
1344 union_fields.clone(),
1345 type_id_buffer,
1346 None,
1347 vec![
1348 new_null_array(dict_list_ty, 1),
1349 Arc::new(arr2),
1350 new_null_array(string_ty, 1),
1351 ],
1352 )
1353 .unwrap();
1354
1355 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1356 let arr3 = UnionArray::try_new(
1357 union_fields.clone(),
1358 type_id_buffer,
1359 None,
1360 vec![
1361 new_null_array(dict_list_ty, 1),
1362 new_null_array(struct_ty, 1),
1363 Arc::new(StringArray::from(vec!["e"])),
1364 ],
1365 )
1366 .unwrap();
1367
1368 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1369 .iter()
1370 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1371 .unzip();
1372 let schema = Arc::new(Schema::new(vec![Field::new_union(
1373 "union",
1374 type_ids.clone(),
1375 union_fields.clone(),
1376 UnionMode::Sparse,
1377 )]));
1378
1379 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1380 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1381 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1382
1383 verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
1384 }
1385
1386 #[tokio::test]
1387 async fn test_dictionary_map_hydration() {
1388 let mut builder = MapBuilder::new(
1389 None,
1390 StringDictionaryBuilder::<UInt16Type>::new(),
1391 StringDictionaryBuilder::<UInt16Type>::new(),
1392 );
1393
1394 builder.keys().append_value("k1");
1396 builder.values().append_value("a");
1397 builder.keys().append_value("k2");
1398 builder.values().append_null();
1399 builder.keys().append_value("k3");
1400 builder.values().append_value("b");
1401 builder.append(true).unwrap();
1402
1403 let arr1 = builder.finish();
1404
1405 builder.keys().append_value("k1");
1407 builder.values().append_value("c");
1408 builder.keys().append_value("k2");
1409 builder.values().append_null();
1410 builder.keys().append_value("k3");
1411 builder.values().append_value("d");
1412 builder.append(true).unwrap();
1413
1414 let arr2 = builder.finish();
1415
1416 let schema = Arc::new(Schema::new(vec![Field::new_map(
1417 "dict_map",
1418 "entries",
1419 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1420 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1421 false,
1422 false,
1423 )]));
1424
1425 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1426 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1427
1428 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1429
1430 let encoder = FlightDataEncoderBuilder::default().build(stream);
1431
1432 let mut decoder = FlightDataDecoder::new(encoder);
1433 let expected_schema = Schema::new(vec![Field::new_map(
1434 "dict_map",
1435 "entries",
1436 Field::new("keys", DataType::Utf8, false),
1437 Field::new("values", DataType::Utf8, true),
1438 false,
1439 false,
1440 )]);
1441
1442 let expected_schema = Arc::new(expected_schema);
1443
1444 let mut builder = MapBuilder::new(
1446 None,
1447 GenericStringBuilder::<i32>::new(),
1448 GenericStringBuilder::<i32>::new(),
1449 );
1450
1451 builder.keys().append_value("k1");
1453 builder.values().append_value("a");
1454 builder.keys().append_value("k2");
1455 builder.values().append_null();
1456 builder.keys().append_value("k3");
1457 builder.values().append_value("b");
1458 builder.append(true).unwrap();
1459
1460 let arr1 = builder.finish();
1461
1462 builder.keys().append_value("k1");
1464 builder.values().append_value("c");
1465 builder.keys().append_value("k2");
1466 builder.values().append_null();
1467 builder.keys().append_value("k3");
1468 builder.values().append_value("d");
1469 builder.append(true).unwrap();
1470
1471 let arr2 = builder.finish();
1472
1473 let mut expected_arrays = vec![arr1, arr2].into_iter();
1474
1475 while let Some(decoded) = decoder.next().await {
1476 let decoded = decoded.unwrap();
1477 match decoded.payload {
1478 DecodedPayload::None => {}
1479 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1480 DecodedPayload::RecordBatch(b) => {
1481 assert_eq!(b.schema(), expected_schema);
1482 let expected_array = expected_arrays.next().unwrap();
1483 let map_array =
1484 downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());
1485
1486 assert_eq!(map_array, expected_array);
1487 }
1488 }
1489 }
1490 }
1491
1492 #[tokio::test]
1493 async fn test_dictionary_map_resend() {
1494 let mut builder = MapBuilder::new(
1495 None,
1496 StringDictionaryBuilder::<UInt16Type>::new(),
1497 StringDictionaryBuilder::<UInt16Type>::new(),
1498 );
1499
1500 builder.keys().append_value("k1");
1502 builder.values().append_value("a");
1503 builder.keys().append_value("k2");
1504 builder.values().append_null();
1505 builder.keys().append_value("k3");
1506 builder.values().append_value("b");
1507 builder.append(true).unwrap();
1508
1509 let arr1 = builder.finish();
1510
1511 builder.keys().append_value("k1");
1513 builder.values().append_value("c");
1514 builder.keys().append_value("k2");
1515 builder.values().append_null();
1516 builder.keys().append_value("k3");
1517 builder.values().append_value("d");
1518 builder.append(true).unwrap();
1519
1520 let arr2 = builder.finish();
1521
1522 let schema = Arc::new(Schema::new(vec![Field::new_map(
1523 "dict_map",
1524 "entries",
1525 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1526 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1527 false,
1528 false,
1529 )]));
1530
1531 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1532 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1533
1534 verify_flight_round_trip(vec![batch1, batch2]).await;
1535 }
1536
1537 async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
1538 let expected_schema = batches.first().unwrap().schema();
1539
1540 let encoder = FlightDataEncoderBuilder::default()
1541 .with_options(IpcWriteOptions::default().with_preserve_dict_id(false))
1542 .with_dictionary_handling(DictionaryHandling::Resend)
1543 .build(futures::stream::iter(batches.clone().into_iter().map(Ok)));
1544
1545 let mut expected_batches = batches.drain(..);
1546
1547 let mut decoder = FlightDataDecoder::new(encoder);
1548 while let Some(decoded) = decoder.next().await {
1549 let decoded = decoded.unwrap();
1550 match decoded.payload {
1551 DecodedPayload::None => {}
1552 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1553 DecodedPayload::RecordBatch(b) => {
1554 let expected_batch = expected_batches.next().unwrap();
1555 assert_eq!(b, expected_batch);
1556 }
1557 }
1558 }
1559 }
1560
1561 #[test]
1562 fn test_schema_metadata_encoded() {
1563 let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata(
1564 HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
1565 );
1566
1567 let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
1568
1569 let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
1570 assert!(got.metadata().contains_key("some_key"));
1571 }
1572
1573 #[test]
1574 fn test_encode_no_column_batch() {
1575 let batch = RecordBatch::try_new_with_options(
1576 Arc::new(Schema::empty()),
1577 vec![],
1578 &RecordBatchOptions::new().with_row_count(Some(10)),
1579 )
1580 .expect("cannot create record batch");
1581
1582 hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize");
1583 }
1584
1585 pub fn make_flight_data(
1586 batch: &RecordBatch,
1587 options: &IpcWriteOptions,
1588 ) -> (Vec<FlightData>, FlightData) {
1589 #[allow(deprecated)]
1590 crate::utils::flight_data_from_arrow_batch(batch, options)
1591 }
1592
1593 #[test]
1594 fn test_split_batch_for_grpc_response() {
1595 let max_flight_data_size = 1024;
1596
1597 let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
1599 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1600 .expect("cannot create record batch");
1601 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1602 assert_eq!(split.len(), 1);
1603 assert_eq!(batch, split[0]);
1604
1605 let n_rows = max_flight_data_size + 1;
1607 assert!(n_rows % 2 == 1, "should be an odd number");
1608 let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
1609 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1610 .expect("cannot create record batch");
1611 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1612 assert_eq!(split.len(), 3);
1613 assert_eq!(
1614 split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
1615 n_rows
1616 );
1617 let a = pretty_format_batches(&split).unwrap().to_string();
1618 let b = pretty_format_batches(&[batch]).unwrap().to_string();
1619 assert_eq!(a, b);
1620 }
1621
1622 #[test]
1623 fn test_split_batch_for_grpc_response_sizes() {
1624 verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
1626
1627 verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
1629
1630 verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
1632
1633 verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
1635
1636 verify_split(10, 1024, vec![10]);
1638 }
1639
1640 fn verify_split(
1644 num_input_rows: u64,
1645 max_flight_data_size_bytes: usize,
1646 expected_sizes: Vec<usize>,
1647 ) {
1648 let array: UInt64Array = (0..num_input_rows).collect();
1649
1650 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
1651 .expect("cannot create record batch");
1652
1653 let input_rows = batch.num_rows();
1654
1655 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
1656 let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
1657 let output_rows: usize = sizes.iter().sum();
1658
1659 assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
1660 assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
1661 }
1662
1663 #[tokio::test]
1667 async fn flight_data_size_even() {
1668 let s1 = StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024));
1669 let i1 = Int16Array::from_iter_values(0..1024);
1670 let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024));
1671 let i2 = Int64Array::from_iter_values(0..1024);
1672
1673 let batch = RecordBatch::try_from_iter(vec![
1674 ("s1", Arc::new(s1) as _),
1675 ("i1", Arc::new(i1) as _),
1676 ("s2", Arc::new(s2) as _),
1677 ("i2", Arc::new(i2) as _),
1678 ])
1679 .unwrap();
1680
1681 verify_encoded_split(batch, 112).await;
1682 }
1683
1684 #[tokio::test]
1685 async fn flight_data_size_uneven_variable_lengths() {
1686 let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
1688 let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
1689
1690 verify_encoded_split(batch, 4304).await;
1693 }
1694
1695 #[tokio::test]
1696 async fn flight_data_size_large_row() {
1697 let array1 = StringArray::from_iter_values(vec![
1699 "*".repeat(500),
1700 "*".repeat(500),
1701 "*".repeat(500),
1702 "*".repeat(500),
1703 ]);
1704 let array2 = StringArray::from_iter_values(vec![
1705 "*".to_string(),
1706 "*".repeat(1000),
1707 "*".repeat(2000),
1708 "*".repeat(4000),
1709 ]);
1710
1711 let array3 = StringArray::from_iter_values(vec![
1712 "*".to_string(),
1713 "*".to_string(),
1714 "*".repeat(1000),
1715 "*".repeat(2000),
1716 ]);
1717
1718 let batch = RecordBatch::try_from_iter(vec![
1719 ("a1", Arc::new(array1) as _),
1720 ("a2", Arc::new(array2) as _),
1721 ("a3", Arc::new(array3) as _),
1722 ])
1723 .unwrap();
1724
1725 verify_encoded_split(batch, 5800).await;
1729 }
1730
1731 #[tokio::test]
1732 async fn flight_data_size_string_dictionary() {
1733 let array: DictionaryArray<Int32Type> = (1..1024)
1735 .map(|i| match i % 3 {
1736 0 => Some("value0"),
1737 1 => Some("value1"),
1738 _ => None,
1739 })
1740 .collect();
1741
1742 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1743
1744 verify_encoded_split(batch, 160).await;
1745 }
1746
1747 #[tokio::test]
1748 async fn flight_data_size_large_dictionary() {
1749 let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1751
1752 let array: DictionaryArray<Int32Type> = values.iter().map(|s| Some(s.as_str())).collect();
1753
1754 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1755
1756 verify_encoded_split(batch, 3328).await;
1759 }
1760
1761 #[tokio::test]
1762 async fn flight_data_size_large_dictionary_repeated_non_uniform() {
1763 let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
1765 let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
1766 let array = DictionaryArray::new(keys, Arc::new(values));
1767
1768 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1769
1770 verify_encoded_split(batch, 5280).await;
1773 }
1774
1775 #[tokio::test]
1776 async fn flight_data_size_multiple_dictionaries() {
1777 let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1779 let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
1781 let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
1783
1784 let array1: DictionaryArray<Int32Type> = values1.iter().map(|s| Some(s.as_str())).collect();
1785 let array2: DictionaryArray<Int32Type> = values2.iter().map(|s| Some(s.as_str())).collect();
1786 let array3: DictionaryArray<Int32Type> = values3.iter().map(|s| Some(s.as_str())).collect();
1787
1788 let batch = RecordBatch::try_from_iter(vec![
1789 ("a1", Arc::new(array1) as _),
1790 ("a2", Arc::new(array2) as _),
1791 ("a3", Arc::new(array3) as _),
1792 ])
1793 .unwrap();
1794
1795 verify_encoded_split(batch, 4128).await;
1798 }
1799
1800 #[allow(clippy::needless_as_bytes)]
1802 fn flight_data_size(d: &FlightData) -> usize {
1803 let flight_descriptor_size = d
1804 .flight_descriptor
1805 .as_ref()
1806 .map(|descriptor| {
1807 let path_len: usize = descriptor.path.iter().map(|p| p.as_bytes().len()).sum();
1808
1809 std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
1810 })
1811 .unwrap_or(0);
1812
1813 flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len()
1814 }
1815
1816 async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
1832 let num_rows = batch.num_rows();
1833
1834 let mut max_overage_seen = 0;
1836
1837 for max_flight_data_size in [1024, 2021, 5000] {
1838 println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
1839
1840 let mut stream = FlightDataEncoderBuilder::new()
1841 .with_max_flight_data_size(max_flight_data_size)
1842 .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap())
1844 .build(futures::stream::iter([Ok(batch.clone())]));
1845
1846 let mut i = 0;
1847 while let Some(data) = stream.next().await.transpose().unwrap() {
1848 let actual_data_size = flight_data_size(&data);
1849
1850 let actual_overage = actual_data_size.saturating_sub(max_flight_data_size);
1851
1852 assert!(
1853 actual_overage <= allowed_overage,
1854 "encoded data[{i}]: actual size {actual_data_size}, \
1855 actual_overage: {actual_overage} \
1856 allowed_overage: {allowed_overage}"
1857 );
1858
1859 i += 1;
1860
1861 max_overage_seen = max_overage_seen.max(actual_overage)
1862 }
1863 }
1864
1865 assert_eq!(
1869 allowed_overage, max_overage_seen,
1870 "Specified overage was too high"
1871 );
1872 }
1873}