Skip to main content

arrow_flight/
encode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
19
20use crate::{FlightData, FlightDescriptor, SchemaAsIpc, error::Result};
21
22use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
24
25use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
26use bytes::Bytes;
27use futures::{Stream, StreamExt, ready, stream::BoxStream};
28
29/// Creates a [`Stream`] of [`FlightData`]s from a
30/// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>.
31///
32/// This can be used to implement [`FlightService::do_get`] in an
33/// Arrow Flight implementation;
34///
35/// This structure encodes a stream of `Result`s rather than `RecordBatch`es  to
36/// propagate errors from streaming execution, where the generation of the
37/// `RecordBatch`es is incremental, and an error may occur even after
38/// several have already been successfully produced.
39///
40/// # Caveats
41/// 1. When [`DictionaryHandling`] is [`DictionaryHandling::Hydrate`],
42///    [`DictionaryArray`]s are converted to their underlying types prior to
43///    transport.
44///    When [`DictionaryHandling`] is [`DictionaryHandling::Resend`], Dictionary [`FlightData`] is sent with every
45///    [`RecordBatch`] that contains a [`DictionaryArray`](arrow_array::array::DictionaryArray).
46///    See <https://github.com/apache/arrow-rs/issues/3389>.
47///
48/// [`DictionaryArray`]: arrow_array::array::DictionaryArray
49///
50/// # Example
51/// ```no_run
52/// # use std::sync::Arc;
53/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
54/// # async fn f() {
55/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
56/// # let batch = RecordBatch::try_from_iter(vec![
57/// #      ("a", Arc::new(c1) as ArrayRef)
58/// #   ])
59/// #   .expect("cannot create record batch");
60/// use arrow_flight::encode::FlightDataEncoderBuilder;
61///
62/// // Get an input stream of Result<RecordBatch, FlightError>
63/// let input_stream = futures::stream::iter(vec![Ok(batch)]);
64///
65/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
66/// let flight_data_stream = FlightDataEncoderBuilder::new()
67///  .build(input_stream);
68///
69/// // Create a tonic `Response` that can be returned from a Flight server
70/// let response = tonic::Response::new(flight_data_stream);
71/// # }
72/// ```
73///
74/// # Example: Sending `Vec<RecordBatch>`
75///
76/// You can create a [`Stream`] to pass to [`Self::build`] from an existing
77/// `Vec` of `RecordBatch`es like this:
78///
79/// ```
80/// # use std::sync::Arc;
81/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
82/// # async fn f() {
83/// # fn make_batches() -> Vec<RecordBatch> {
84/// #   let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
85/// #   let batch = RecordBatch::try_from_iter(vec![
86/// #      ("a", Arc::new(c1) as ArrayRef)
87/// #   ])
88/// #   .expect("cannot create record batch");
89/// #   vec![batch.clone(), batch.clone()]
90/// # }
91/// use arrow_flight::encode::FlightDataEncoderBuilder;
92///
93/// // Get batches that you want to send via Flight
94/// let batches: Vec<RecordBatch> = make_batches();
95///
96/// // Create an input stream of Result<RecordBatch, FlightError>
97/// let input_stream = futures::stream::iter(
98///   batches.into_iter().map(Ok)
99/// );
100///
101/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
102/// let flight_data_stream = FlightDataEncoderBuilder::new()
103///  .build(input_stream);
104/// # }
105/// ```
106///
107/// # Example: Determining schema of encoded data
108///
109/// Encoding flight data may hydrate dictionaries, see [`DictionaryHandling`] for more information,
110/// which changes the schema of the encoded data compared to the input record batches.
111/// The fully hydrated schema can be accessed using the [`FlightDataEncoder::known_schema`] method
112/// and explicitly informing the builder of the schema using [`FlightDataEncoderBuilder::with_schema`].
113///
114/// ```
115/// # use std::sync::Arc;
116/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
117/// # async fn f() {
118/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
119/// # let batch = RecordBatch::try_from_iter(vec![
120/// #      ("a", Arc::new(c1) as ArrayRef)
121/// #   ])
122/// #   .expect("cannot create record batch");
123/// use arrow_flight::encode::FlightDataEncoderBuilder;
124///
125/// // Get the schema of the input stream
126/// let schema = batch.schema();
127///
128/// // Get an input stream of Result<RecordBatch, FlightError>
129/// let input_stream = futures::stream::iter(vec![Ok(batch)]);
130///
131/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
132/// let flight_data_stream = FlightDataEncoderBuilder::new()
133///  // Inform the builder of the input stream schema
134///  .with_schema(schema)
135///  .build(input_stream);
136///
137/// // Retrieve the schema of the encoded data
138/// let encoded_schema = flight_data_stream.known_schema();
139/// # }
140/// ```
141///
142/// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get
143/// [`FlightError`]: crate::error::FlightError
144#[derive(Debug)]
145pub struct FlightDataEncoderBuilder {
146    /// The maximum approximate target message size in bytes
147    /// (see details on [`Self::with_max_flight_data_size`]).
148    max_flight_data_size: usize,
149    /// Ipc writer options
150    options: IpcWriteOptions,
151    /// Metadata to add to the schema message
152    app_metadata: Bytes,
153    /// Optional schema, if known before data.
154    schema: Option<SchemaRef>,
155    /// Optional flight descriptor, if known before data.
156    descriptor: Option<FlightDescriptor>,
157    /// Deterimines how `DictionaryArray`s are encoded for transport.
158    /// See [`DictionaryHandling`] for more information.
159    dictionary_handling: DictionaryHandling,
160}
161
162/// Default target size for encoded [`FlightData`].
163///
164/// Note this value would normally be 4MB, but the size calculation is
165/// somewhat inexact, so we set it to 2MB.
166pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
167
168impl Default for FlightDataEncoderBuilder {
169    fn default() -> Self {
170        Self {
171            max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
172            options: IpcWriteOptions::default(),
173            app_metadata: Bytes::new(),
174            schema: None,
175            descriptor: None,
176            dictionary_handling: DictionaryHandling::Hydrate,
177        }
178    }
179}
180
181impl FlightDataEncoderBuilder {
182    /// Create a new [`FlightDataEncoderBuilder`].
183    pub fn new() -> Self {
184        Self::default()
185    }
186
187    /// Set the (approximate) maximum size, in bytes, of the
188    /// [`FlightData`] produced by this encoder. Defaults to 2MB.
189    ///
190    /// Since there is often a maximum message size for gRPC messages
191    /// (typically around 4MB), this encoder splits up [`RecordBatch`]s
192    /// (preserving order) into multiple [`FlightData`] objects to
193    /// limit the size individual messages sent via gRPC.
194    ///
195    /// The size is approximate because of the additional encoding
196    /// overhead on top of the underlying data buffers themselves.
197    pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
198        self.max_flight_data_size = max_flight_data_size;
199        self
200    }
201
202    /// Set [`DictionaryHandling`] for encoder
203    pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
204        self.dictionary_handling = dictionary_handling;
205        self
206    }
207
208    /// Specify application specific metadata included in the
209    /// [`FlightData::app_metadata`] field of the the first Schema
210    /// message
211    pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
212        self.app_metadata = app_metadata;
213        self
214    }
215
216    /// Set the [`IpcWriteOptions`] used to encode the [`RecordBatch`]es for transport.
217    pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
218        self.options = options;
219        self
220    }
221
222    /// Specify a schema for the RecordBatches being sent. If a schema
223    /// is not specified, an encoded Schema message will be sent when
224    /// the first [`RecordBatch`], if any, is encoded. Some clients
225    /// expect a Schema message even if there is no data sent.
226    pub fn with_schema(mut self, schema: SchemaRef) -> Self {
227        self.schema = Some(schema);
228        self
229    }
230
231    /// Specify a flight descriptor in the first FlightData message.
232    pub fn with_flight_descriptor(mut self, descriptor: Option<FlightDescriptor>) -> Self {
233        self.descriptor = descriptor;
234        self
235    }
236
237    /// Takes a [`Stream`] of [`Result<RecordBatch>`] and returns a [`Stream`]
238    /// of [`FlightData`], consuming self.
239    ///
240    /// See example on [`Self`] and [`FlightDataEncoder`] for more details
241    pub fn build<S>(self, input: S) -> FlightDataEncoder
242    where
243        S: Stream<Item = Result<RecordBatch>> + Send + 'static,
244    {
245        let Self {
246            max_flight_data_size,
247            options,
248            app_metadata,
249            schema,
250            descriptor,
251            dictionary_handling,
252        } = self;
253
254        FlightDataEncoder::new(
255            input.boxed(),
256            schema,
257            max_flight_data_size,
258            options,
259            app_metadata,
260            descriptor,
261            dictionary_handling,
262        )
263    }
264}
265
266/// Stream that encodes a stream of record batches to flight data.
267///
268/// See [`FlightDataEncoderBuilder`] for details and example.
269pub struct FlightDataEncoder {
270    /// Input stream
271    inner: BoxStream<'static, Result<RecordBatch>>,
272    /// schema, set after the first batch
273    schema: Option<SchemaRef>,
274    /// Target maximum size of flight data
275    /// (see details on [`FlightDataEncoderBuilder::with_max_flight_data_size`]).
276    max_flight_data_size: usize,
277    /// do the encoding / tracking of dictionaries
278    encoder: FlightIpcEncoder,
279    /// optional metadata to add to schema FlightData
280    app_metadata: Option<Bytes>,
281    /// data queued up to send but not yet sent
282    queue: VecDeque<FlightData>,
283    /// Is this stream done (inner is empty or errored)
284    done: bool,
285    /// cleared after the first FlightData message is sent
286    descriptor: Option<FlightDescriptor>,
287    /// Deterimines how `DictionaryArray`s are encoded for transport.
288    /// See [`DictionaryHandling`] for more information.
289    dictionary_handling: DictionaryHandling,
290}
291
292impl FlightDataEncoder {
293    fn new(
294        inner: BoxStream<'static, Result<RecordBatch>>,
295        schema: Option<SchemaRef>,
296        max_flight_data_size: usize,
297        options: IpcWriteOptions,
298        app_metadata: Bytes,
299        descriptor: Option<FlightDescriptor>,
300        dictionary_handling: DictionaryHandling,
301    ) -> Self {
302        let mut encoder = Self {
303            inner,
304            schema: None,
305            max_flight_data_size,
306            encoder: FlightIpcEncoder::new(
307                options,
308                dictionary_handling != DictionaryHandling::Resend,
309            ),
310            app_metadata: Some(app_metadata),
311            queue: VecDeque::new(),
312            done: false,
313            descriptor,
314            dictionary_handling,
315        };
316
317        // If schema is known up front, enqueue it immediately
318        if let Some(schema) = schema {
319            encoder.encode_schema(&schema);
320        }
321
322        encoder
323    }
324
325    /// Report the schema of the encoded data when known.
326    /// A schema is known when provided via the [`FlightDataEncoderBuilder::with_schema`] method.
327    pub fn known_schema(&self) -> Option<SchemaRef> {
328        self.schema.clone()
329    }
330
331    /// Place the `FlightData` in the queue to send
332    fn queue_message(&mut self, mut data: FlightData) {
333        if let Some(descriptor) = self.descriptor.take() {
334            data.flight_descriptor = Some(descriptor);
335        }
336        self.queue.push_back(data);
337    }
338
339    /// Place the `FlightData` in the queue to send
340    fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
341        for data in datas {
342            self.queue_message(data)
343        }
344    }
345
346    /// Encodes schema as a [`FlightData`] in self.queue.
347    /// Updates `self.schema` and returns the new schema
348    fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
349        // The first message is the schema message, and all
350        // batches have the same schema
351        let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend;
352        let schema = Arc::new(prepare_schema_for_flight(
353            schema,
354            &mut self.encoder.dictionary_tracker,
355            send_dictionaries,
356        ));
357        let mut schema_flight_data = self.encoder.encode_schema(&schema);
358
359        // attach any metadata requested
360        if let Some(app_metadata) = self.app_metadata.take() {
361            schema_flight_data.app_metadata = app_metadata;
362        }
363        self.queue_message(schema_flight_data);
364        // remember schema
365        self.schema = Some(schema.clone());
366        schema
367    }
368
369    /// Encodes batch into one or more `FlightData` messages in self.queue
370    fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
371        let schema = match &self.schema {
372            Some(schema) => schema.clone(),
373            // encode the schema if this is the first time we have seen it
374            None => self.encode_schema(batch.schema_ref()),
375        };
376
377        let batch = match self.dictionary_handling {
378            DictionaryHandling::Resend => batch,
379            DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
380        };
381
382        for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
383            let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
384
385            self.queue_messages(flight_dictionaries);
386            self.queue_message(flight_batch);
387        }
388
389        Ok(())
390    }
391}
392
393impl Stream for FlightDataEncoder {
394    type Item = Result<FlightData>;
395
396    fn poll_next(
397        mut self: Pin<&mut Self>,
398        cx: &mut std::task::Context<'_>,
399    ) -> Poll<Option<Self::Item>> {
400        loop {
401            if self.done && self.queue.is_empty() {
402                return Poll::Ready(None);
403            }
404
405            // Any messages queued to send?
406            if let Some(data) = self.queue.pop_front() {
407                return Poll::Ready(Some(Ok(data)));
408            }
409
410            // Get next batch
411            let batch = ready!(self.inner.poll_next_unpin(cx));
412
413            match batch {
414                None => {
415                    // inner is done
416                    self.done = true;
417                    // queue must also be empty so we are done
418                    assert!(self.queue.is_empty());
419                    return Poll::Ready(None);
420                }
421                Some(Err(e)) => {
422                    // error from inner
423                    self.done = true;
424                    self.queue.clear();
425                    return Poll::Ready(Some(Err(e)));
426                }
427                Some(Ok(batch)) => {
428                    // had data, encode into the queue
429                    if let Err(e) = self.encode_batch(batch) {
430                        self.done = true;
431                        self.queue.clear();
432                        return Poll::Ready(Some(Err(e)));
433                    }
434                }
435            }
436        }
437    }
438}
439
440/// Defines how a [`FlightDataEncoder`] encodes [`DictionaryArray`]s
441///
442/// [`DictionaryArray`]: arrow_array::DictionaryArray
443///
444/// In the arrow flight protocol dictionary values and keys are sent as two separate messages.
445/// When a sender is encoding a [`RecordBatch`] containing ['DictionaryArray'] columns, it will
446/// first send a dictionary batch (a batch with header `MessageHeader::DictionaryBatch`) containing
447/// the dictionary values. The receiver is responsible for reading this batch and maintaining state that associates
448/// those dictionary values with the corresponding array using the `dict_id` as a key.
449///
450/// After sending the dictionary batch the sender will send the array data in a batch with header `MessageHeader::RecordBatch`.
451/// For any dictionary array batches in this message, the encoded flight message will only contain the dictionary keys. The receiver
452/// is then responsible for rebuilding the `DictionaryArray` on the client side using the dictionary values from the DictionaryBatch message
453/// and the keys from the RecordBatch message.
454///
455/// For example, if we have a batch with a `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` (a dictionary array where they keys are `u32` and the
456/// values are `String`), then the DictionaryBatch will contain a `StringArray` and the RecordBatch will contain a `UInt32Array`.
457///
458/// Note that since `dict_id` defined in the `Schema` is used as a key to associate dictionary values to their arrays it is required that each
459/// `DictionaryArray` in a `RecordBatch` have a unique `dict_id`.
460///
461/// The current implementation does not support "delta" dictionaries so a new dictionary batch will be sent each time the encoder sees a
462/// dictionary which is not pointer-equal to the previously observed dictionary for a given `dict_id`.
463///
464/// For clients which may not support `DictionaryEncoding`, the `DictionaryHandling::Hydrate` method will bypass the process defined above
465/// and "hydrate" any `DictionaryArray` in the batch to their underlying value type (e.g. `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` will
466/// be sent as a `StringArray`). With this method all data will be sent in ``MessageHeader::RecordBatch` messages and the batch schema
467/// will be adjusted so that all dictionary encoded fields are changed to fields of the dictionary value type.
468#[derive(Debug, PartialEq)]
469pub enum DictionaryHandling {
470    /// Expands to the underlying type (default). This likely sends more data
471    /// over the network but requires less memory (dictionaries are not tracked)
472    /// and is more compatible with other arrow flight client implementations
473    /// that may not support `DictionaryEncoding`
474    ///
475    /// See also:
476    /// * <https://github.com/apache/arrow-rs/issues/1206>
477    Hydrate,
478    /// Send dictionary FlightData with every RecordBatch that contains a
479    /// [`DictionaryArray`]. See [`Self::Hydrate`] for more tradeoffs. No
480    /// attempt is made to skip sending the same (logical) dictionary values
481    /// twice.
482    ///
483    /// [`DictionaryArray`]: arrow_array::DictionaryArray
484    ///
485    /// This requires identifying the different dictionaries in use and assigning
486    //  them unique IDs
487    Resend,
488}
489
490fn prepare_field_for_flight(
491    field: &FieldRef,
492    dictionary_tracker: &mut DictionaryTracker,
493    send_dictionaries: bool,
494) -> Field {
495    match field.data_type() {
496        DataType::List(inner) => Field::new_list(
497            field.name(),
498            prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
499            field.is_nullable(),
500        )
501        .with_metadata(field.metadata().clone()),
502        DataType::LargeList(inner) => Field::new_list(
503            field.name(),
504            prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
505            field.is_nullable(),
506        )
507        .with_metadata(field.metadata().clone()),
508        DataType::Struct(fields) => {
509            let new_fields: Vec<Field> = fields
510                .iter()
511                .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
512                .collect();
513            Field::new_struct(field.name(), new_fields, field.is_nullable())
514                .with_metadata(field.metadata().clone())
515        }
516        DataType::Union(fields, mode) => {
517            let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
518                .iter()
519                .map(|(type_id, f)| {
520                    (
521                        type_id,
522                        prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
523                    )
524                })
525                .unzip();
526
527            Field::new_union(field.name(), type_ids, new_fields, *mode)
528        }
529        DataType::Dictionary(_, value_type) => {
530            if !send_dictionaries {
531                // Recurse into value type to handle nested dicts being stripped
532                let value_field = Field::new(
533                    field.name(),
534                    value_type.as_ref().clone(),
535                    field.is_nullable(),
536                );
537                prepare_field_for_flight(
538                    &Arc::new(value_field),
539                    dictionary_tracker,
540                    send_dictionaries,
541                )
542                .with_metadata(field.metadata().clone())
543            } else {
544                // Recurse into value type BEFORE registering this dict's id,
545                // matching the depth-first order of encode_dictionaries in the
546                // IPC writer which processes nested dicts before the parent.
547                let value_field = Field::new("values", value_type.as_ref().clone(), true);
548                prepare_field_for_flight(
549                    &Arc::new(value_field),
550                    dictionary_tracker,
551                    send_dictionaries,
552                );
553                dictionary_tracker.next_dict_id();
554                #[allow(deprecated)]
555                Field::new_dict(
556                    field.name(),
557                    field.data_type().clone(),
558                    field.is_nullable(),
559                    0,
560                    field.dict_is_ordered().unwrap_or_default(),
561                )
562                .with_metadata(field.metadata().clone())
563            }
564        }
565        DataType::ListView(inner) | DataType::LargeListView(inner) => {
566            let prepared = prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries);
567            Field::new(
568                field.name(),
569                match field.data_type() {
570                    DataType::ListView(_) => DataType::ListView(Arc::new(prepared)),
571                    _ => DataType::LargeListView(Arc::new(prepared)),
572                },
573                field.is_nullable(),
574            )
575            .with_metadata(field.metadata().clone())
576        }
577        DataType::FixedSizeList(inner, size) => Field::new(
578            field.name(),
579            DataType::FixedSizeList(
580                Arc::new(prepare_field_for_flight(
581                    inner,
582                    dictionary_tracker,
583                    send_dictionaries,
584                )),
585                *size,
586            ),
587            field.is_nullable(),
588        )
589        .with_metadata(field.metadata().clone()),
590        DataType::RunEndEncoded(run_ends, values) => Field::new(
591            field.name(),
592            DataType::RunEndEncoded(
593                run_ends.clone(),
594                Arc::new(prepare_field_for_flight(
595                    values,
596                    dictionary_tracker,
597                    send_dictionaries,
598                )),
599            ),
600            field.is_nullable(),
601        )
602        .with_metadata(field.metadata().clone()),
603        DataType::Map(inner, sorted) => Field::new(
604            field.name(),
605            DataType::Map(
606                prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
607                *sorted,
608            ),
609            field.is_nullable(),
610        )
611        .with_metadata(field.metadata().clone()),
612        DataType::Null
613        | DataType::Boolean
614        | DataType::Int8
615        | DataType::Int16
616        | DataType::Int32
617        | DataType::Int64
618        | DataType::UInt8
619        | DataType::UInt16
620        | DataType::UInt32
621        | DataType::UInt64
622        | DataType::Float16
623        | DataType::Float32
624        | DataType::Float64
625        | DataType::Timestamp(_, _)
626        | DataType::Date32
627        | DataType::Date64
628        | DataType::Time32(_)
629        | DataType::Time64(_)
630        | DataType::Duration(_)
631        | DataType::Interval(_)
632        | DataType::Binary
633        | DataType::FixedSizeBinary(_)
634        | DataType::LargeBinary
635        | DataType::BinaryView
636        | DataType::Utf8
637        | DataType::LargeUtf8
638        | DataType::Utf8View
639        | DataType::Decimal32(_, _)
640        | DataType::Decimal64(_, _)
641        | DataType::Decimal128(_, _)
642        | DataType::Decimal256(_, _) => field.as_ref().clone(),
643    }
644}
645
646/// Prepare an arrow Schema for transport over the Arrow Flight protocol
647///
648/// Convert dictionary types to underlying types
649///
650/// See hydrate_dictionary for more information
651fn prepare_schema_for_flight(
652    schema: &Schema,
653    dictionary_tracker: &mut DictionaryTracker,
654    send_dictionaries: bool,
655) -> Schema {
656    let fields: Fields = schema
657        .fields()
658        .iter()
659        .map(|field| prepare_field_for_flight(field, dictionary_tracker, send_dictionaries))
660        .collect();
661
662    Schema::new(fields).with_metadata(schema.metadata().clone())
663}
664
665/// Split [`RecordBatch`] so it hopefully fits into a gRPC response.
666///
667/// Data is zero-copy sliced into batches.
668///
669/// Note: this method does not take into account already sliced
670/// arrays: <https://github.com/apache/arrow-rs/issues/3407>
671fn split_batch_for_grpc_response(
672    batch: RecordBatch,
673    max_flight_data_size: usize,
674) -> Vec<RecordBatch> {
675    let size = batch
676        .columns()
677        .iter()
678        .map(|col| col.get_buffer_memory_size())
679        .sum::<usize>();
680
681    let n_batches =
682        (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
683    let rows_per_batch = (batch.num_rows() / n_batches).max(1);
684    let mut out = Vec::with_capacity(n_batches + 1);
685
686    let mut offset = 0;
687    while offset < batch.num_rows() {
688        let length = (rows_per_batch).min(batch.num_rows() - offset);
689        out.push(batch.slice(offset, length));
690
691        offset += length;
692    }
693
694    out
695}
696
697/// The data needed to encode a stream of flight data, holding on to
698/// shared Dictionaries.
699///
700/// TODO: at allow dictionaries to be flushed / avoid building them
701///
702/// TODO limit on the number of dictionaries???
703struct FlightIpcEncoder {
704    options: IpcWriteOptions,
705    data_gen: IpcDataGenerator,
706    dictionary_tracker: DictionaryTracker,
707    compression_context: CompressionContext,
708}
709
710impl FlightIpcEncoder {
711    fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
712        Self {
713            options,
714            data_gen: IpcDataGenerator::default(),
715            dictionary_tracker: DictionaryTracker::new(error_on_replacement),
716            compression_context: CompressionContext::default(),
717        }
718    }
719
720    /// Encode a schema as a FlightData
721    fn encode_schema(&self, schema: &Schema) -> FlightData {
722        SchemaAsIpc::new(schema, &self.options).into()
723    }
724
725    /// Convert a `RecordBatch` to a Vec of `FlightData` representing
726    /// dictionaries and a `FlightData` representing the batch
727    fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
728        let (encoded_dictionaries, encoded_batch) = self.data_gen.encode(
729            batch,
730            &mut self.dictionary_tracker,
731            &self.options,
732            &mut self.compression_context,
733        )?;
734
735        let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
736        let flight_batch = encoded_batch.into();
737
738        Ok((flight_dictionaries, flight_batch))
739    }
740}
741
742/// Hydrates any dictionaries arrays in `batch` to its underlying type. See
743/// hydrate_dictionary for more information.
744fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result<RecordBatch> {
745    let columns = schema
746        .fields()
747        .iter()
748        .zip(batch.columns())
749        .map(|(field, c)| hydrate_dictionary(c, field.data_type()))
750        .collect::<Result<Vec<_>>>()?;
751
752    let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
753
754    Ok(RecordBatch::try_new_with_options(
755        schema, columns, &options,
756    )?)
757}
758
759/// Hydrates a dictionary to its underlying type.
760fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef> {
761    let arr = match (array.data_type(), data_type) {
762        (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
763            let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();
764
765            Arc::new(UnionArray::try_new(
766                fields.clone(),
767                union_arr.type_ids().clone(),
768                None,
769                fields
770                    .iter()
771                    .map(|(type_id, field)| {
772                        Ok(arrow_cast::cast(
773                            union_arr.child(type_id),
774                            field.data_type(),
775                        )?)
776                    })
777                    .collect::<Result<Vec<_>>>()?,
778            )?)
779        }
780        (_, data_type) => arrow_cast::cast(array, data_type)?,
781    };
782    Ok(arr)
783}
784
785#[cfg(test)]
786mod tests {
787    use crate::decode::{DecodedPayload, FlightDataDecoder};
788    use arrow_array::builder::{
789        FixedSizeListBuilder, GenericByteDictionaryBuilder, GenericListViewBuilder, ListBuilder,
790        StringDictionaryBuilder, StructBuilder,
791    };
792    use arrow_array::*;
793    use arrow_array::{cast::downcast_array, types::*};
794    use arrow_buffer::ScalarBuffer;
795    use arrow_cast::pretty::pretty_format_batches;
796    use arrow_ipc::MetadataVersion;
797    use arrow_schema::{UnionFields, UnionMode};
798    use builder::{GenericStringBuilder, MapBuilder};
799    use std::collections::HashMap;
800
801    use super::*;
802
803    #[test]
804    /// ensure only the batch's used data (not the allocated data) is sent
805    /// <https://github.com/apache/arrow-rs/issues/208>
806    fn test_encode_flight_data() {
807        // use 8-byte alignment - default alignment is 64 which produces bigger ipc data
808        let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
809        let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
810
811        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
812            .expect("cannot create record batch");
813        let schema = batch.schema_ref();
814
815        let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
816
817        let big_batch = batch.slice(0, batch.num_rows() - 1);
818        let optimized_big_batch =
819            hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize");
820        let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options);
821
822        assert_eq!(
823            baseline_flight_batch.data_body.len(),
824            optimized_big_flight_batch.data_body.len()
825        );
826
827        let small_batch = batch.slice(0, 1);
828        let optimized_small_batch =
829            hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize");
830        let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options);
831
832        assert!(
833            baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len()
834        );
835    }
836
837    #[tokio::test]
838    async fn test_dictionary_hydration() {
839        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
840        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
841
842        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
843            "dict",
844            DataType::UInt16,
845            DataType::Utf8,
846            false,
847        )]));
848        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
849        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
850
851        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
852
853        let encoder = FlightDataEncoderBuilder::default().build(stream);
854        let mut decoder = FlightDataDecoder::new(encoder);
855        let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]);
856        let expected_schema = Arc::new(expected_schema);
857        let mut expected_arrays = vec![
858            StringArray::from(vec!["a", "a", "b"]),
859            StringArray::from(vec!["c", "c", "d"]),
860        ]
861        .into_iter();
862        while let Some(decoded) = decoder.next().await {
863            let decoded = decoded.unwrap();
864            match decoded.payload {
865                DecodedPayload::None => {}
866                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
867                DecodedPayload::RecordBatch(b) => {
868                    assert_eq!(b.schema(), expected_schema);
869                    let expected_array = expected_arrays.next().unwrap();
870                    let actual_array = b.column_by_name("dict").unwrap();
871                    let actual_array = downcast_array::<StringArray>(actual_array);
872
873                    assert_eq!(actual_array, expected_array);
874                }
875            }
876        }
877    }
878
879    #[tokio::test]
880    async fn test_dictionary_resend() {
881        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
882        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
883
884        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
885            "dict",
886            DataType::UInt16,
887            DataType::Utf8,
888            false,
889        )]));
890        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
891        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
892
893        verify_flight_round_trip(vec![batch1, batch2]).await;
894    }
895
896    #[tokio::test]
897    async fn test_dictionary_hydration_known_schema() {
898        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
899        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
900
901        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
902            "dict",
903            DataType::UInt16,
904            DataType::Utf8,
905            false,
906        )]));
907        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
908        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
909
910        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
911
912        let encoder = FlightDataEncoderBuilder::default()
913            .with_schema(schema)
914            .build(stream);
915        let expected_schema =
916            Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
917        assert_eq!(Some(expected_schema), encoder.known_schema())
918    }
919
920    #[tokio::test]
921    async fn test_dictionary_resend_known_schema() {
922        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
923        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
924
925        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
926            "dict",
927            DataType::UInt16,
928            DataType::Utf8,
929            false,
930        )]));
931        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
932        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
933
934        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
935
936        let encoder = FlightDataEncoderBuilder::default()
937            .with_dictionary_handling(DictionaryHandling::Resend)
938            .with_schema(schema.clone())
939            .build(stream);
940        assert_eq!(Some(schema), encoder.known_schema())
941    }
942
943    #[tokio::test]
944    async fn test_multiple_dictionaries_resend() {
945        // Create a schema with two dictionary fields that have the same dict ID
946        let schema = Arc::new(Schema::new(vec![
947            Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false),
948            Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false),
949        ]));
950
951        let arr_one_1: Arc<DictionaryArray<UInt16Type>> =
952            Arc::new(vec!["a", "a", "b"].into_iter().collect());
953        let arr_one_2: Arc<DictionaryArray<UInt16Type>> =
954            Arc::new(vec!["c", "c", "d"].into_iter().collect());
955        let arr_two_1: Arc<DictionaryArray<UInt16Type>> =
956            Arc::new(vec!["b", "a", "c"].into_iter().collect());
957        let arr_two_2: Arc<DictionaryArray<UInt16Type>> =
958            Arc::new(vec!["k", "d", "e"].into_iter().collect());
959        let batch1 =
960            RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()])
961                .unwrap();
962        let batch2 =
963            RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()])
964                .unwrap();
965
966        verify_flight_round_trip(vec![batch1, batch2]).await;
967    }
968
969    #[tokio::test]
970    async fn test_dictionary_list_hydration() {
971        let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
972
973        builder.append_value(vec![Some("a"), None, Some("b")]);
974
975        let arr1 = builder.finish();
976
977        builder.append_value(vec![Some("c"), None, Some("d")]);
978
979        let arr2 = builder.finish();
980
981        let schema = Arc::new(Schema::new(vec![Field::new_list(
982            "dict_list",
983            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
984            true,
985        )]));
986
987        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
988        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
989
990        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
991
992        let encoder = FlightDataEncoderBuilder::default().build(stream);
993
994        let mut decoder = FlightDataDecoder::new(encoder);
995        let expected_schema = Schema::new(vec![Field::new_list(
996            "dict_list",
997            Field::new_list_field(DataType::Utf8, true),
998            true,
999        )]);
1000
1001        let expected_schema = Arc::new(expected_schema);
1002
1003        let mut expected_arrays = vec![
1004            StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1005            StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1006        ]
1007        .into_iter();
1008
1009        while let Some(decoded) = decoder.next().await {
1010            let decoded = decoded.unwrap();
1011            match decoded.payload {
1012                DecodedPayload::None => {}
1013                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1014                DecodedPayload::RecordBatch(b) => {
1015                    assert_eq!(b.schema(), expected_schema);
1016                    let expected_array = expected_arrays.next().unwrap();
1017                    let list_array =
1018                        downcast_array::<ListArray>(b.column_by_name("dict_list").unwrap());
1019                    let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1020
1021                    assert_eq!(elem_array, expected_array);
1022                }
1023            }
1024        }
1025    }
1026
1027    #[tokio::test]
1028    async fn test_dictionary_list_resend() {
1029        let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1030
1031        builder.append_value(vec![Some("a"), None, Some("b")]);
1032
1033        let arr1 = builder.finish();
1034
1035        builder.append_value(vec![Some("c"), None, Some("d")]);
1036
1037        let arr2 = builder.finish();
1038
1039        let schema = Arc::new(Schema::new(vec![Field::new_list(
1040            "dict_list",
1041            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1042            true,
1043        )]));
1044
1045        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1046        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1047
1048        verify_flight_round_trip(vec![batch1, batch2]).await;
1049    }
1050
1051    #[tokio::test]
1052    async fn test_dictionary_struct_hydration() {
1053        let struct_fields = vec![Field::new_list(
1054            "dict_list",
1055            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1056            true,
1057        )];
1058
1059        let mut struct_builder = StructBuilder::new(
1060            struct_fields.clone(),
1061            vec![Box::new(builder::ListBuilder::new(
1062                StringDictionaryBuilder::<UInt16Type>::new(),
1063            ))],
1064        );
1065
1066        struct_builder
1067            .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1068            .unwrap()
1069            .append_value(vec![Some("a"), None, Some("b")]);
1070
1071        struct_builder.append(true);
1072
1073        let arr1 = struct_builder.finish();
1074
1075        struct_builder
1076            .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1077            .unwrap()
1078            .append_value(vec![Some("c"), None, Some("d")]);
1079        struct_builder.append(true);
1080
1081        let arr2 = struct_builder.finish();
1082
1083        let schema = Arc::new(Schema::new(vec![Field::new_struct(
1084            "struct",
1085            struct_fields,
1086            true,
1087        )]));
1088
1089        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1090        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1091
1092        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1093
1094        let encoder = FlightDataEncoderBuilder::default().build(stream);
1095
1096        let mut decoder = FlightDataDecoder::new(encoder);
1097        let expected_schema = Schema::new(vec![Field::new_struct(
1098            "struct",
1099            vec![Field::new_list(
1100                "dict_list",
1101                Field::new_list_field(DataType::Utf8, true),
1102                true,
1103            )],
1104            true,
1105        )]);
1106
1107        let expected_schema = Arc::new(expected_schema);
1108
1109        let mut expected_arrays = vec![
1110            StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1111            StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1112        ]
1113        .into_iter();
1114
1115        while let Some(decoded) = decoder.next().await {
1116            let decoded = decoded.unwrap();
1117            match decoded.payload {
1118                DecodedPayload::None => {}
1119                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1120                DecodedPayload::RecordBatch(b) => {
1121                    assert_eq!(b.schema(), expected_schema);
1122                    let expected_array = expected_arrays.next().unwrap();
1123                    let struct_array =
1124                        downcast_array::<StructArray>(b.column_by_name("struct").unwrap());
1125                    let list_array = downcast_array::<ListArray>(struct_array.column(0));
1126
1127                    let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1128
1129                    assert_eq!(elem_array, expected_array);
1130                }
1131            }
1132        }
1133    }
1134
1135    #[tokio::test]
1136    async fn test_dictionary_struct_resend() {
1137        let struct_fields = vec![Field::new_list(
1138            "dict_list",
1139            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1140            true,
1141        )];
1142
1143        let mut struct_builder = StructBuilder::new(
1144            struct_fields.clone(),
1145            vec![Box::new(builder::ListBuilder::new(
1146                StringDictionaryBuilder::<UInt16Type>::new(),
1147            ))],
1148        );
1149
1150        struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1151            .unwrap()
1152            .append_value(vec![Some("a"), None, Some("b")]);
1153        struct_builder.append(true);
1154
1155        let arr1 = struct_builder.finish();
1156
1157        struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1158            .unwrap()
1159            .append_value(vec![Some("c"), None, Some("d")]);
1160        struct_builder.append(true);
1161
1162        let arr2 = struct_builder.finish();
1163
1164        let schema = Arc::new(Schema::new(vec![Field::new_struct(
1165            "struct",
1166            struct_fields,
1167            true,
1168        )]));
1169
1170        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1171        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1172
1173        verify_flight_round_trip(vec![batch1, batch2]).await;
1174    }
1175
1176    #[tokio::test]
1177    async fn test_dictionary_union_hydration() {
1178        let struct_fields = vec![Field::new_list(
1179            "dict_list",
1180            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1181            true,
1182        )];
1183
1184        let union_fields = [
1185            (
1186                0,
1187                Arc::new(Field::new_list(
1188                    "dict_list",
1189                    Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1190                    true,
1191                )),
1192            ),
1193            (
1194                1,
1195                Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1196            ),
1197            (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1198        ]
1199        .into_iter()
1200        .collect::<UnionFields>();
1201
1202        let struct_fields = vec![Field::new_list(
1203            "dict_list",
1204            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1205            true,
1206        )];
1207
1208        let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1209
1210        builder.append_value(vec![Some("a"), None, Some("b")]);
1211
1212        let arr1 = builder.finish();
1213
1214        let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1215        let arr1 = UnionArray::try_new(
1216            union_fields.clone(),
1217            type_id_buffer,
1218            None,
1219            vec![
1220                Arc::new(arr1) as Arc<dyn Array>,
1221                new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1222                new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1223            ],
1224        )
1225        .unwrap();
1226
1227        builder.append_value(vec![Some("c"), None, Some("d")]);
1228
1229        let arr2 = Arc::new(builder.finish());
1230        let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1231
1232        let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1233        let arr2 = UnionArray::try_new(
1234            union_fields.clone(),
1235            type_id_buffer,
1236            None,
1237            vec![
1238                new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1239                Arc::new(arr2),
1240                new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1241            ],
1242        )
1243        .unwrap();
1244
1245        let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1246        let arr3 = UnionArray::try_new(
1247            union_fields.clone(),
1248            type_id_buffer,
1249            None,
1250            vec![
1251                new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1252                new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1253                Arc::new(StringArray::from(vec!["e"])),
1254            ],
1255        )
1256        .unwrap();
1257
1258        let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1259            .iter()
1260            .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1261            .unzip();
1262        let schema = Arc::new(Schema::new(vec![Field::new_union(
1263            "union",
1264            type_ids.clone(),
1265            union_fields.clone(),
1266            UnionMode::Sparse,
1267        )]));
1268
1269        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1270        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1271        let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1272
1273        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
1274
1275        let encoder = FlightDataEncoderBuilder::default().build(stream);
1276
1277        let mut decoder = FlightDataDecoder::new(encoder);
1278
1279        let hydrated_struct_fields = vec![Field::new_list(
1280            "dict_list",
1281            Field::new_list_field(DataType::Utf8, true),
1282            true,
1283        )];
1284
1285        let hydrated_union_fields = vec![
1286            Field::new_list(
1287                "dict_list",
1288                Field::new_list_field(DataType::Utf8, true),
1289                true,
1290            ),
1291            Field::new_struct("struct", hydrated_struct_fields.clone(), true),
1292            Field::new("string", DataType::Utf8, true),
1293        ];
1294
1295        let expected_schema = Schema::new(vec![Field::new_union(
1296            "union",
1297            type_ids.clone(),
1298            hydrated_union_fields,
1299            UnionMode::Sparse,
1300        )]);
1301
1302        let expected_schema = Arc::new(expected_schema);
1303
1304        let mut expected_arrays = vec![
1305            StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1306            StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1307            StringArray::from(vec!["e"]),
1308        ]
1309        .into_iter();
1310
1311        let mut batch = 0;
1312        while let Some(decoded) = decoder.next().await {
1313            let decoded = decoded.unwrap();
1314            match decoded.payload {
1315                DecodedPayload::None => {}
1316                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1317                DecodedPayload::RecordBatch(b) => {
1318                    assert_eq!(b.schema(), expected_schema);
1319                    let expected_array = expected_arrays.next().unwrap();
1320                    let union_arr =
1321                        downcast_array::<UnionArray>(b.column_by_name("union").unwrap());
1322
1323                    let elem_array = match batch {
1324                        0 => {
1325                            let list_array = downcast_array::<ListArray>(union_arr.child(0));
1326                            downcast_array::<StringArray>(list_array.value(0).as_ref())
1327                        }
1328                        1 => {
1329                            let struct_array = downcast_array::<StructArray>(union_arr.child(1));
1330                            let list_array = downcast_array::<ListArray>(struct_array.column(0));
1331
1332                            downcast_array::<StringArray>(list_array.value(0).as_ref())
1333                        }
1334                        _ => downcast_array::<StringArray>(union_arr.child(2)),
1335                    };
1336
1337                    batch += 1;
1338
1339                    assert_eq!(elem_array, expected_array);
1340                }
1341            }
1342        }
1343    }
1344
1345    #[tokio::test]
1346    async fn test_dictionary_union_resend() {
1347        let struct_fields = vec![Field::new_list(
1348            "dict_list",
1349            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1350            true,
1351        )];
1352
1353        let union_fields = [
1354            (
1355                0,
1356                Arc::new(Field::new_list(
1357                    "dict_list",
1358                    Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1359                    true,
1360                )),
1361            ),
1362            (
1363                1,
1364                Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1365            ),
1366            (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1367        ]
1368        .into_iter()
1369        .collect::<UnionFields>();
1370
1371        let mut field_types = union_fields.iter().map(|(_, field)| field.data_type());
1372        let dict_list_ty = field_types.next().unwrap();
1373        let struct_ty = field_types.next().unwrap();
1374        let string_ty = field_types.next().unwrap();
1375
1376        let struct_fields = vec![Field::new_list(
1377            "dict_list",
1378            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1379            true,
1380        )];
1381
1382        let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1383
1384        builder.append_value(vec![Some("a"), None, Some("b")]);
1385
1386        let arr1 = builder.finish();
1387
1388        let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1389        let arr1 = UnionArray::try_new(
1390            union_fields.clone(),
1391            type_id_buffer,
1392            None,
1393            vec![
1394                Arc::new(arr1),
1395                new_null_array(struct_ty, 1),
1396                new_null_array(string_ty, 1),
1397            ],
1398        )
1399        .unwrap();
1400
1401        builder.append_value(vec![Some("c"), None, Some("d")]);
1402
1403        let arr2 = Arc::new(builder.finish());
1404        let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1405
1406        let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1407        let arr2 = UnionArray::try_new(
1408            union_fields.clone(),
1409            type_id_buffer,
1410            None,
1411            vec![
1412                new_null_array(dict_list_ty, 1),
1413                Arc::new(arr2),
1414                new_null_array(string_ty, 1),
1415            ],
1416        )
1417        .unwrap();
1418
1419        let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1420        let arr3 = UnionArray::try_new(
1421            union_fields.clone(),
1422            type_id_buffer,
1423            None,
1424            vec![
1425                new_null_array(dict_list_ty, 1),
1426                new_null_array(struct_ty, 1),
1427                Arc::new(StringArray::from(vec!["e"])),
1428            ],
1429        )
1430        .unwrap();
1431
1432        let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1433            .iter()
1434            .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1435            .unzip();
1436        let schema = Arc::new(Schema::new(vec![Field::new_union(
1437            "union",
1438            type_ids.clone(),
1439            union_fields.clone(),
1440            UnionMode::Sparse,
1441        )]));
1442
1443        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1444        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1445        let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1446
1447        verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
1448    }
1449
1450    #[tokio::test]
1451    async fn test_dictionary_map_hydration() {
1452        let mut builder = MapBuilder::new(
1453            None,
1454            StringDictionaryBuilder::<UInt16Type>::new(),
1455            StringDictionaryBuilder::<UInt16Type>::new(),
1456        );
1457
1458        // {"k1":"a","k2":null,"k3":"b"}
1459        builder.keys().append_value("k1");
1460        builder.values().append_value("a");
1461        builder.keys().append_value("k2");
1462        builder.values().append_null();
1463        builder.keys().append_value("k3");
1464        builder.values().append_value("b");
1465        builder.append(true).unwrap();
1466
1467        let arr1 = builder.finish();
1468
1469        // {"k1":"c","k2":null,"k3":"d"}
1470        builder.keys().append_value("k1");
1471        builder.values().append_value("c");
1472        builder.keys().append_value("k2");
1473        builder.values().append_null();
1474        builder.keys().append_value("k3");
1475        builder.values().append_value("d");
1476        builder.append(true).unwrap();
1477
1478        let arr2 = builder.finish();
1479
1480        let schema = Arc::new(Schema::new(vec![Field::new_map(
1481            "dict_map",
1482            "entries",
1483            Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1484            Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1485            false,
1486            false,
1487        )]));
1488
1489        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1490        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1491
1492        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1493
1494        let encoder = FlightDataEncoderBuilder::default().build(stream);
1495
1496        let mut decoder = FlightDataDecoder::new(encoder);
1497        let expected_schema = Schema::new(vec![Field::new_map(
1498            "dict_map",
1499            "entries",
1500            Field::new("keys", DataType::Utf8, false),
1501            Field::new("values", DataType::Utf8, true),
1502            false,
1503            false,
1504        )]);
1505
1506        let expected_schema = Arc::new(expected_schema);
1507
1508        // Builder without dictionary fields
1509        let mut builder = MapBuilder::new(
1510            None,
1511            GenericStringBuilder::<i32>::new(),
1512            GenericStringBuilder::<i32>::new(),
1513        );
1514
1515        // {"k1":"a","k2":null,"k3":"b"}
1516        builder.keys().append_value("k1");
1517        builder.values().append_value("a");
1518        builder.keys().append_value("k2");
1519        builder.values().append_null();
1520        builder.keys().append_value("k3");
1521        builder.values().append_value("b");
1522        builder.append(true).unwrap();
1523
1524        let arr1 = builder.finish();
1525
1526        // {"k1":"c","k2":null,"k3":"d"}
1527        builder.keys().append_value("k1");
1528        builder.values().append_value("c");
1529        builder.keys().append_value("k2");
1530        builder.values().append_null();
1531        builder.keys().append_value("k3");
1532        builder.values().append_value("d");
1533        builder.append(true).unwrap();
1534
1535        let arr2 = builder.finish();
1536
1537        let mut expected_arrays = vec![arr1, arr2].into_iter();
1538
1539        while let Some(decoded) = decoder.next().await {
1540            let decoded = decoded.unwrap();
1541            match decoded.payload {
1542                DecodedPayload::None => {}
1543                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1544                DecodedPayload::RecordBatch(b) => {
1545                    assert_eq!(b.schema(), expected_schema);
1546                    let expected_array = expected_arrays.next().unwrap();
1547                    let map_array =
1548                        downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());
1549
1550                    assert_eq!(map_array, expected_array);
1551                }
1552            }
1553        }
1554    }
1555
1556    #[tokio::test]
1557    async fn test_dictionary_map_resend() {
1558        let mut builder = MapBuilder::new(
1559            None,
1560            StringDictionaryBuilder::<UInt16Type>::new(),
1561            StringDictionaryBuilder::<UInt16Type>::new(),
1562        );
1563
1564        // {"k1":"a","k2":null,"k3":"b"}
1565        builder.keys().append_value("k1");
1566        builder.values().append_value("a");
1567        builder.keys().append_value("k2");
1568        builder.values().append_null();
1569        builder.keys().append_value("k3");
1570        builder.values().append_value("b");
1571        builder.append(true).unwrap();
1572
1573        let arr1 = builder.finish();
1574
1575        // {"k1":"c","k2":null,"k3":"d"}
1576        builder.keys().append_value("k1");
1577        builder.values().append_value("c");
1578        builder.keys().append_value("k2");
1579        builder.values().append_null();
1580        builder.keys().append_value("k3");
1581        builder.values().append_value("d");
1582        builder.append(true).unwrap();
1583
1584        let arr2 = builder.finish();
1585
1586        let schema = Arc::new(Schema::new(vec![Field::new_map(
1587            "dict_map",
1588            "entries",
1589            Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1590            Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1591            false,
1592            false,
1593        )]));
1594
1595        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1596        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1597
1598        verify_flight_round_trip(vec![batch1, batch2]).await;
1599    }
1600
1601    #[tokio::test]
1602    async fn test_dictionary_ree_resend() {
1603        let dict_values1 = vec![Some("a"), None, Some("b")]
1604            .into_iter()
1605            .collect::<DictionaryArray<Int32Type>>();
1606        let run_ends1 = Int32Array::from(vec![1, 2, 3]);
1607        let arr1 = RunArray::try_new(&run_ends1, &dict_values1).unwrap();
1608
1609        let dict_values2 = vec![Some("c"), Some("a")]
1610            .into_iter()
1611            .collect::<DictionaryArray<Int32Type>>();
1612        let run_ends2 = Int32Array::from(vec![1, 2]);
1613        let arr2 = RunArray::try_new(&run_ends2, &dict_values2).unwrap();
1614
1615        let schema = Arc::new(Schema::new(vec![Field::new(
1616            "ree",
1617            arr1.data_type().clone(),
1618            true,
1619        )]));
1620
1621        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1622        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1623
1624        verify_flight_round_trip(vec![batch1, batch2]).await;
1625    }
1626
1627    #[tokio::test]
1628    async fn test_dictionary_of_struct_of_dict_resend() {
1629        // Dict(Int8, Struct { dict: Dict(Int32, Utf8), int: Int32 })
1630        // This exercises the Dictionary branch recursing into its value type
1631        // before assigning its own dict_id (depth-first ordering).
1632        let struct_fields: Vec<Field> = vec![
1633            Field::new_dictionary("dict", DataType::Int32, DataType::Utf8, true),
1634            Field::new("int", DataType::Int32, false),
1635        ];
1636
1637        let inner_values =
1638            StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
1639        let inner_keys = Int32Array::from_iter_values([0, 1, 2, 3, 0]);
1640        let inner_dict = DictionaryArray::new(inner_keys, Arc::new(inner_values));
1641        let int_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
1642
1643        let struct_array = StructArray::from(vec![
1644            (
1645                Arc::new(struct_fields[0].clone()),
1646                Arc::new(inner_dict) as ArrayRef,
1647            ),
1648            (
1649                Arc::new(struct_fields[1].clone()),
1650                Arc::new(int_array) as ArrayRef,
1651            ),
1652        ]);
1653
1654        let outer_keys = Int8Array::from_iter_values([0, 0, 1, 2]);
1655        let arr1 = DictionaryArray::new(outer_keys, Arc::new(struct_array));
1656
1657        let inner_values2 = StringArray::from(vec![Some("x"), Some("y")]);
1658        let inner_keys2 = Int32Array::from_iter_values([0, 1, 0]);
1659        let inner_dict2 = DictionaryArray::new(inner_keys2, Arc::new(inner_values2));
1660        let int_array2 = Int32Array::from(vec![100, 200, 300]);
1661
1662        let struct_array2 = StructArray::from(vec![
1663            (
1664                Arc::new(struct_fields[0].clone()),
1665                Arc::new(inner_dict2) as ArrayRef,
1666            ),
1667            (
1668                Arc::new(struct_fields[1].clone()),
1669                Arc::new(int_array2) as ArrayRef,
1670            ),
1671        ]);
1672
1673        let outer_keys2 = Int8Array::from_iter_values([0, 1]);
1674        let arr2 = DictionaryArray::new(outer_keys2, Arc::new(struct_array2));
1675
1676        let schema = Arc::new(Schema::new(vec![Field::new(
1677            "dict_struct",
1678            arr1.data_type().clone(),
1679            false,
1680        )]));
1681
1682        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1683        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1684
1685        verify_flight_round_trip(vec![batch1, batch2]).await;
1686    }
1687
1688    async fn verify_dictionary_list_view_resend<O: OffsetSizeTrait>() {
1689        let mut builder =
1690            GenericListViewBuilder::<O, _>::new(StringDictionaryBuilder::<UInt16Type>::new());
1691
1692        builder.append_value(vec![Some("a"), None, Some("b")]);
1693        let arr1 = builder.finish();
1694
1695        builder.append_value(vec![Some("c"), None, Some("d")]);
1696        let arr2 = builder.finish();
1697
1698        let inner = Arc::new(Field::new_dictionary(
1699            "item",
1700            DataType::UInt16,
1701            DataType::Utf8,
1702            true,
1703        ));
1704        let dt = if O::IS_LARGE {
1705            DataType::LargeListView(inner)
1706        } else {
1707            DataType::ListView(inner)
1708        };
1709        let schema = Arc::new(Schema::new(vec![Field::new("dict_list_view", dt, true)]));
1710
1711        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1712        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1713
1714        verify_flight_round_trip(vec![batch1, batch2]).await;
1715    }
1716
1717    #[tokio::test]
1718    async fn test_dictionary_list_view_resend() {
1719        verify_dictionary_list_view_resend::<i32>().await;
1720    }
1721
1722    #[tokio::test]
1723    async fn test_dictionary_large_list_view_resend() {
1724        verify_dictionary_list_view_resend::<i64>().await;
1725    }
1726
1727    #[tokio::test]
1728    async fn test_dictionary_fixed_size_list_resend() {
1729        let mut builder =
1730            FixedSizeListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new(), 2);
1731
1732        builder.values().append_value("a");
1733        builder.values().append_value("b");
1734        builder.append(true);
1735        let arr1 = builder.finish();
1736
1737        builder.values().append_value("c");
1738        builder.values().append_value("d");
1739        builder.append(true);
1740        let arr2 = builder.finish();
1741
1742        let schema = Arc::new(Schema::new(vec![Field::new_fixed_size_list(
1743            "dict_fsl",
1744            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1745            2,
1746            true,
1747        )]));
1748
1749        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1750        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1751
1752        verify_flight_round_trip(vec![batch1, batch2]).await;
1753    }
1754
1755    async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
1756        let expected_schema = batches.first().unwrap().schema();
1757
1758        let encoder = FlightDataEncoderBuilder::default()
1759            .with_options(IpcWriteOptions::default())
1760            .with_dictionary_handling(DictionaryHandling::Resend)
1761            .build(futures::stream::iter(batches.clone().into_iter().map(Ok)));
1762
1763        let mut expected_batches = batches.drain(..);
1764
1765        let mut decoder = FlightDataDecoder::new(encoder);
1766        while let Some(decoded) = decoder.next().await {
1767            let decoded = decoded.unwrap();
1768            match decoded.payload {
1769                DecodedPayload::None => {}
1770                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1771                DecodedPayload::RecordBatch(b) => {
1772                    let expected_batch = expected_batches.next().unwrap();
1773                    assert_eq!(b, expected_batch);
1774                }
1775            }
1776        }
1777    }
1778
1779    #[test]
1780    fn test_schema_metadata_encoded() {
1781        let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata(
1782            HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
1783        );
1784
1785        let mut dictionary_tracker = DictionaryTracker::new(false);
1786
1787        let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
1788        assert!(got.metadata().contains_key("some_key"));
1789    }
1790
1791    #[test]
1792    fn test_encode_no_column_batch() {
1793        let batch = RecordBatch::try_new_with_options(
1794            Arc::new(Schema::empty()),
1795            vec![],
1796            &RecordBatchOptions::new().with_row_count(Some(10)),
1797        )
1798        .expect("cannot create record batch");
1799
1800        hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize");
1801    }
1802
1803    fn make_flight_data(
1804        batch: &RecordBatch,
1805        options: &IpcWriteOptions,
1806    ) -> (Vec<FlightData>, FlightData) {
1807        flight_data_from_arrow_batch(batch, options)
1808    }
1809
1810    fn flight_data_from_arrow_batch(
1811        batch: &RecordBatch,
1812        options: &IpcWriteOptions,
1813    ) -> (Vec<FlightData>, FlightData) {
1814        let data_gen = IpcDataGenerator::default();
1815        let mut dictionary_tracker = DictionaryTracker::new(false);
1816        let mut compression_context = CompressionContext::default();
1817
1818        let (encoded_dictionaries, encoded_batch) = data_gen
1819            .encode(
1820                batch,
1821                &mut dictionary_tracker,
1822                options,
1823                &mut compression_context,
1824            )
1825            .expect("DictionaryTracker configured above to not error on replacement");
1826
1827        let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
1828        let flight_batch = encoded_batch.into();
1829
1830        (flight_dictionaries, flight_batch)
1831    }
1832
1833    #[test]
1834    fn test_split_batch_for_grpc_response() {
1835        let max_flight_data_size = 1024;
1836
1837        // no split
1838        let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
1839        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1840            .expect("cannot create record batch");
1841        let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1842        assert_eq!(split.len(), 1);
1843        assert_eq!(batch, split[0]);
1844
1845        // split once
1846        let n_rows = max_flight_data_size + 1;
1847        assert!(n_rows % 2 == 1, "should be an odd number");
1848        let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
1849        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1850            .expect("cannot create record batch");
1851        let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1852        assert_eq!(split.len(), 3);
1853        assert_eq!(
1854            split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
1855            n_rows
1856        );
1857        let a = pretty_format_batches(&split).unwrap().to_string();
1858        let b = pretty_format_batches(&[batch]).unwrap().to_string();
1859        assert_eq!(a, b);
1860    }
1861
1862    #[test]
1863    fn test_split_batch_for_grpc_response_sizes() {
1864        // 2000 8 byte entries into 2k pieces: 8 chunks of 250 rows
1865        verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
1866
1867        // 2000 8 byte entries into 4k pieces: 4 chunks of 500 rows
1868        verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
1869
1870        // 2023 8 byte entries into 3k pieces does not divide evenly
1871        verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
1872
1873        // 10 8 byte entries into 1 byte pieces means each rows gets its own
1874        verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
1875
1876        // 10 8 byte entries into 1k byte pieces means one piece
1877        verify_split(10, 1024, vec![10]);
1878    }
1879
1880    /// Creates a UInt64Array of 8 byte integers with input_rows rows
1881    /// `max_flight_data_size_bytes` pieces and verifies the row counts in
1882    /// those pieces
1883    fn verify_split(
1884        num_input_rows: u64,
1885        max_flight_data_size_bytes: usize,
1886        expected_sizes: Vec<usize>,
1887    ) {
1888        let array: UInt64Array = (0..num_input_rows).collect();
1889
1890        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
1891            .expect("cannot create record batch");
1892
1893        let input_rows = batch.num_rows();
1894
1895        let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
1896        let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
1897        let output_rows: usize = sizes.iter().sum();
1898
1899        assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
1900        assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
1901    }
1902
1903    // test sending record batches
1904    // test sending record batches with multiple different dictionaries
1905
1906    #[tokio::test]
1907    async fn flight_data_size_even() {
1908        let s1 = StringArray::from_iter_values(std::iter::repeat_n(".10 bytes.", 1024));
1909        let i1 = Int16Array::from_iter_values(0..1024);
1910        let s2 = StringArray::from_iter_values(std::iter::repeat_n("6bytes", 1024));
1911        let i2 = Int64Array::from_iter_values(0..1024);
1912
1913        let batch = RecordBatch::try_from_iter(vec![
1914            ("s1", Arc::new(s1) as _),
1915            ("i1", Arc::new(i1) as _),
1916            ("s2", Arc::new(s2) as _),
1917            ("i2", Arc::new(i2) as _),
1918        ])
1919        .unwrap();
1920
1921        verify_encoded_split(batch, 120).await;
1922    }
1923
1924    #[tokio::test]
1925    async fn flight_data_size_uneven_variable_lengths() {
1926        // each row has a longer string than the last with increasing lengths 0 --> 1024
1927        let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
1928        let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
1929
1930        // overage is much higher than ideal
1931        // https://github.com/apache/arrow-rs/issues/3478
1932        verify_encoded_split(batch, 4312).await;
1933    }
1934
1935    #[tokio::test]
1936    async fn flight_data_size_large_row() {
1937        // batch with individual that can each exceed the batch size
1938        let array1 = StringArray::from_iter_values(vec![
1939            "*".repeat(500),
1940            "*".repeat(500),
1941            "*".repeat(500),
1942            "*".repeat(500),
1943        ]);
1944        let array2 = StringArray::from_iter_values(vec![
1945            "*".to_string(),
1946            "*".repeat(1000),
1947            "*".repeat(2000),
1948            "*".repeat(4000),
1949        ]);
1950
1951        let array3 = StringArray::from_iter_values(vec![
1952            "*".to_string(),
1953            "*".to_string(),
1954            "*".repeat(1000),
1955            "*".repeat(2000),
1956        ]);
1957
1958        let batch = RecordBatch::try_from_iter(vec![
1959            ("a1", Arc::new(array1) as _),
1960            ("a2", Arc::new(array2) as _),
1961            ("a3", Arc::new(array3) as _),
1962        ])
1963        .unwrap();
1964
1965        // 5k over limit (which is 2x larger than limit of 5k)
1966        // overage is much higher than ideal
1967        // https://github.com/apache/arrow-rs/issues/3478
1968        verify_encoded_split(batch, 5808).await;
1969    }
1970
1971    #[tokio::test]
1972    async fn flight_data_size_string_dictionary() {
1973        // Small dictionary (only 2 distinct values ==> 2 entries in dictionary)
1974        let array: DictionaryArray<Int32Type> = (1..1024)
1975            .map(|i| match i % 3 {
1976                0 => Some("value0"),
1977                1 => Some("value1"),
1978                _ => None,
1979            })
1980            .collect();
1981
1982        let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1983
1984        verify_encoded_split(batch, 56).await;
1985    }
1986
1987    #[tokio::test]
1988    async fn flight_data_size_large_dictionary() {
1989        // large dictionary (all distinct values ==> 1024 entries in dictionary)
1990        let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1991
1992        let array: DictionaryArray<Int32Type> = values.iter().map(|s| Some(s.as_str())).collect();
1993
1994        let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1995
1996        // overage is much higher than ideal
1997        // https://github.com/apache/arrow-rs/issues/3478
1998        verify_encoded_split(batch, 3336).await;
1999    }
2000
2001    #[tokio::test]
2002    async fn flight_data_size_large_dictionary_repeated_non_uniform() {
2003        // large dictionary (1024 distinct values) that are used throughout the array
2004        let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
2005        let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
2006        let array = DictionaryArray::new(keys, Arc::new(values));
2007
2008        let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
2009
2010        // overage is much higher than ideal
2011        // https://github.com/apache/arrow-rs/issues/3478
2012        verify_encoded_split(batch, 5288).await;
2013    }
2014
2015    #[tokio::test]
2016    async fn flight_data_size_multiple_dictionaries() {
2017        // high cardinality
2018        let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
2019        // highish cardinality
2020        let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
2021        // medium cardinality
2022        let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
2023
2024        let array1: DictionaryArray<Int32Type> = values1.iter().map(|s| Some(s.as_str())).collect();
2025        let array2: DictionaryArray<Int32Type> = values2.iter().map(|s| Some(s.as_str())).collect();
2026        let array3: DictionaryArray<Int32Type> = values3.iter().map(|s| Some(s.as_str())).collect();
2027
2028        let batch = RecordBatch::try_from_iter(vec![
2029            ("a1", Arc::new(array1) as _),
2030            ("a2", Arc::new(array2) as _),
2031            ("a3", Arc::new(array3) as _),
2032        ])
2033        .unwrap();
2034
2035        // overage is much higher than ideal
2036        // https://github.com/apache/arrow-rs/issues/3478
2037        verify_encoded_split(batch, 4136).await;
2038    }
2039
2040    /// Return size, in memory of flight data
2041    fn flight_data_size(d: &FlightData) -> usize {
2042        let flight_descriptor_size = d
2043            .flight_descriptor
2044            .as_ref()
2045            .map(|descriptor| {
2046                let path_len: usize = descriptor.path.iter().map(|p| p.len()).sum();
2047
2048                std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
2049            })
2050            .unwrap_or(0);
2051
2052        flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len()
2053    }
2054
2055    /// Coverage for <https://github.com/apache/arrow-rs/issues/3478>
2056    ///
2057    /// Encodes the specified batch using several values of
2058    /// `max_flight_data_size` between 1K to 5K and ensures that the
2059    /// resulting size of the flight data stays within the limit
2060    /// + `allowed_overage`
2061    ///
2062    /// `allowed_overage` is how far off the actual data encoding is
2063    /// from the target limit that was set. It is an improvement when
2064    /// the allowed_overage decreses.
2065    ///
2066    /// Note this overhead will likely always be greater than zero to
2067    /// account for encoding overhead such as IPC headers and padding.
2068    ///
2069    ///
2070    async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
2071        let num_rows = batch.num_rows();
2072
2073        // Track the overall required maximum overage
2074        let mut max_overage_seen = 0;
2075
2076        for max_flight_data_size in [1024, 2021, 5000] {
2077            println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
2078
2079            let mut stream = FlightDataEncoderBuilder::new()
2080                .with_max_flight_data_size(max_flight_data_size)
2081                // use 8-byte alignment - default alignment is 64 which produces bigger ipc data
2082                .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap())
2083                .build(futures::stream::iter([Ok(batch.clone())]));
2084
2085            let mut i = 0;
2086            while let Some(data) = stream.next().await.transpose().unwrap() {
2087                let actual_data_size = flight_data_size(&data);
2088
2089                let actual_overage = actual_data_size.saturating_sub(max_flight_data_size);
2090
2091                assert!(
2092                    actual_overage <= allowed_overage,
2093                    "encoded data[{i}]: actual size {actual_data_size}, \
2094                         actual_overage: {actual_overage} \
2095                         allowed_overage: {allowed_overage}"
2096                );
2097
2098                i += 1;
2099
2100                max_overage_seen = max_overage_seen.max(actual_overage)
2101            }
2102        }
2103
2104        // ensure that the specified overage is exactly the maxmium so
2105        // that when the splitting logic improves, the tests must be
2106        // updated to reflect the better logic
2107        assert_eq!(
2108            allowed_overage, max_overage_seen,
2109            "Specified overage was too high"
2110        );
2111    }
2112}