1use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
19
20use crate::{FlightData, FlightDescriptor, SchemaAsIpc, error::Result};
21
22use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
24
25use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
26use bytes::Bytes;
27use futures::{Stream, StreamExt, ready, stream::BoxStream};
28
29#[derive(Debug)]
145pub struct FlightDataEncoderBuilder {
146 max_flight_data_size: usize,
149 options: IpcWriteOptions,
151 app_metadata: Bytes,
153 schema: Option<SchemaRef>,
155 descriptor: Option<FlightDescriptor>,
157 dictionary_handling: DictionaryHandling,
160}
161
162pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
167
168impl Default for FlightDataEncoderBuilder {
169 fn default() -> Self {
170 Self {
171 max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
172 options: IpcWriteOptions::default(),
173 app_metadata: Bytes::new(),
174 schema: None,
175 descriptor: None,
176 dictionary_handling: DictionaryHandling::Hydrate,
177 }
178 }
179}
180
181impl FlightDataEncoderBuilder {
182 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
198 self.max_flight_data_size = max_flight_data_size;
199 self
200 }
201
202 pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
204 self.dictionary_handling = dictionary_handling;
205 self
206 }
207
208 pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
212 self.app_metadata = app_metadata;
213 self
214 }
215
216 pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
218 self.options = options;
219 self
220 }
221
222 pub fn with_schema(mut self, schema: SchemaRef) -> Self {
227 self.schema = Some(schema);
228 self
229 }
230
231 pub fn with_flight_descriptor(mut self, descriptor: Option<FlightDescriptor>) -> Self {
233 self.descriptor = descriptor;
234 self
235 }
236
237 pub fn build<S>(self, input: S) -> FlightDataEncoder
242 where
243 S: Stream<Item = Result<RecordBatch>> + Send + 'static,
244 {
245 let Self {
246 max_flight_data_size,
247 options,
248 app_metadata,
249 schema,
250 descriptor,
251 dictionary_handling,
252 } = self;
253
254 FlightDataEncoder::new(
255 input.boxed(),
256 schema,
257 max_flight_data_size,
258 options,
259 app_metadata,
260 descriptor,
261 dictionary_handling,
262 )
263 }
264}
265
266pub struct FlightDataEncoder {
270 inner: BoxStream<'static, Result<RecordBatch>>,
272 schema: Option<SchemaRef>,
274 max_flight_data_size: usize,
277 encoder: FlightIpcEncoder,
279 app_metadata: Option<Bytes>,
281 queue: VecDeque<FlightData>,
283 done: bool,
285 descriptor: Option<FlightDescriptor>,
287 dictionary_handling: DictionaryHandling,
290}
291
292impl FlightDataEncoder {
293 fn new(
294 inner: BoxStream<'static, Result<RecordBatch>>,
295 schema: Option<SchemaRef>,
296 max_flight_data_size: usize,
297 options: IpcWriteOptions,
298 app_metadata: Bytes,
299 descriptor: Option<FlightDescriptor>,
300 dictionary_handling: DictionaryHandling,
301 ) -> Self {
302 let mut encoder = Self {
303 inner,
304 schema: None,
305 max_flight_data_size,
306 encoder: FlightIpcEncoder::new(
307 options,
308 dictionary_handling != DictionaryHandling::Resend,
309 ),
310 app_metadata: Some(app_metadata),
311 queue: VecDeque::new(),
312 done: false,
313 descriptor,
314 dictionary_handling,
315 };
316
317 if let Some(schema) = schema {
319 encoder.encode_schema(&schema);
320 }
321
322 encoder
323 }
324
325 pub fn known_schema(&self) -> Option<SchemaRef> {
328 self.schema.clone()
329 }
330
331 fn queue_message(&mut self, mut data: FlightData) {
333 if let Some(descriptor) = self.descriptor.take() {
334 data.flight_descriptor = Some(descriptor);
335 }
336 self.queue.push_back(data);
337 }
338
339 fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
341 for data in datas {
342 self.queue_message(data)
343 }
344 }
345
346 fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
349 let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend;
352 let schema = Arc::new(prepare_schema_for_flight(
353 schema,
354 &mut self.encoder.dictionary_tracker,
355 send_dictionaries,
356 ));
357 let mut schema_flight_data = self.encoder.encode_schema(&schema);
358
359 if let Some(app_metadata) = self.app_metadata.take() {
361 schema_flight_data.app_metadata = app_metadata;
362 }
363 self.queue_message(schema_flight_data);
364 self.schema = Some(schema.clone());
366 schema
367 }
368
369 fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
371 let schema = match &self.schema {
372 Some(schema) => schema.clone(),
373 None => self.encode_schema(batch.schema_ref()),
375 };
376
377 let batch = match self.dictionary_handling {
378 DictionaryHandling::Resend => batch,
379 DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
380 };
381
382 for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
383 let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
384
385 self.queue_messages(flight_dictionaries);
386 self.queue_message(flight_batch);
387 }
388
389 Ok(())
390 }
391}
392
393impl Stream for FlightDataEncoder {
394 type Item = Result<FlightData>;
395
396 fn poll_next(
397 mut self: Pin<&mut Self>,
398 cx: &mut std::task::Context<'_>,
399 ) -> Poll<Option<Self::Item>> {
400 loop {
401 if self.done && self.queue.is_empty() {
402 return Poll::Ready(None);
403 }
404
405 if let Some(data) = self.queue.pop_front() {
407 return Poll::Ready(Some(Ok(data)));
408 }
409
410 let batch = ready!(self.inner.poll_next_unpin(cx));
412
413 match batch {
414 None => {
415 self.done = true;
417 assert!(self.queue.is_empty());
419 return Poll::Ready(None);
420 }
421 Some(Err(e)) => {
422 self.done = true;
424 self.queue.clear();
425 return Poll::Ready(Some(Err(e)));
426 }
427 Some(Ok(batch)) => {
428 if let Err(e) = self.encode_batch(batch) {
430 self.done = true;
431 self.queue.clear();
432 return Poll::Ready(Some(Err(e)));
433 }
434 }
435 }
436 }
437 }
438}
439
440#[derive(Debug, PartialEq)]
469pub enum DictionaryHandling {
470 Hydrate,
478 Resend,
488}
489
490fn prepare_field_for_flight(
491 field: &FieldRef,
492 dictionary_tracker: &mut DictionaryTracker,
493 send_dictionaries: bool,
494) -> Field {
495 match field.data_type() {
496 DataType::List(inner) => Field::new_list(
497 field.name(),
498 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
499 field.is_nullable(),
500 )
501 .with_metadata(field.metadata().clone()),
502 DataType::LargeList(inner) => Field::new_list(
503 field.name(),
504 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
505 field.is_nullable(),
506 )
507 .with_metadata(field.metadata().clone()),
508 DataType::Struct(fields) => {
509 let new_fields: Vec<Field> = fields
510 .iter()
511 .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
512 .collect();
513 Field::new_struct(field.name(), new_fields, field.is_nullable())
514 .with_metadata(field.metadata().clone())
515 }
516 DataType::Union(fields, mode) => {
517 let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
518 .iter()
519 .map(|(type_id, f)| {
520 (
521 type_id,
522 prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
523 )
524 })
525 .unzip();
526
527 Field::new_union(field.name(), type_ids, new_fields, *mode)
528 }
529 DataType::Dictionary(_, value_type) => {
530 if !send_dictionaries {
531 let value_field = Field::new(
533 field.name(),
534 value_type.as_ref().clone(),
535 field.is_nullable(),
536 );
537 prepare_field_for_flight(
538 &Arc::new(value_field),
539 dictionary_tracker,
540 send_dictionaries,
541 )
542 .with_metadata(field.metadata().clone())
543 } else {
544 let value_field = Field::new("values", value_type.as_ref().clone(), true);
548 prepare_field_for_flight(
549 &Arc::new(value_field),
550 dictionary_tracker,
551 send_dictionaries,
552 );
553 dictionary_tracker.next_dict_id();
554 #[allow(deprecated)]
555 Field::new_dict(
556 field.name(),
557 field.data_type().clone(),
558 field.is_nullable(),
559 0,
560 field.dict_is_ordered().unwrap_or_default(),
561 )
562 .with_metadata(field.metadata().clone())
563 }
564 }
565 DataType::ListView(inner) | DataType::LargeListView(inner) => {
566 let prepared = prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries);
567 Field::new(
568 field.name(),
569 match field.data_type() {
570 DataType::ListView(_) => DataType::ListView(Arc::new(prepared)),
571 _ => DataType::LargeListView(Arc::new(prepared)),
572 },
573 field.is_nullable(),
574 )
575 .with_metadata(field.metadata().clone())
576 }
577 DataType::FixedSizeList(inner, size) => Field::new(
578 field.name(),
579 DataType::FixedSizeList(
580 Arc::new(prepare_field_for_flight(
581 inner,
582 dictionary_tracker,
583 send_dictionaries,
584 )),
585 *size,
586 ),
587 field.is_nullable(),
588 )
589 .with_metadata(field.metadata().clone()),
590 DataType::RunEndEncoded(run_ends, values) => Field::new(
591 field.name(),
592 DataType::RunEndEncoded(
593 run_ends.clone(),
594 Arc::new(prepare_field_for_flight(
595 values,
596 dictionary_tracker,
597 send_dictionaries,
598 )),
599 ),
600 field.is_nullable(),
601 )
602 .with_metadata(field.metadata().clone()),
603 DataType::Map(inner, sorted) => Field::new(
604 field.name(),
605 DataType::Map(
606 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
607 *sorted,
608 ),
609 field.is_nullable(),
610 )
611 .with_metadata(field.metadata().clone()),
612 DataType::Null
613 | DataType::Boolean
614 | DataType::Int8
615 | DataType::Int16
616 | DataType::Int32
617 | DataType::Int64
618 | DataType::UInt8
619 | DataType::UInt16
620 | DataType::UInt32
621 | DataType::UInt64
622 | DataType::Float16
623 | DataType::Float32
624 | DataType::Float64
625 | DataType::Timestamp(_, _)
626 | DataType::Date32
627 | DataType::Date64
628 | DataType::Time32(_)
629 | DataType::Time64(_)
630 | DataType::Duration(_)
631 | DataType::Interval(_)
632 | DataType::Binary
633 | DataType::FixedSizeBinary(_)
634 | DataType::LargeBinary
635 | DataType::BinaryView
636 | DataType::Utf8
637 | DataType::LargeUtf8
638 | DataType::Utf8View
639 | DataType::Decimal32(_, _)
640 | DataType::Decimal64(_, _)
641 | DataType::Decimal128(_, _)
642 | DataType::Decimal256(_, _) => field.as_ref().clone(),
643 }
644}
645
646fn prepare_schema_for_flight(
652 schema: &Schema,
653 dictionary_tracker: &mut DictionaryTracker,
654 send_dictionaries: bool,
655) -> Schema {
656 let fields: Fields = schema
657 .fields()
658 .iter()
659 .map(|field| prepare_field_for_flight(field, dictionary_tracker, send_dictionaries))
660 .collect();
661
662 Schema::new(fields).with_metadata(schema.metadata().clone())
663}
664
665fn split_batch_for_grpc_response(
672 batch: RecordBatch,
673 max_flight_data_size: usize,
674) -> Vec<RecordBatch> {
675 let size = batch
676 .columns()
677 .iter()
678 .map(|col| col.get_buffer_memory_size())
679 .sum::<usize>();
680
681 let n_batches =
682 (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
683 let rows_per_batch = (batch.num_rows() / n_batches).max(1);
684 let mut out = Vec::with_capacity(n_batches + 1);
685
686 let mut offset = 0;
687 while offset < batch.num_rows() {
688 let length = (rows_per_batch).min(batch.num_rows() - offset);
689 out.push(batch.slice(offset, length));
690
691 offset += length;
692 }
693
694 out
695}
696
697struct FlightIpcEncoder {
704 options: IpcWriteOptions,
705 data_gen: IpcDataGenerator,
706 dictionary_tracker: DictionaryTracker,
707 compression_context: CompressionContext,
708}
709
710impl FlightIpcEncoder {
711 fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
712 Self {
713 options,
714 data_gen: IpcDataGenerator::default(),
715 dictionary_tracker: DictionaryTracker::new(error_on_replacement),
716 compression_context: CompressionContext::default(),
717 }
718 }
719
720 fn encode_schema(&self, schema: &Schema) -> FlightData {
722 SchemaAsIpc::new(schema, &self.options).into()
723 }
724
725 fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
728 let (encoded_dictionaries, encoded_batch) = self.data_gen.encode(
729 batch,
730 &mut self.dictionary_tracker,
731 &self.options,
732 &mut self.compression_context,
733 )?;
734
735 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
736 let flight_batch = encoded_batch.into();
737
738 Ok((flight_dictionaries, flight_batch))
739 }
740}
741
742fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result<RecordBatch> {
745 let columns = schema
746 .fields()
747 .iter()
748 .zip(batch.columns())
749 .map(|(field, c)| hydrate_dictionary(c, field.data_type()))
750 .collect::<Result<Vec<_>>>()?;
751
752 let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
753
754 Ok(RecordBatch::try_new_with_options(
755 schema, columns, &options,
756 )?)
757}
758
759fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef> {
761 let arr = match (array.data_type(), data_type) {
762 (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
763 let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();
764
765 Arc::new(UnionArray::try_new(
766 fields.clone(),
767 union_arr.type_ids().clone(),
768 None,
769 fields
770 .iter()
771 .map(|(type_id, field)| {
772 Ok(arrow_cast::cast(
773 union_arr.child(type_id),
774 field.data_type(),
775 )?)
776 })
777 .collect::<Result<Vec<_>>>()?,
778 )?)
779 }
780 (_, data_type) => arrow_cast::cast(array, data_type)?,
781 };
782 Ok(arr)
783}
784
785#[cfg(test)]
786mod tests {
787 use crate::decode::{DecodedPayload, FlightDataDecoder};
788 use arrow_array::builder::{
789 FixedSizeListBuilder, GenericByteDictionaryBuilder, GenericListViewBuilder, ListBuilder,
790 StringDictionaryBuilder, StructBuilder,
791 };
792 use arrow_array::*;
793 use arrow_array::{cast::downcast_array, types::*};
794 use arrow_buffer::ScalarBuffer;
795 use arrow_cast::pretty::pretty_format_batches;
796 use arrow_ipc::MetadataVersion;
797 use arrow_schema::{UnionFields, UnionMode};
798 use builder::{GenericStringBuilder, MapBuilder};
799 use std::collections::HashMap;
800
801 use super::*;
802
803 #[test]
804 fn test_encode_flight_data() {
807 let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
809 let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
810
811 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
812 .expect("cannot create record batch");
813 let schema = batch.schema_ref();
814
815 let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
816
817 let big_batch = batch.slice(0, batch.num_rows() - 1);
818 let optimized_big_batch =
819 hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize");
820 let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options);
821
822 assert_eq!(
823 baseline_flight_batch.data_body.len(),
824 optimized_big_flight_batch.data_body.len()
825 );
826
827 let small_batch = batch.slice(0, 1);
828 let optimized_small_batch =
829 hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize");
830 let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options);
831
832 assert!(
833 baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len()
834 );
835 }
836
837 #[tokio::test]
838 async fn test_dictionary_hydration() {
839 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
840 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
841
842 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
843 "dict",
844 DataType::UInt16,
845 DataType::Utf8,
846 false,
847 )]));
848 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
849 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
850
851 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
852
853 let encoder = FlightDataEncoderBuilder::default().build(stream);
854 let mut decoder = FlightDataDecoder::new(encoder);
855 let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]);
856 let expected_schema = Arc::new(expected_schema);
857 let mut expected_arrays = vec![
858 StringArray::from(vec!["a", "a", "b"]),
859 StringArray::from(vec!["c", "c", "d"]),
860 ]
861 .into_iter();
862 while let Some(decoded) = decoder.next().await {
863 let decoded = decoded.unwrap();
864 match decoded.payload {
865 DecodedPayload::None => {}
866 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
867 DecodedPayload::RecordBatch(b) => {
868 assert_eq!(b.schema(), expected_schema);
869 let expected_array = expected_arrays.next().unwrap();
870 let actual_array = b.column_by_name("dict").unwrap();
871 let actual_array = downcast_array::<StringArray>(actual_array);
872
873 assert_eq!(actual_array, expected_array);
874 }
875 }
876 }
877 }
878
879 #[tokio::test]
880 async fn test_dictionary_resend() {
881 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
882 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
883
884 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
885 "dict",
886 DataType::UInt16,
887 DataType::Utf8,
888 false,
889 )]));
890 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
891 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
892
893 verify_flight_round_trip(vec![batch1, batch2]).await;
894 }
895
896 #[tokio::test]
897 async fn test_dictionary_hydration_known_schema() {
898 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
899 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
900
901 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
902 "dict",
903 DataType::UInt16,
904 DataType::Utf8,
905 false,
906 )]));
907 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
908 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
909
910 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
911
912 let encoder = FlightDataEncoderBuilder::default()
913 .with_schema(schema)
914 .build(stream);
915 let expected_schema =
916 Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
917 assert_eq!(Some(expected_schema), encoder.known_schema())
918 }
919
920 #[tokio::test]
921 async fn test_dictionary_resend_known_schema() {
922 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
923 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
924
925 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
926 "dict",
927 DataType::UInt16,
928 DataType::Utf8,
929 false,
930 )]));
931 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
932 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
933
934 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
935
936 let encoder = FlightDataEncoderBuilder::default()
937 .with_dictionary_handling(DictionaryHandling::Resend)
938 .with_schema(schema.clone())
939 .build(stream);
940 assert_eq!(Some(schema), encoder.known_schema())
941 }
942
943 #[tokio::test]
944 async fn test_multiple_dictionaries_resend() {
945 let schema = Arc::new(Schema::new(vec![
947 Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false),
948 Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false),
949 ]));
950
951 let arr_one_1: Arc<DictionaryArray<UInt16Type>> =
952 Arc::new(vec!["a", "a", "b"].into_iter().collect());
953 let arr_one_2: Arc<DictionaryArray<UInt16Type>> =
954 Arc::new(vec!["c", "c", "d"].into_iter().collect());
955 let arr_two_1: Arc<DictionaryArray<UInt16Type>> =
956 Arc::new(vec!["b", "a", "c"].into_iter().collect());
957 let arr_two_2: Arc<DictionaryArray<UInt16Type>> =
958 Arc::new(vec!["k", "d", "e"].into_iter().collect());
959 let batch1 =
960 RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()])
961 .unwrap();
962 let batch2 =
963 RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()])
964 .unwrap();
965
966 verify_flight_round_trip(vec![batch1, batch2]).await;
967 }
968
969 #[tokio::test]
970 async fn test_dictionary_list_hydration() {
971 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
972
973 builder.append_value(vec![Some("a"), None, Some("b")]);
974
975 let arr1 = builder.finish();
976
977 builder.append_value(vec![Some("c"), None, Some("d")]);
978
979 let arr2 = builder.finish();
980
981 let schema = Arc::new(Schema::new(vec![Field::new_list(
982 "dict_list",
983 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
984 true,
985 )]));
986
987 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
988 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
989
990 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
991
992 let encoder = FlightDataEncoderBuilder::default().build(stream);
993
994 let mut decoder = FlightDataDecoder::new(encoder);
995 let expected_schema = Schema::new(vec![Field::new_list(
996 "dict_list",
997 Field::new_list_field(DataType::Utf8, true),
998 true,
999 )]);
1000
1001 let expected_schema = Arc::new(expected_schema);
1002
1003 let mut expected_arrays = vec![
1004 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1005 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1006 ]
1007 .into_iter();
1008
1009 while let Some(decoded) = decoder.next().await {
1010 let decoded = decoded.unwrap();
1011 match decoded.payload {
1012 DecodedPayload::None => {}
1013 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1014 DecodedPayload::RecordBatch(b) => {
1015 assert_eq!(b.schema(), expected_schema);
1016 let expected_array = expected_arrays.next().unwrap();
1017 let list_array =
1018 downcast_array::<ListArray>(b.column_by_name("dict_list").unwrap());
1019 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1020
1021 assert_eq!(elem_array, expected_array);
1022 }
1023 }
1024 }
1025 }
1026
1027 #[tokio::test]
1028 async fn test_dictionary_list_resend() {
1029 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1030
1031 builder.append_value(vec![Some("a"), None, Some("b")]);
1032
1033 let arr1 = builder.finish();
1034
1035 builder.append_value(vec![Some("c"), None, Some("d")]);
1036
1037 let arr2 = builder.finish();
1038
1039 let schema = Arc::new(Schema::new(vec![Field::new_list(
1040 "dict_list",
1041 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1042 true,
1043 )]));
1044
1045 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1046 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1047
1048 verify_flight_round_trip(vec![batch1, batch2]).await;
1049 }
1050
1051 #[tokio::test]
1052 async fn test_dictionary_struct_hydration() {
1053 let struct_fields = vec![Field::new_list(
1054 "dict_list",
1055 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1056 true,
1057 )];
1058
1059 let mut struct_builder = StructBuilder::new(
1060 struct_fields.clone(),
1061 vec![Box::new(builder::ListBuilder::new(
1062 StringDictionaryBuilder::<UInt16Type>::new(),
1063 ))],
1064 );
1065
1066 struct_builder
1067 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1068 .unwrap()
1069 .append_value(vec![Some("a"), None, Some("b")]);
1070
1071 struct_builder.append(true);
1072
1073 let arr1 = struct_builder.finish();
1074
1075 struct_builder
1076 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1077 .unwrap()
1078 .append_value(vec![Some("c"), None, Some("d")]);
1079 struct_builder.append(true);
1080
1081 let arr2 = struct_builder.finish();
1082
1083 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1084 "struct",
1085 struct_fields,
1086 true,
1087 )]));
1088
1089 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1090 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1091
1092 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1093
1094 let encoder = FlightDataEncoderBuilder::default().build(stream);
1095
1096 let mut decoder = FlightDataDecoder::new(encoder);
1097 let expected_schema = Schema::new(vec![Field::new_struct(
1098 "struct",
1099 vec![Field::new_list(
1100 "dict_list",
1101 Field::new_list_field(DataType::Utf8, true),
1102 true,
1103 )],
1104 true,
1105 )]);
1106
1107 let expected_schema = Arc::new(expected_schema);
1108
1109 let mut expected_arrays = vec![
1110 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1111 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1112 ]
1113 .into_iter();
1114
1115 while let Some(decoded) = decoder.next().await {
1116 let decoded = decoded.unwrap();
1117 match decoded.payload {
1118 DecodedPayload::None => {}
1119 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1120 DecodedPayload::RecordBatch(b) => {
1121 assert_eq!(b.schema(), expected_schema);
1122 let expected_array = expected_arrays.next().unwrap();
1123 let struct_array =
1124 downcast_array::<StructArray>(b.column_by_name("struct").unwrap());
1125 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1126
1127 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1128
1129 assert_eq!(elem_array, expected_array);
1130 }
1131 }
1132 }
1133 }
1134
1135 #[tokio::test]
1136 async fn test_dictionary_struct_resend() {
1137 let struct_fields = vec![Field::new_list(
1138 "dict_list",
1139 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1140 true,
1141 )];
1142
1143 let mut struct_builder = StructBuilder::new(
1144 struct_fields.clone(),
1145 vec![Box::new(builder::ListBuilder::new(
1146 StringDictionaryBuilder::<UInt16Type>::new(),
1147 ))],
1148 );
1149
1150 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1151 .unwrap()
1152 .append_value(vec![Some("a"), None, Some("b")]);
1153 struct_builder.append(true);
1154
1155 let arr1 = struct_builder.finish();
1156
1157 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1158 .unwrap()
1159 .append_value(vec![Some("c"), None, Some("d")]);
1160 struct_builder.append(true);
1161
1162 let arr2 = struct_builder.finish();
1163
1164 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1165 "struct",
1166 struct_fields,
1167 true,
1168 )]));
1169
1170 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1171 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1172
1173 verify_flight_round_trip(vec![batch1, batch2]).await;
1174 }
1175
1176 #[tokio::test]
1177 async fn test_dictionary_union_hydration() {
1178 let struct_fields = vec![Field::new_list(
1179 "dict_list",
1180 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1181 true,
1182 )];
1183
1184 let union_fields = [
1185 (
1186 0,
1187 Arc::new(Field::new_list(
1188 "dict_list",
1189 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1190 true,
1191 )),
1192 ),
1193 (
1194 1,
1195 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1196 ),
1197 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1198 ]
1199 .into_iter()
1200 .collect::<UnionFields>();
1201
1202 let struct_fields = vec![Field::new_list(
1203 "dict_list",
1204 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1205 true,
1206 )];
1207
1208 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1209
1210 builder.append_value(vec![Some("a"), None, Some("b")]);
1211
1212 let arr1 = builder.finish();
1213
1214 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1215 let arr1 = UnionArray::try_new(
1216 union_fields.clone(),
1217 type_id_buffer,
1218 None,
1219 vec![
1220 Arc::new(arr1) as Arc<dyn Array>,
1221 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1222 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1223 ],
1224 )
1225 .unwrap();
1226
1227 builder.append_value(vec![Some("c"), None, Some("d")]);
1228
1229 let arr2 = Arc::new(builder.finish());
1230 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1231
1232 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1233 let arr2 = UnionArray::try_new(
1234 union_fields.clone(),
1235 type_id_buffer,
1236 None,
1237 vec![
1238 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1239 Arc::new(arr2),
1240 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1241 ],
1242 )
1243 .unwrap();
1244
1245 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1246 let arr3 = UnionArray::try_new(
1247 union_fields.clone(),
1248 type_id_buffer,
1249 None,
1250 vec![
1251 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1252 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1253 Arc::new(StringArray::from(vec!["e"])),
1254 ],
1255 )
1256 .unwrap();
1257
1258 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1259 .iter()
1260 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1261 .unzip();
1262 let schema = Arc::new(Schema::new(vec![Field::new_union(
1263 "union",
1264 type_ids.clone(),
1265 union_fields.clone(),
1266 UnionMode::Sparse,
1267 )]));
1268
1269 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1270 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1271 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1272
1273 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
1274
1275 let encoder = FlightDataEncoderBuilder::default().build(stream);
1276
1277 let mut decoder = FlightDataDecoder::new(encoder);
1278
1279 let hydrated_struct_fields = vec![Field::new_list(
1280 "dict_list",
1281 Field::new_list_field(DataType::Utf8, true),
1282 true,
1283 )];
1284
1285 let hydrated_union_fields = vec![
1286 Field::new_list(
1287 "dict_list",
1288 Field::new_list_field(DataType::Utf8, true),
1289 true,
1290 ),
1291 Field::new_struct("struct", hydrated_struct_fields.clone(), true),
1292 Field::new("string", DataType::Utf8, true),
1293 ];
1294
1295 let expected_schema = Schema::new(vec![Field::new_union(
1296 "union",
1297 type_ids.clone(),
1298 hydrated_union_fields,
1299 UnionMode::Sparse,
1300 )]);
1301
1302 let expected_schema = Arc::new(expected_schema);
1303
1304 let mut expected_arrays = vec![
1305 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1306 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1307 StringArray::from(vec!["e"]),
1308 ]
1309 .into_iter();
1310
1311 let mut batch = 0;
1312 while let Some(decoded) = decoder.next().await {
1313 let decoded = decoded.unwrap();
1314 match decoded.payload {
1315 DecodedPayload::None => {}
1316 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1317 DecodedPayload::RecordBatch(b) => {
1318 assert_eq!(b.schema(), expected_schema);
1319 let expected_array = expected_arrays.next().unwrap();
1320 let union_arr =
1321 downcast_array::<UnionArray>(b.column_by_name("union").unwrap());
1322
1323 let elem_array = match batch {
1324 0 => {
1325 let list_array = downcast_array::<ListArray>(union_arr.child(0));
1326 downcast_array::<StringArray>(list_array.value(0).as_ref())
1327 }
1328 1 => {
1329 let struct_array = downcast_array::<StructArray>(union_arr.child(1));
1330 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1331
1332 downcast_array::<StringArray>(list_array.value(0).as_ref())
1333 }
1334 _ => downcast_array::<StringArray>(union_arr.child(2)),
1335 };
1336
1337 batch += 1;
1338
1339 assert_eq!(elem_array, expected_array);
1340 }
1341 }
1342 }
1343 }
1344
1345 #[tokio::test]
1346 async fn test_dictionary_union_resend() {
1347 let struct_fields = vec![Field::new_list(
1348 "dict_list",
1349 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1350 true,
1351 )];
1352
1353 let union_fields = [
1354 (
1355 0,
1356 Arc::new(Field::new_list(
1357 "dict_list",
1358 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1359 true,
1360 )),
1361 ),
1362 (
1363 1,
1364 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1365 ),
1366 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1367 ]
1368 .into_iter()
1369 .collect::<UnionFields>();
1370
1371 let mut field_types = union_fields.iter().map(|(_, field)| field.data_type());
1372 let dict_list_ty = field_types.next().unwrap();
1373 let struct_ty = field_types.next().unwrap();
1374 let string_ty = field_types.next().unwrap();
1375
1376 let struct_fields = vec![Field::new_list(
1377 "dict_list",
1378 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1379 true,
1380 )];
1381
1382 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1383
1384 builder.append_value(vec![Some("a"), None, Some("b")]);
1385
1386 let arr1 = builder.finish();
1387
1388 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1389 let arr1 = UnionArray::try_new(
1390 union_fields.clone(),
1391 type_id_buffer,
1392 None,
1393 vec![
1394 Arc::new(arr1),
1395 new_null_array(struct_ty, 1),
1396 new_null_array(string_ty, 1),
1397 ],
1398 )
1399 .unwrap();
1400
1401 builder.append_value(vec![Some("c"), None, Some("d")]);
1402
1403 let arr2 = Arc::new(builder.finish());
1404 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1405
1406 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1407 let arr2 = UnionArray::try_new(
1408 union_fields.clone(),
1409 type_id_buffer,
1410 None,
1411 vec![
1412 new_null_array(dict_list_ty, 1),
1413 Arc::new(arr2),
1414 new_null_array(string_ty, 1),
1415 ],
1416 )
1417 .unwrap();
1418
1419 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1420 let arr3 = UnionArray::try_new(
1421 union_fields.clone(),
1422 type_id_buffer,
1423 None,
1424 vec![
1425 new_null_array(dict_list_ty, 1),
1426 new_null_array(struct_ty, 1),
1427 Arc::new(StringArray::from(vec!["e"])),
1428 ],
1429 )
1430 .unwrap();
1431
1432 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1433 .iter()
1434 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1435 .unzip();
1436 let schema = Arc::new(Schema::new(vec![Field::new_union(
1437 "union",
1438 type_ids.clone(),
1439 union_fields.clone(),
1440 UnionMode::Sparse,
1441 )]));
1442
1443 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1444 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1445 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1446
1447 verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
1448 }
1449
1450 #[tokio::test]
1451 async fn test_dictionary_map_hydration() {
1452 let mut builder = MapBuilder::new(
1453 None,
1454 StringDictionaryBuilder::<UInt16Type>::new(),
1455 StringDictionaryBuilder::<UInt16Type>::new(),
1456 );
1457
1458 builder.keys().append_value("k1");
1460 builder.values().append_value("a");
1461 builder.keys().append_value("k2");
1462 builder.values().append_null();
1463 builder.keys().append_value("k3");
1464 builder.values().append_value("b");
1465 builder.append(true).unwrap();
1466
1467 let arr1 = builder.finish();
1468
1469 builder.keys().append_value("k1");
1471 builder.values().append_value("c");
1472 builder.keys().append_value("k2");
1473 builder.values().append_null();
1474 builder.keys().append_value("k3");
1475 builder.values().append_value("d");
1476 builder.append(true).unwrap();
1477
1478 let arr2 = builder.finish();
1479
1480 let schema = Arc::new(Schema::new(vec![Field::new_map(
1481 "dict_map",
1482 "entries",
1483 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1484 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1485 false,
1486 false,
1487 )]));
1488
1489 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1490 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1491
1492 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1493
1494 let encoder = FlightDataEncoderBuilder::default().build(stream);
1495
1496 let mut decoder = FlightDataDecoder::new(encoder);
1497 let expected_schema = Schema::new(vec![Field::new_map(
1498 "dict_map",
1499 "entries",
1500 Field::new("keys", DataType::Utf8, false),
1501 Field::new("values", DataType::Utf8, true),
1502 false,
1503 false,
1504 )]);
1505
1506 let expected_schema = Arc::new(expected_schema);
1507
1508 let mut builder = MapBuilder::new(
1510 None,
1511 GenericStringBuilder::<i32>::new(),
1512 GenericStringBuilder::<i32>::new(),
1513 );
1514
1515 builder.keys().append_value("k1");
1517 builder.values().append_value("a");
1518 builder.keys().append_value("k2");
1519 builder.values().append_null();
1520 builder.keys().append_value("k3");
1521 builder.values().append_value("b");
1522 builder.append(true).unwrap();
1523
1524 let arr1 = builder.finish();
1525
1526 builder.keys().append_value("k1");
1528 builder.values().append_value("c");
1529 builder.keys().append_value("k2");
1530 builder.values().append_null();
1531 builder.keys().append_value("k3");
1532 builder.values().append_value("d");
1533 builder.append(true).unwrap();
1534
1535 let arr2 = builder.finish();
1536
1537 let mut expected_arrays = vec![arr1, arr2].into_iter();
1538
1539 while let Some(decoded) = decoder.next().await {
1540 let decoded = decoded.unwrap();
1541 match decoded.payload {
1542 DecodedPayload::None => {}
1543 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1544 DecodedPayload::RecordBatch(b) => {
1545 assert_eq!(b.schema(), expected_schema);
1546 let expected_array = expected_arrays.next().unwrap();
1547 let map_array =
1548 downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());
1549
1550 assert_eq!(map_array, expected_array);
1551 }
1552 }
1553 }
1554 }
1555
1556 #[tokio::test]
1557 async fn test_dictionary_map_resend() {
1558 let mut builder = MapBuilder::new(
1559 None,
1560 StringDictionaryBuilder::<UInt16Type>::new(),
1561 StringDictionaryBuilder::<UInt16Type>::new(),
1562 );
1563
1564 builder.keys().append_value("k1");
1566 builder.values().append_value("a");
1567 builder.keys().append_value("k2");
1568 builder.values().append_null();
1569 builder.keys().append_value("k3");
1570 builder.values().append_value("b");
1571 builder.append(true).unwrap();
1572
1573 let arr1 = builder.finish();
1574
1575 builder.keys().append_value("k1");
1577 builder.values().append_value("c");
1578 builder.keys().append_value("k2");
1579 builder.values().append_null();
1580 builder.keys().append_value("k3");
1581 builder.values().append_value("d");
1582 builder.append(true).unwrap();
1583
1584 let arr2 = builder.finish();
1585
1586 let schema = Arc::new(Schema::new(vec![Field::new_map(
1587 "dict_map",
1588 "entries",
1589 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1590 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1591 false,
1592 false,
1593 )]));
1594
1595 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1596 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1597
1598 verify_flight_round_trip(vec![batch1, batch2]).await;
1599 }
1600
1601 #[tokio::test]
1602 async fn test_dictionary_ree_resend() {
1603 let dict_values1 = vec![Some("a"), None, Some("b")]
1604 .into_iter()
1605 .collect::<DictionaryArray<Int32Type>>();
1606 let run_ends1 = Int32Array::from(vec![1, 2, 3]);
1607 let arr1 = RunArray::try_new(&run_ends1, &dict_values1).unwrap();
1608
1609 let dict_values2 = vec![Some("c"), Some("a")]
1610 .into_iter()
1611 .collect::<DictionaryArray<Int32Type>>();
1612 let run_ends2 = Int32Array::from(vec![1, 2]);
1613 let arr2 = RunArray::try_new(&run_ends2, &dict_values2).unwrap();
1614
1615 let schema = Arc::new(Schema::new(vec![Field::new(
1616 "ree",
1617 arr1.data_type().clone(),
1618 true,
1619 )]));
1620
1621 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1622 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1623
1624 verify_flight_round_trip(vec![batch1, batch2]).await;
1625 }
1626
1627 #[tokio::test]
1628 async fn test_dictionary_of_struct_of_dict_resend() {
1629 let struct_fields: Vec<Field> = vec![
1633 Field::new_dictionary("dict", DataType::Int32, DataType::Utf8, true),
1634 Field::new("int", DataType::Int32, false),
1635 ];
1636
1637 let inner_values =
1638 StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
1639 let inner_keys = Int32Array::from_iter_values([0, 1, 2, 3, 0]);
1640 let inner_dict = DictionaryArray::new(inner_keys, Arc::new(inner_values));
1641 let int_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
1642
1643 let struct_array = StructArray::from(vec![
1644 (
1645 Arc::new(struct_fields[0].clone()),
1646 Arc::new(inner_dict) as ArrayRef,
1647 ),
1648 (
1649 Arc::new(struct_fields[1].clone()),
1650 Arc::new(int_array) as ArrayRef,
1651 ),
1652 ]);
1653
1654 let outer_keys = Int8Array::from_iter_values([0, 0, 1, 2]);
1655 let arr1 = DictionaryArray::new(outer_keys, Arc::new(struct_array));
1656
1657 let inner_values2 = StringArray::from(vec![Some("x"), Some("y")]);
1658 let inner_keys2 = Int32Array::from_iter_values([0, 1, 0]);
1659 let inner_dict2 = DictionaryArray::new(inner_keys2, Arc::new(inner_values2));
1660 let int_array2 = Int32Array::from(vec![100, 200, 300]);
1661
1662 let struct_array2 = StructArray::from(vec![
1663 (
1664 Arc::new(struct_fields[0].clone()),
1665 Arc::new(inner_dict2) as ArrayRef,
1666 ),
1667 (
1668 Arc::new(struct_fields[1].clone()),
1669 Arc::new(int_array2) as ArrayRef,
1670 ),
1671 ]);
1672
1673 let outer_keys2 = Int8Array::from_iter_values([0, 1]);
1674 let arr2 = DictionaryArray::new(outer_keys2, Arc::new(struct_array2));
1675
1676 let schema = Arc::new(Schema::new(vec![Field::new(
1677 "dict_struct",
1678 arr1.data_type().clone(),
1679 false,
1680 )]));
1681
1682 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1683 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1684
1685 verify_flight_round_trip(vec![batch1, batch2]).await;
1686 }
1687
1688 async fn verify_dictionary_list_view_resend<O: OffsetSizeTrait>() {
1689 let mut builder =
1690 GenericListViewBuilder::<O, _>::new(StringDictionaryBuilder::<UInt16Type>::new());
1691
1692 builder.append_value(vec![Some("a"), None, Some("b")]);
1693 let arr1 = builder.finish();
1694
1695 builder.append_value(vec![Some("c"), None, Some("d")]);
1696 let arr2 = builder.finish();
1697
1698 let inner = Arc::new(Field::new_dictionary(
1699 "item",
1700 DataType::UInt16,
1701 DataType::Utf8,
1702 true,
1703 ));
1704 let dt = if O::IS_LARGE {
1705 DataType::LargeListView(inner)
1706 } else {
1707 DataType::ListView(inner)
1708 };
1709 let schema = Arc::new(Schema::new(vec![Field::new("dict_list_view", dt, true)]));
1710
1711 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1712 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1713
1714 verify_flight_round_trip(vec![batch1, batch2]).await;
1715 }
1716
1717 #[tokio::test]
1718 async fn test_dictionary_list_view_resend() {
1719 verify_dictionary_list_view_resend::<i32>().await;
1720 }
1721
1722 #[tokio::test]
1723 async fn test_dictionary_large_list_view_resend() {
1724 verify_dictionary_list_view_resend::<i64>().await;
1725 }
1726
1727 #[tokio::test]
1728 async fn test_dictionary_fixed_size_list_resend() {
1729 let mut builder =
1730 FixedSizeListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new(), 2);
1731
1732 builder.values().append_value("a");
1733 builder.values().append_value("b");
1734 builder.append(true);
1735 let arr1 = builder.finish();
1736
1737 builder.values().append_value("c");
1738 builder.values().append_value("d");
1739 builder.append(true);
1740 let arr2 = builder.finish();
1741
1742 let schema = Arc::new(Schema::new(vec![Field::new_fixed_size_list(
1743 "dict_fsl",
1744 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1745 2,
1746 true,
1747 )]));
1748
1749 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1750 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1751
1752 verify_flight_round_trip(vec![batch1, batch2]).await;
1753 }
1754
1755 async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
1756 let expected_schema = batches.first().unwrap().schema();
1757
1758 let encoder = FlightDataEncoderBuilder::default()
1759 .with_options(IpcWriteOptions::default())
1760 .with_dictionary_handling(DictionaryHandling::Resend)
1761 .build(futures::stream::iter(batches.clone().into_iter().map(Ok)));
1762
1763 let mut expected_batches = batches.drain(..);
1764
1765 let mut decoder = FlightDataDecoder::new(encoder);
1766 while let Some(decoded) = decoder.next().await {
1767 let decoded = decoded.unwrap();
1768 match decoded.payload {
1769 DecodedPayload::None => {}
1770 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1771 DecodedPayload::RecordBatch(b) => {
1772 let expected_batch = expected_batches.next().unwrap();
1773 assert_eq!(b, expected_batch);
1774 }
1775 }
1776 }
1777 }
1778
1779 #[test]
1780 fn test_schema_metadata_encoded() {
1781 let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata(
1782 HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
1783 );
1784
1785 let mut dictionary_tracker = DictionaryTracker::new(false);
1786
1787 let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
1788 assert!(got.metadata().contains_key("some_key"));
1789 }
1790
1791 #[test]
1792 fn test_encode_no_column_batch() {
1793 let batch = RecordBatch::try_new_with_options(
1794 Arc::new(Schema::empty()),
1795 vec![],
1796 &RecordBatchOptions::new().with_row_count(Some(10)),
1797 )
1798 .expect("cannot create record batch");
1799
1800 hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize");
1801 }
1802
1803 fn make_flight_data(
1804 batch: &RecordBatch,
1805 options: &IpcWriteOptions,
1806 ) -> (Vec<FlightData>, FlightData) {
1807 flight_data_from_arrow_batch(batch, options)
1808 }
1809
1810 fn flight_data_from_arrow_batch(
1811 batch: &RecordBatch,
1812 options: &IpcWriteOptions,
1813 ) -> (Vec<FlightData>, FlightData) {
1814 let data_gen = IpcDataGenerator::default();
1815 let mut dictionary_tracker = DictionaryTracker::new(false);
1816 let mut compression_context = CompressionContext::default();
1817
1818 let (encoded_dictionaries, encoded_batch) = data_gen
1819 .encode(
1820 batch,
1821 &mut dictionary_tracker,
1822 options,
1823 &mut compression_context,
1824 )
1825 .expect("DictionaryTracker configured above to not error on replacement");
1826
1827 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
1828 let flight_batch = encoded_batch.into();
1829
1830 (flight_dictionaries, flight_batch)
1831 }
1832
1833 #[test]
1834 fn test_split_batch_for_grpc_response() {
1835 let max_flight_data_size = 1024;
1836
1837 let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
1839 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1840 .expect("cannot create record batch");
1841 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1842 assert_eq!(split.len(), 1);
1843 assert_eq!(batch, split[0]);
1844
1845 let n_rows = max_flight_data_size + 1;
1847 assert!(n_rows % 2 == 1, "should be an odd number");
1848 let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
1849 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1850 .expect("cannot create record batch");
1851 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1852 assert_eq!(split.len(), 3);
1853 assert_eq!(
1854 split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
1855 n_rows
1856 );
1857 let a = pretty_format_batches(&split).unwrap().to_string();
1858 let b = pretty_format_batches(&[batch]).unwrap().to_string();
1859 assert_eq!(a, b);
1860 }
1861
1862 #[test]
1863 fn test_split_batch_for_grpc_response_sizes() {
1864 verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
1866
1867 verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
1869
1870 verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
1872
1873 verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
1875
1876 verify_split(10, 1024, vec![10]);
1878 }
1879
1880 fn verify_split(
1884 num_input_rows: u64,
1885 max_flight_data_size_bytes: usize,
1886 expected_sizes: Vec<usize>,
1887 ) {
1888 let array: UInt64Array = (0..num_input_rows).collect();
1889
1890 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
1891 .expect("cannot create record batch");
1892
1893 let input_rows = batch.num_rows();
1894
1895 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
1896 let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
1897 let output_rows: usize = sizes.iter().sum();
1898
1899 assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
1900 assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
1901 }
1902
1903 #[tokio::test]
1907 async fn flight_data_size_even() {
1908 let s1 = StringArray::from_iter_values(std::iter::repeat_n(".10 bytes.", 1024));
1909 let i1 = Int16Array::from_iter_values(0..1024);
1910 let s2 = StringArray::from_iter_values(std::iter::repeat_n("6bytes", 1024));
1911 let i2 = Int64Array::from_iter_values(0..1024);
1912
1913 let batch = RecordBatch::try_from_iter(vec![
1914 ("s1", Arc::new(s1) as _),
1915 ("i1", Arc::new(i1) as _),
1916 ("s2", Arc::new(s2) as _),
1917 ("i2", Arc::new(i2) as _),
1918 ])
1919 .unwrap();
1920
1921 verify_encoded_split(batch, 120).await;
1922 }
1923
1924 #[tokio::test]
1925 async fn flight_data_size_uneven_variable_lengths() {
1926 let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
1928 let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
1929
1930 verify_encoded_split(batch, 4312).await;
1933 }
1934
1935 #[tokio::test]
1936 async fn flight_data_size_large_row() {
1937 let array1 = StringArray::from_iter_values(vec![
1939 "*".repeat(500),
1940 "*".repeat(500),
1941 "*".repeat(500),
1942 "*".repeat(500),
1943 ]);
1944 let array2 = StringArray::from_iter_values(vec![
1945 "*".to_string(),
1946 "*".repeat(1000),
1947 "*".repeat(2000),
1948 "*".repeat(4000),
1949 ]);
1950
1951 let array3 = StringArray::from_iter_values(vec![
1952 "*".to_string(),
1953 "*".to_string(),
1954 "*".repeat(1000),
1955 "*".repeat(2000),
1956 ]);
1957
1958 let batch = RecordBatch::try_from_iter(vec![
1959 ("a1", Arc::new(array1) as _),
1960 ("a2", Arc::new(array2) as _),
1961 ("a3", Arc::new(array3) as _),
1962 ])
1963 .unwrap();
1964
1965 verify_encoded_split(batch, 5808).await;
1969 }
1970
1971 #[tokio::test]
1972 async fn flight_data_size_string_dictionary() {
1973 let array: DictionaryArray<Int32Type> = (1..1024)
1975 .map(|i| match i % 3 {
1976 0 => Some("value0"),
1977 1 => Some("value1"),
1978 _ => None,
1979 })
1980 .collect();
1981
1982 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1983
1984 verify_encoded_split(batch, 56).await;
1985 }
1986
1987 #[tokio::test]
1988 async fn flight_data_size_large_dictionary() {
1989 let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1991
1992 let array: DictionaryArray<Int32Type> = values.iter().map(|s| Some(s.as_str())).collect();
1993
1994 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1995
1996 verify_encoded_split(batch, 3336).await;
1999 }
2000
2001 #[tokio::test]
2002 async fn flight_data_size_large_dictionary_repeated_non_uniform() {
2003 let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
2005 let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
2006 let array = DictionaryArray::new(keys, Arc::new(values));
2007
2008 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
2009
2010 verify_encoded_split(batch, 5288).await;
2013 }
2014
2015 #[tokio::test]
2016 async fn flight_data_size_multiple_dictionaries() {
2017 let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
2019 let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
2021 let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
2023
2024 let array1: DictionaryArray<Int32Type> = values1.iter().map(|s| Some(s.as_str())).collect();
2025 let array2: DictionaryArray<Int32Type> = values2.iter().map(|s| Some(s.as_str())).collect();
2026 let array3: DictionaryArray<Int32Type> = values3.iter().map(|s| Some(s.as_str())).collect();
2027
2028 let batch = RecordBatch::try_from_iter(vec![
2029 ("a1", Arc::new(array1) as _),
2030 ("a2", Arc::new(array2) as _),
2031 ("a3", Arc::new(array3) as _),
2032 ])
2033 .unwrap();
2034
2035 verify_encoded_split(batch, 4136).await;
2038 }
2039
2040 fn flight_data_size(d: &FlightData) -> usize {
2042 let flight_descriptor_size = d
2043 .flight_descriptor
2044 .as_ref()
2045 .map(|descriptor| {
2046 let path_len: usize = descriptor.path.iter().map(|p| p.len()).sum();
2047
2048 std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
2049 })
2050 .unwrap_or(0);
2051
2052 flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len()
2053 }
2054
2055 async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
2071 let num_rows = batch.num_rows();
2072
2073 let mut max_overage_seen = 0;
2075
2076 for max_flight_data_size in [1024, 2021, 5000] {
2077 println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
2078
2079 let mut stream = FlightDataEncoderBuilder::new()
2080 .with_max_flight_data_size(max_flight_data_size)
2081 .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap())
2083 .build(futures::stream::iter([Ok(batch.clone())]));
2084
2085 let mut i = 0;
2086 while let Some(data) = stream.next().await.transpose().unwrap() {
2087 let actual_data_size = flight_data_size(&data);
2088
2089 let actual_overage = actual_data_size.saturating_sub(max_flight_data_size);
2090
2091 assert!(
2092 actual_overage <= allowed_overage,
2093 "encoded data[{i}]: actual size {actual_data_size}, \
2094 actual_overage: {actual_overage} \
2095 allowed_overage: {allowed_overage}"
2096 );
2097
2098 i += 1;
2099
2100 max_overage_seen = max_overage_seen.max(actual_overage)
2101 }
2102 }
2103
2104 assert_eq!(
2108 allowed_overage, max_overage_seen,
2109 "Specified overage was too high"
2110 );
2111 }
2112}