lance_arrow/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend Arrow Functionality
5//!
6//! To improve Arrow-RS ergonomic
7
8use std::sync::Arc;
9use std::{collections::HashMap, ptr::NonNull};
10
11use arrow_array::{
12    cast::AsArray, Array, ArrayRef, ArrowNumericType, FixedSizeBinaryArray, FixedSizeListArray,
13    GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt32Array,
14    UInt8Array,
15};
16use arrow_array::{
17    new_null_array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
18};
19use arrow_buffer::MutableBuffer;
20use arrow_data::ArrayDataBuilder;
21use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema};
22use arrow_select::{interleave::interleave, take::take};
23use rand::prelude::*;
24
25pub mod deepcopy;
26pub mod schema;
27pub use schema::*;
28pub mod bfloat16;
29pub mod floats;
30pub use floats::*;
31pub mod cast;
32pub mod list;
33pub mod memory;
34
35type Result<T> = std::result::Result<T, ArrowError>;
36
37pub trait DataTypeExt {
38    /// Returns true if the data type is binary-like, such as (Large)Utf8 and (Large)Binary.
39    ///
40    /// ```
41    /// use lance_arrow::*;
42    /// use arrow_schema::DataType;
43    ///
44    /// assert!(DataType::Utf8.is_binary_like());
45    /// assert!(DataType::Binary.is_binary_like());
46    /// assert!(DataType::LargeUtf8.is_binary_like());
47    /// assert!(DataType::LargeBinary.is_binary_like());
48    /// assert!(!DataType::Int32.is_binary_like());
49    /// ```
50    fn is_binary_like(&self) -> bool;
51
52    /// Returns true if the data type is a struct.
53    fn is_struct(&self) -> bool;
54
55    /// Check whether the given Arrow DataType is fixed stride.
56    ///
57    /// A fixed stride type has the same byte width for all array elements
58    /// This includes all PrimitiveType's Boolean, FixedSizeList, FixedSizeBinary, and Decimals
59    fn is_fixed_stride(&self) -> bool;
60
61    /// Returns true if the [DataType] is a dictionary type.
62    fn is_dictionary(&self) -> bool;
63
64    /// Returns the byte width of the data type
65    /// Panics if the data type is not fixed stride.
66    fn byte_width(&self) -> usize;
67
68    /// Returns the byte width of the data type, if it is fixed stride.
69    /// Returns None if the data type is not fixed stride.
70    fn byte_width_opt(&self) -> Option<usize>;
71}
72
73impl DataTypeExt for DataType {
74    fn is_binary_like(&self) -> bool {
75        use DataType::*;
76        matches!(self, Utf8 | Binary | LargeUtf8 | LargeBinary)
77    }
78
79    fn is_struct(&self) -> bool {
80        matches!(self, Self::Struct(_))
81    }
82
83    fn is_fixed_stride(&self) -> bool {
84        use DataType::*;
85        matches!(
86            self,
87            Boolean
88                | UInt8
89                | UInt16
90                | UInt32
91                | UInt64
92                | Int8
93                | Int16
94                | Int32
95                | Int64
96                | Float16
97                | Float32
98                | Float64
99                | Decimal128(_, _)
100                | Decimal256(_, _)
101                | FixedSizeList(_, _)
102                | FixedSizeBinary(_)
103                | Duration(_)
104                | Timestamp(_, _)
105                | Date32
106                | Date64
107                | Time32(_)
108                | Time64(_)
109        )
110    }
111
112    fn is_dictionary(&self) -> bool {
113        matches!(self, Self::Dictionary(_, _))
114    }
115
116    fn byte_width_opt(&self) -> Option<usize> {
117        match self {
118            Self::Int8 => Some(1),
119            Self::Int16 => Some(2),
120            Self::Int32 => Some(4),
121            Self::Int64 => Some(8),
122            Self::UInt8 => Some(1),
123            Self::UInt16 => Some(2),
124            Self::UInt32 => Some(4),
125            Self::UInt64 => Some(8),
126            Self::Float16 => Some(2),
127            Self::Float32 => Some(4),
128            Self::Float64 => Some(8),
129            Self::Date32 => Some(4),
130            Self::Date64 => Some(8),
131            Self::Time32(_) => Some(4),
132            Self::Time64(_) => Some(8),
133            Self::Timestamp(_, _) => Some(8),
134            Self::Duration(_) => Some(8),
135            Self::Decimal128(_, _) => Some(16),
136            Self::Decimal256(_, _) => Some(32),
137            Self::Interval(unit) => match unit {
138                IntervalUnit::YearMonth => Some(4),
139                IntervalUnit::DayTime => Some(8),
140                IntervalUnit::MonthDayNano => Some(16),
141            },
142            Self::FixedSizeBinary(s) => Some(*s as usize),
143            Self::FixedSizeList(dt, s) => dt
144                .data_type()
145                .byte_width_opt()
146                .map(|width| width * *s as usize),
147            _ => None,
148        }
149    }
150
151    fn byte_width(&self) -> usize {
152        self.byte_width_opt()
153            .unwrap_or_else(|| panic!("Expecting fixed stride data type, found {:?}", self))
154    }
155}
156
157/// Create an [`GenericListArray`] from values and offsets.
158///
159/// ```
160/// use arrow_array::{Int32Array, Int64Array, ListArray};
161/// use arrow_array::types::Int64Type;
162/// use lance_arrow::try_new_generic_list_array;
163///
164/// let offsets = Int32Array::from_iter([0, 2, 7, 10]);
165/// let int_values = Int64Array::from_iter(0..10);
166/// let list_arr = try_new_generic_list_array(int_values, &offsets).unwrap();
167/// assert_eq!(list_arr,
168///     ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
169///         Some(vec![Some(0), Some(1)]),
170///         Some(vec![Some(2), Some(3), Some(4), Some(5), Some(6)]),
171///         Some(vec![Some(7), Some(8), Some(9)]),
172/// ]))
173/// ```
174pub fn try_new_generic_list_array<T: Array, Offset: ArrowNumericType>(
175    values: T,
176    offsets: &PrimitiveArray<Offset>,
177) -> Result<GenericListArray<Offset::Native>>
178where
179    Offset::Native: OffsetSizeTrait,
180{
181    let data_type = if Offset::Native::IS_LARGE {
182        DataType::LargeList(Arc::new(Field::new(
183            "item",
184            values.data_type().clone(),
185            true,
186        )))
187    } else {
188        DataType::List(Arc::new(Field::new(
189            "item",
190            values.data_type().clone(),
191            true,
192        )))
193    };
194    let data = ArrayDataBuilder::new(data_type)
195        .len(offsets.len() - 1)
196        .add_buffer(offsets.into_data().buffers()[0].clone())
197        .add_child_data(values.into_data())
198        .build()?;
199
200    Ok(GenericListArray::from(data))
201}
202
203pub fn fixed_size_list_type(list_width: i32, inner_type: DataType) -> DataType {
204    DataType::FixedSizeList(Arc::new(Field::new("item", inner_type, true)), list_width)
205}
206
207pub trait FixedSizeListArrayExt {
208    /// Create an [`FixedSizeListArray`] from values and list size.
209    ///
210    /// ```
211    /// use arrow_array::{Int64Array, FixedSizeListArray};
212    /// use arrow_array::types::Int64Type;
213    /// use lance_arrow::FixedSizeListArrayExt;
214    ///
215    /// let int_values = Int64Array::from_iter(0..10);
216    /// let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(int_values, 2).unwrap();
217    /// assert_eq!(fixed_size_list_arr,
218    ///     FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(vec![
219    ///         Some(vec![Some(0), Some(1)]),
220    ///         Some(vec![Some(2), Some(3)]),
221    ///         Some(vec![Some(4), Some(5)]),
222    ///         Some(vec![Some(6), Some(7)]),
223    ///         Some(vec![Some(8), Some(9)])
224    /// ], 2))
225    /// ```
226    fn try_new_from_values<T: Array + 'static>(
227        values: T,
228        list_size: i32,
229    ) -> Result<FixedSizeListArray>;
230
231    /// Sample `n` rows from the [FixedSizeListArray]
232    ///
233    /// ```
234    /// use arrow_array::{Int64Array, FixedSizeListArray, Array};
235    /// use lance_arrow::FixedSizeListArrayExt;
236    ///
237    /// let int_values = Int64Array::from_iter(0..256);
238    /// let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(int_values, 16).unwrap();
239    /// let sampled = fixed_size_list_arr.sample(10).unwrap();
240    /// assert_eq!(sampled.len(), 10);
241    /// assert_eq!(sampled.value_length(), 16);
242    /// assert_eq!(sampled.values().len(), 160);
243    /// ```
244    fn sample(&self, n: usize) -> Result<FixedSizeListArray>;
245
246    /// Ensure the [FixedSizeListArray] of Float16, Float32, Float64,
247    /// Int8, Int16, Int32, Int64, UInt8, UInt32 type to its closest floating point type.
248    fn convert_to_floating_point(&self) -> Result<FixedSizeListArray>;
249}
250
251impl FixedSizeListArrayExt for FixedSizeListArray {
252    fn try_new_from_values<T: Array + 'static>(values: T, list_size: i32) -> Result<Self> {
253        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
254        let values = Arc::new(values);
255
256        Self::try_new(field, list_size, values, None)
257    }
258
259    fn sample(&self, n: usize) -> Result<FixedSizeListArray> {
260        if n >= self.len() {
261            return Ok(self.clone());
262        }
263        let mut rng = SmallRng::from_entropy();
264        let chosen = (0..self.len() as u32).choose_multiple(&mut rng, n);
265        take(self, &UInt32Array::from(chosen), None).map(|arr| arr.as_fixed_size_list().clone())
266    }
267
268    fn convert_to_floating_point(&self) -> Result<FixedSizeListArray> {
269        match self.data_type() {
270            DataType::FixedSizeList(field, size) => match field.data_type() {
271                DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(self.clone()),
272                DataType::Int8 => Ok(Self::new(
273                    Arc::new(arrow_schema::Field::new(
274                        field.name(),
275                        DataType::Float32,
276                        field.is_nullable(),
277                    )),
278                    *size,
279                    Arc::new(Float32Array::from_iter_values(
280                        self.values()
281                            .as_any()
282                            .downcast_ref::<Int8Array>()
283                            .ok_or(ArrowError::ParseError(
284                                "Fail to cast primitive array to Int8Type".to_string(),
285                            ))?
286                            .into_iter()
287                            .filter_map(|x| x.map(|y| y as f32)),
288                    )),
289                    self.nulls().cloned(),
290                )),
291                DataType::Int16 => Ok(Self::new(
292                    Arc::new(arrow_schema::Field::new(
293                        field.name(),
294                        DataType::Float32,
295                        field.is_nullable(),
296                    )),
297                    *size,
298                    Arc::new(Float32Array::from_iter_values(
299                        self.values()
300                            .as_any()
301                            .downcast_ref::<Int16Array>()
302                            .ok_or(ArrowError::ParseError(
303                                "Fail to cast primitive array to Int8Type".to_string(),
304                            ))?
305                            .into_iter()
306                            .filter_map(|x| x.map(|y| y as f32)),
307                    )),
308                    self.nulls().cloned(),
309                )),
310                DataType::Int32 => Ok(Self::new(
311                    Arc::new(arrow_schema::Field::new(
312                        field.name(),
313                        DataType::Float32,
314                        field.is_nullable(),
315                    )),
316                    *size,
317                    Arc::new(Float32Array::from_iter_values(
318                        self.values()
319                            .as_any()
320                            .downcast_ref::<Int32Array>()
321                            .ok_or(ArrowError::ParseError(
322                                "Fail to cast primitive array to Int8Type".to_string(),
323                            ))?
324                            .into_iter()
325                            .filter_map(|x| x.map(|y| y as f32)),
326                    )),
327                    self.nulls().cloned(),
328                )),
329                DataType::Int64 => Ok(Self::new(
330                    Arc::new(arrow_schema::Field::new(
331                        field.name(),
332                        DataType::Float64,
333                        field.is_nullable(),
334                    )),
335                    *size,
336                    Arc::new(Float64Array::from_iter_values(
337                        self.values()
338                            .as_any()
339                            .downcast_ref::<Int64Array>()
340                            .ok_or(ArrowError::ParseError(
341                                "Fail to cast primitive array to Int8Type".to_string(),
342                            ))?
343                            .into_iter()
344                            .filter_map(|x| x.map(|y| y as f64)),
345                    )),
346                    self.nulls().cloned(),
347                )),
348                DataType::UInt8 => Ok(Self::new(
349                    Arc::new(arrow_schema::Field::new(
350                        field.name(),
351                        DataType::Float64,
352                        field.is_nullable(),
353                    )),
354                    *size,
355                    Arc::new(Float64Array::from_iter_values(
356                        self.values()
357                            .as_any()
358                            .downcast_ref::<UInt8Array>()
359                            .ok_or(ArrowError::ParseError(
360                                "Fail to cast primitive array to Int8Type".to_string(),
361                            ))?
362                            .into_iter()
363                            .filter_map(|x| x.map(|y| y as f64)),
364                    )),
365                    self.nulls().cloned(),
366                )),
367                DataType::UInt32 => Ok(Self::new(
368                    Arc::new(arrow_schema::Field::new(
369                        field.name(),
370                        DataType::Float64,
371                        field.is_nullable(),
372                    )),
373                    *size,
374                    Arc::new(Float64Array::from_iter_values(
375                        self.values()
376                            .as_any()
377                            .downcast_ref::<UInt32Array>()
378                            .ok_or(ArrowError::ParseError(
379                                "Fail to cast primitive array to Int8Type".to_string(),
380                            ))?
381                            .into_iter()
382                            .filter_map(|x| x.map(|y| y as f64)),
383                    )),
384                    self.nulls().cloned(),
385                )),
386                data_type => Err(ArrowError::ParseError(format!(
387                    "Expect either floating type or integer got {:?}",
388                    data_type
389                ))),
390            },
391            data_type => Err(ArrowError::ParseError(format!(
392                "Expect either FixedSizeList got {:?}",
393                data_type
394            ))),
395        }
396    }
397}
398
399/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
400/// [`FixedSizeListArray`], panic'ing on failure.
401pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
402    arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
403}
404
405pub trait FixedSizeBinaryArrayExt {
406    /// Create an [`FixedSizeBinaryArray`] from values and stride.
407    ///
408    /// ```
409    /// use arrow_array::{UInt8Array, FixedSizeBinaryArray};
410    /// use arrow_array::types::UInt8Type;
411    /// use lance_arrow::FixedSizeBinaryArrayExt;
412    ///
413    /// let int_values = UInt8Array::from_iter(0..10);
414    /// let fixed_size_list_arr = FixedSizeBinaryArray::try_new_from_values(&int_values, 2).unwrap();
415    /// assert_eq!(fixed_size_list_arr,
416    ///     FixedSizeBinaryArray::from(vec![
417    ///         Some(vec![0, 1].as_slice()),
418    ///         Some(vec![2, 3].as_slice()),
419    ///         Some(vec![4, 5].as_slice()),
420    ///         Some(vec![6, 7].as_slice()),
421    ///         Some(vec![8, 9].as_slice())
422    /// ]))
423    /// ```
424    fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<FixedSizeBinaryArray>;
425}
426
427impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray {
428    fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<Self> {
429        let data_type = DataType::FixedSizeBinary(stride);
430        let data = ArrayDataBuilder::new(data_type)
431            .len(values.len() / stride as usize)
432            .add_buffer(values.into_data().buffers()[0].clone())
433            .build()?;
434        Ok(Self::from(data))
435    }
436}
437
438pub fn as_fixed_size_binary_array(arr: &dyn Array) -> &FixedSizeBinaryArray {
439    arr.as_any().downcast_ref::<FixedSizeBinaryArray>().unwrap()
440}
441
442pub fn iter_str_array(arr: &dyn Array) -> Box<dyn Iterator<Item = Option<&str>> + Send + '_> {
443    match arr.data_type() {
444        DataType::Utf8 => Box::new(arr.as_string::<i32>().iter()),
445        DataType::LargeUtf8 => Box::new(arr.as_string::<i64>().iter()),
446        _ => panic!("Expecting Utf8 or LargeUtf8, found {:?}", arr.data_type()),
447    }
448}
449
450/// Extends Arrow's [RecordBatch].
451pub trait RecordBatchExt {
452    /// Append a new column to this [`RecordBatch`] and returns a new RecordBatch.
453    ///
454    /// ```
455    /// use std::sync::Arc;
456    /// use arrow_array::{RecordBatch, Int32Array, StringArray};
457    /// use arrow_schema::{Schema, Field, DataType};
458    /// use lance_arrow::*;
459    ///
460    /// let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
461    /// let int_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
462    /// let record_batch = RecordBatch::try_new(schema, vec![int_arr.clone()]).unwrap();
463    ///
464    /// let new_field = Field::new("s", DataType::Utf8, true);
465    /// let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
466    /// let new_record_batch = record_batch.try_with_column(new_field, str_arr.clone()).unwrap();
467    ///
468    /// assert_eq!(
469    ///     new_record_batch,
470    ///     RecordBatch::try_new(
471    ///         Arc::new(Schema::new(
472    ///             vec![
473    ///                 Field::new("a", DataType::Int32, true),
474    ///                 Field::new("s", DataType::Utf8, true)
475    ///             ])
476    ///         ),
477    ///         vec![int_arr, str_arr],
478    ///     ).unwrap()
479    /// )
480    /// ```
481    fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
482
483    /// Created a new RecordBatch with column at index.
484    fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
485
486    /// Creates a new [`RecordBatch`] from the provided  [`StructArray`].
487    ///
488    /// The fields on the [`StructArray`] need to match this [`RecordBatch`] schema
489    fn try_new_from_struct_array(&self, arr: StructArray) -> Result<RecordBatch>;
490
491    /// Merge with another [`RecordBatch`] and returns a new one.
492    ///
493    /// Fields are merged based on name.  First we iterate the left columns.  If a matching
494    /// name is found in the right then we merge the two columns.  If there is no match then
495    /// we add the left column to the output.
496    ///
497    /// To merge two columns we consider the type.  If both arrays are struct arrays we recurse.
498    /// Otherwise we use the left array.
499    ///
500    /// Afterwards we add all non-matching right columns to the output.
501    ///
502    /// Note: This method likely does not handle nested fields correctly and you may want to consider
503    /// using [`merge_with_schema`] instead.
504    /// ```
505    /// use std::sync::Arc;
506    /// use arrow_array::*;
507    /// use arrow_schema::{Schema, Field, DataType};
508    /// use lance_arrow::*;
509    ///
510    /// let left_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
511    /// let int_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
512    /// let left = RecordBatch::try_new(left_schema, vec![int_arr.clone()]).unwrap();
513    ///
514    /// let right_schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
515    /// let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
516    /// let right = RecordBatch::try_new(right_schema, vec![str_arr.clone()]).unwrap();
517    ///
518    /// let new_record_batch = left.merge(&right).unwrap();
519    ///
520    /// assert_eq!(
521    ///     new_record_batch,
522    ///     RecordBatch::try_new(
523    ///         Arc::new(Schema::new(
524    ///             vec![
525    ///                 Field::new("a", DataType::Int32, true),
526    ///                 Field::new("s", DataType::Utf8, true)
527    ///             ])
528    ///         ),
529    ///         vec![int_arr, str_arr],
530    ///     ).unwrap()
531    /// )
532    /// ```
533    ///
534    /// TODO: add merge nested fields support.
535    fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;
536
537    /// Create a batch by merging columns between two batches with a given schema.
538    ///
539    /// A reference schema is used to determine the proper ordering of nested fields.
540    ///
541    /// For each field in the reference schema we look for corresponding fields in
542    /// the left and right batches.  If a field is found in both batches we recursively merge
543    /// it.
544    ///
545    /// If a field is only in the left or right batch we take it as it is.
546    fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch>;
547
548    /// Drop one column specified with the name and return the new [`RecordBatch`].
549    ///
550    /// If the named column does not exist, it returns a copy of this [`RecordBatch`].
551    fn drop_column(&self, name: &str) -> Result<RecordBatch>;
552
553    /// Replace a column (specified by name) and return the new [`RecordBatch`].
554    fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch>;
555
556    /// Replace a column schema (specified by name) and return the new [`RecordBatch`].
557    fn replace_column_schema_by_name(
558        &self,
559        name: &str,
560        new_data_type: DataType,
561        column: Arc<dyn Array>,
562    ) -> Result<RecordBatch>;
563
564    /// Get (potentially nested) column by qualified name.
565    fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef>;
566
567    /// Project the schema over the [RecordBatch].
568    fn project_by_schema(&self, schema: &Schema) -> Result<RecordBatch>;
569
570    /// metadata of the schema.
571    fn metadata(&self) -> &HashMap<String, String>;
572
573    /// Add metadata to the schema.
574    fn add_metadata(&self, key: String, value: String) -> Result<RecordBatch> {
575        let mut metadata = self.metadata().clone();
576        metadata.insert(key, value);
577        self.with_metadata(metadata)
578    }
579
580    /// Replace the schema metadata with the provided one.
581    fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch>;
582
583    /// Take selected rows from the [RecordBatch].
584    fn take(&self, indices: &UInt32Array) -> Result<RecordBatch>;
585}
586
587impl RecordBatchExt for RecordBatch {
588    fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<Self> {
589        let new_schema = Arc::new(self.schema().as_ref().try_with_column(field)?);
590        let mut new_columns = self.columns().to_vec();
591        new_columns.push(arr);
592        Self::try_new(new_schema, new_columns)
593    }
594
595    fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<Self> {
596        let new_schema = Arc::new(self.schema().as_ref().try_with_column_at(index, field)?);
597        let mut new_columns = self.columns().to_vec();
598        new_columns.insert(index, arr);
599        Self::try_new(new_schema, new_columns)
600    }
601
602    fn try_new_from_struct_array(&self, arr: StructArray) -> Result<Self> {
603        let schema = Arc::new(Schema::new_with_metadata(
604            arr.fields().to_vec(),
605            self.schema().metadata.clone(),
606        ));
607        let batch = Self::from(arr);
608        batch.with_schema(schema)
609    }
610
611    fn merge(&self, other: &Self) -> Result<Self> {
612        if self.num_rows() != other.num_rows() {
613            return Err(ArrowError::InvalidArgumentError(format!(
614                "Attempt to merge two RecordBatch with different sizes: {} != {}",
615                self.num_rows(),
616                other.num_rows()
617            )));
618        }
619        let left_struct_array: StructArray = self.clone().into();
620        let right_struct_array: StructArray = other.clone().into();
621        self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array))
622    }
623
624    fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch> {
625        if self.num_rows() != other.num_rows() {
626            return Err(ArrowError::InvalidArgumentError(format!(
627                "Attempt to merge two RecordBatch with different sizes: {} != {}",
628                self.num_rows(),
629                other.num_rows()
630            )));
631        }
632        let left_struct_array: StructArray = self.clone().into();
633        let right_struct_array: StructArray = other.clone().into();
634        self.try_new_from_struct_array(merge_with_schema(
635            &left_struct_array,
636            &right_struct_array,
637            schema.fields(),
638        ))
639    }
640
641    fn drop_column(&self, name: &str) -> Result<Self> {
642        let mut fields = vec![];
643        let mut columns = vec![];
644        for i in 0..self.schema().fields.len() {
645            if self.schema().field(i).name() != name {
646                fields.push(self.schema().field(i).clone());
647                columns.push(self.column(i).clone());
648            }
649        }
650        Self::try_new(
651            Arc::new(Schema::new_with_metadata(
652                fields,
653                self.schema().metadata().clone(),
654            )),
655            columns,
656        )
657    }
658
659    fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch> {
660        let mut columns = self.columns().to_vec();
661        let field_i = self
662            .schema()
663            .fields()
664            .iter()
665            .position(|f| f.name() == name)
666            .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
667        columns[field_i] = column;
668        Self::try_new(self.schema(), columns)
669    }
670
671    fn replace_column_schema_by_name(
672        &self,
673        name: &str,
674        new_data_type: DataType,
675        column: Arc<dyn Array>,
676    ) -> Result<RecordBatch> {
677        let fields = self
678            .schema()
679            .fields()
680            .iter()
681            .map(|x| {
682                if x.name() != name {
683                    x.clone()
684                } else {
685                    let new_field = Field::new(name, new_data_type.clone(), x.is_nullable());
686                    Arc::new(new_field)
687                }
688            })
689            .collect::<Vec<_>>();
690        let schema = Schema::new_with_metadata(fields, self.schema().metadata.clone());
691        let mut columns = self.columns().to_vec();
692        let field_i = self
693            .schema()
694            .fields()
695            .iter()
696            .position(|f| f.name() == name)
697            .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
698        columns[field_i] = column;
699        Self::try_new(Arc::new(schema), columns)
700    }
701
702    fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef> {
703        let split = name.split('.').collect::<Vec<_>>();
704        if split.is_empty() {
705            return None;
706        }
707
708        self.column_by_name(split[0])
709            .and_then(|arr| get_sub_array(arr, &split[1..]))
710    }
711
712    fn project_by_schema(&self, schema: &Schema) -> Result<Self> {
713        let struct_array: StructArray = self.clone().into();
714        self.try_new_from_struct_array(project(&struct_array, schema.fields())?)
715    }
716
717    fn metadata(&self) -> &HashMap<String, String> {
718        self.schema_ref().metadata()
719    }
720
721    fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch> {
722        let mut schema = self.schema_ref().as_ref().clone();
723        schema.metadata = metadata;
724        Self::try_new(schema.into(), self.columns().into())
725    }
726
727    fn take(&self, indices: &UInt32Array) -> Result<Self> {
728        let struct_array: StructArray = self.clone().into();
729        let taken = take(&struct_array, indices, None)?;
730        self.try_new_from_struct_array(taken.as_struct().clone())
731    }
732}
733
734fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
735    if fields.is_empty() {
736        return Ok(StructArray::new_empty_fields(
737            struct_array.len(),
738            struct_array.nulls().cloned(),
739        ));
740    }
741    let mut columns: Vec<ArrayRef> = vec![];
742    for field in fields.iter() {
743        if let Some(col) = struct_array.column_by_name(field.name()) {
744            match field.data_type() {
745                // TODO handle list-of-struct
746                DataType::Struct(subfields) => {
747                    let projected = project(col.as_struct(), subfields)?;
748                    columns.push(Arc::new(projected));
749                }
750                _ => {
751                    columns.push(col.clone());
752                }
753            }
754        } else {
755            return Err(ArrowError::SchemaError(format!(
756                "field {} does not exist in the RecordBatch",
757                field.name()
758            )));
759        }
760    }
761    StructArray::try_new(fields.clone(), columns, None)
762}
763
764fn lists_have_same_offsets_helper<T: OffsetSizeTrait>(left: &dyn Array, right: &dyn Array) -> bool {
765    let left_list: &GenericListArray<T> = left.as_list();
766    let right_list: &GenericListArray<T> = right.as_list();
767    left_list.offsets().inner() == right_list.offsets().inner()
768}
769
770fn merge_list_structs_helper<T: OffsetSizeTrait>(
771    left: &dyn Array,
772    right: &dyn Array,
773    items_field_name: impl Into<String>,
774    items_nullable: bool,
775) -> Arc<dyn Array> {
776    let left_list: &GenericListArray<T> = left.as_list();
777    let right_list: &GenericListArray<T> = right.as_list();
778    let left_struct = left_list.values();
779    let right_struct = right_list.values();
780    let left_struct_arr = left_struct.as_struct();
781    let right_struct_arr = right_struct.as_struct();
782    let merged_items = Arc::new(merge(left_struct_arr, right_struct_arr));
783    let items_field = Arc::new(Field::new(
784        items_field_name,
785        merged_items.data_type().clone(),
786        items_nullable,
787    ));
788    Arc::new(GenericListArray::<T>::new(
789        items_field,
790        left_list.offsets().clone(),
791        merged_items,
792        left_list.nulls().cloned(),
793    ))
794}
795
796fn merge_list_struct_null_helper<T: OffsetSizeTrait>(
797    left: &dyn Array,
798    right: &dyn Array,
799    not_null: &dyn Array,
800    items_field_name: impl Into<String>,
801) -> Arc<dyn Array> {
802    let left_list: &GenericListArray<T> = left.as_list::<T>();
803    let not_null_list = not_null.as_list::<T>();
804    let right_list = right.as_list::<T>();
805
806    let left_struct = left_list.values().as_struct();
807    let not_null_struct: &StructArray = not_null_list.values().as_struct();
808    let right_struct = right_list.values().as_struct();
809
810    let values_len = not_null_list.values().len();
811    let mut merged_fields =
812        Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
813    let mut merged_columns =
814        Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
815
816    for (_, field) in left_struct.columns().iter().zip(left_struct.fields()) {
817        merged_fields.push(field.clone());
818        if let Some(val) = not_null_struct.column_by_name(field.name()) {
819            merged_columns.push(val.clone());
820        } else {
821            merged_columns.push(new_null_array(field.data_type(), values_len))
822        }
823    }
824    for (_, field) in right_struct
825        .columns()
826        .iter()
827        .zip(right_struct.fields())
828        .filter(|(_, field)| left_struct.column_by_name(field.name()).is_none())
829    {
830        merged_fields.push(field.clone());
831        if let Some(val) = not_null_struct.column_by_name(field.name()) {
832            merged_columns.push(val.clone());
833        } else {
834            merged_columns.push(new_null_array(field.data_type(), values_len));
835        }
836    }
837
838    let merged_struct = Arc::new(StructArray::new(
839        Fields::from(merged_fields),
840        merged_columns,
841        not_null_struct.nulls().cloned(),
842    ));
843    let items_field = Arc::new(Field::new(
844        items_field_name,
845        merged_struct.data_type().clone(),
846        true,
847    ));
848    Arc::new(GenericListArray::<T>::new(
849        items_field,
850        not_null_list.offsets().clone(),
851        merged_struct,
852        not_null_list.nulls().cloned(),
853    ))
854}
855
856fn merge_list_struct_null(
857    left: &dyn Array,
858    right: &dyn Array,
859    not_null: &dyn Array,
860) -> Arc<dyn Array> {
861    match left.data_type() {
862        DataType::List(left_field) => {
863            merge_list_struct_null_helper::<i32>(left, right, not_null, left_field.name())
864        }
865        DataType::LargeList(left_field) => {
866            merge_list_struct_null_helper::<i64>(left, right, not_null, left_field.name())
867        }
868        _ => unreachable!(),
869    }
870}
871
872fn merge_list_struct(left: &dyn Array, right: &dyn Array) -> Arc<dyn Array> {
873    // Merging fields into a list<struct<...>> is tricky and can only succeed
874    // in two ways.  First, if both lists have the same offsets.  Second, if
875    // one of the lists is all-null
876    if left.null_count() == left.len() {
877        return merge_list_struct_null(left, right, right);
878    } else if right.null_count() == right.len() {
879        return merge_list_struct_null(left, right, left);
880    }
881    match (left.data_type(), right.data_type()) {
882        (DataType::List(left_field), DataType::List(_)) => {
883            if !lists_have_same_offsets_helper::<i32>(left, right) {
884                panic!("Attempt to merge list struct arrays which do not have same offsets");
885            }
886            merge_list_structs_helper::<i32>(
887                left,
888                right,
889                left_field.name(),
890                left_field.is_nullable(),
891            )
892        }
893        (DataType::LargeList(left_field), DataType::LargeList(_)) => {
894            if !lists_have_same_offsets_helper::<i64>(left, right) {
895                panic!("Attempt to merge list struct arrays which do not have same offsets");
896            }
897            merge_list_structs_helper::<i64>(
898                left,
899                right,
900                left_field.name(),
901                left_field.is_nullable(),
902            )
903        }
904        _ => unreachable!(),
905    }
906}
907
908fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray {
909    let mut fields: Vec<Field> = vec![];
910    let mut columns: Vec<ArrayRef> = vec![];
911    let right_fields = right_struct_array.fields();
912    let right_columns = right_struct_array.columns();
913
914    // iterate through the fields on the left hand side
915    for (left_field, left_column) in left_struct_array
916        .fields()
917        .iter()
918        .zip(left_struct_array.columns().iter())
919    {
920        match right_fields
921            .iter()
922            .position(|f| f.name() == left_field.name())
923        {
924            // if the field exists on the right hand side, merge them recursively if appropriate
925            Some(right_index) => {
926                let right_field = right_fields.get(right_index).unwrap();
927                let right_column = right_columns.get(right_index).unwrap();
928                // if both fields are struct, merge them recursively
929                match (left_field.data_type(), right_field.data_type()) {
930                    (DataType::Struct(_), DataType::Struct(_)) => {
931                        let left_sub_array = left_column.as_struct();
932                        let right_sub_array = right_column.as_struct();
933                        let merged_sub_array = merge(left_sub_array, right_sub_array);
934                        fields.push(Field::new(
935                            left_field.name(),
936                            merged_sub_array.data_type().clone(),
937                            left_field.is_nullable(),
938                        ));
939                        columns.push(Arc::new(merged_sub_array) as ArrayRef);
940                    }
941                    (DataType::List(left_list), DataType::List(right_list))
942                        if left_list.data_type().is_struct()
943                            && right_list.data_type().is_struct() =>
944                    {
945                        // If there is nothing to merge just use the left field
946                        if left_list.data_type() == right_list.data_type() {
947                            fields.push(left_field.as_ref().clone());
948                            columns.push(left_column.clone());
949                        }
950                        // If we have two List<Struct> and they have different sets of fields then
951                        // we can merge them if the offsets arrays are the same.  Otherwise, we
952                        // have to consider it an error.
953                        let merged_sub_array = merge_list_struct(&left_column, &right_column);
954
955                        fields.push(Field::new(
956                            left_field.name(),
957                            merged_sub_array.data_type().clone(),
958                            left_field.is_nullable(),
959                        ));
960                        columns.push(merged_sub_array);
961                    }
962                    // otherwise, just use the field on the left hand side
963                    _ => {
964                        // TODO handle list-of-struct and other types
965                        fields.push(left_field.as_ref().clone());
966                        columns.push(left_column.clone());
967                    }
968                }
969            }
970            None => {
971                fields.push(left_field.as_ref().clone());
972                columns.push(left_column.clone());
973            }
974        }
975    }
976
977    // now iterate through the fields on the right hand side
978    right_fields
979        .iter()
980        .zip(right_columns.iter())
981        .for_each(|(field, column)| {
982            // add new columns on the right
983            if !left_struct_array
984                .fields()
985                .iter()
986                .any(|f| f.name() == field.name())
987            {
988                fields.push(field.as_ref().clone());
989                columns.push(column.clone() as ArrayRef);
990            }
991        });
992
993    let zipped: Vec<(FieldRef, ArrayRef)> = fields
994        .iter()
995        .cloned()
996        .map(Arc::new)
997        .zip(columns.iter().cloned())
998        .collect::<Vec<_>>();
999    StructArray::from(zipped)
1000}
1001
1002fn merge_with_schema(
1003    left_struct_array: &StructArray,
1004    right_struct_array: &StructArray,
1005    fields: &Fields,
1006) -> StructArray {
1007    // Helper function that returns true if both types are struct or both are non-struct
1008    fn same_type_kind(left: &DataType, right: &DataType) -> bool {
1009        match (left, right) {
1010            (DataType::Struct(_), DataType::Struct(_)) => true,
1011            (DataType::Struct(_), _) => false,
1012            (_, DataType::Struct(_)) => false,
1013            _ => true,
1014        }
1015    }
1016
1017    let mut output_fields: Vec<Field> = Vec::with_capacity(fields.len());
1018    let mut columns: Vec<ArrayRef> = Vec::with_capacity(fields.len());
1019
1020    let left_fields = left_struct_array.fields();
1021    let left_columns = left_struct_array.columns();
1022    let right_fields = right_struct_array.fields();
1023    let right_columns = right_struct_array.columns();
1024
1025    for field in fields {
1026        let left_match_idx = left_fields.iter().position(|f| {
1027            f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
1028        });
1029        let right_match_idx = right_fields.iter().position(|f| {
1030            f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
1031        });
1032
1033        match (left_match_idx, right_match_idx) {
1034            (None, Some(right_idx)) => {
1035                output_fields.push(right_fields[right_idx].as_ref().clone());
1036                columns.push(right_columns[right_idx].clone());
1037            }
1038            (Some(left_idx), None) => {
1039                output_fields.push(left_fields[left_idx].as_ref().clone());
1040                columns.push(left_columns[left_idx].clone());
1041            }
1042            (Some(left_idx), Some(right_idx)) => {
1043                if let DataType::Struct(child_fields) = field.data_type() {
1044                    let left_sub_array = left_columns[left_idx].as_struct();
1045                    let right_sub_array = right_columns[right_idx].as_struct();
1046                    let merged_sub_array =
1047                        merge_with_schema(left_sub_array, right_sub_array, child_fields);
1048                    output_fields.push(Field::new(
1049                        field.name(),
1050                        merged_sub_array.data_type().clone(),
1051                        field.is_nullable(),
1052                    ));
1053                    columns.push(Arc::new(merged_sub_array) as ArrayRef);
1054                } else {
1055                    output_fields.push(left_fields[left_idx].as_ref().clone());
1056                    columns.push(left_columns[left_idx].clone());
1057                }
1058            }
1059            (None, None) => {
1060                // The field will not be included in the output
1061            }
1062        }
1063    }
1064
1065    let zipped: Vec<(FieldRef, ArrayRef)> = output_fields
1066        .into_iter()
1067        .map(Arc::new)
1068        .zip(columns)
1069        .collect::<Vec<_>>();
1070    StructArray::from(zipped)
1071}
1072
1073fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> {
1074    if components.is_empty() {
1075        return Some(array);
1076    }
1077    if !matches!(array.data_type(), DataType::Struct(_)) {
1078        return None;
1079    }
1080    let struct_arr = array.as_struct();
1081    struct_arr
1082        .column_by_name(components[0])
1083        .and_then(|arr| get_sub_array(arr, &components[1..]))
1084}
1085
1086/// Interleave multiple RecordBatches into a single RecordBatch.
1087///
1088/// Behaves like [`arrow::compute::interleave`], but for RecordBatches.
1089pub fn interleave_batches(
1090    batches: &[RecordBatch],
1091    indices: &[(usize, usize)],
1092) -> Result<RecordBatch> {
1093    let first_batch = batches.first().ok_or_else(|| {
1094        ArrowError::InvalidArgumentError("Cannot interleave zero RecordBatches".to_string())
1095    })?;
1096    let schema = first_batch.schema();
1097    let num_columns = first_batch.num_columns();
1098    let mut columns = Vec::with_capacity(num_columns);
1099    let mut chunks = Vec::with_capacity(batches.len());
1100
1101    for i in 0..num_columns {
1102        for batch in batches {
1103            chunks.push(batch.column(i).as_ref());
1104        }
1105        let new_column = interleave(&chunks, indices)?;
1106        columns.push(new_column);
1107        chunks.clear();
1108    }
1109
1110    RecordBatch::try_new(schema, columns)
1111}
1112
1113pub trait BufferExt {
1114    /// Create an `arrow_buffer::Buffer`` from a `bytes::Bytes` object
1115    ///
1116    /// The alignment must be specified (as `bytes_per_value`) since we want to make
1117    /// sure we can safely reinterpret the buffer.
1118    ///
1119    /// If the buffer is properly aligned this will be zero-copy.  If not, a copy
1120    /// will be made and an owned buffer returned.
1121    ///
1122    /// If `bytes_per_value` is not a power of two, then we assume the buffer is
1123    /// never going to be reinterpreted into another type and we can safely
1124    /// ignore the alignment.
1125    ///
1126    /// Yes, the method name is odd.  It's because there is already a `from_bytes`
1127    /// which converts from `arrow_buffer::bytes::Bytes` (not `bytes::Bytes`)
1128    fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self;
1129
1130    /// Allocates a new properly aligned arrow buffer and copies `bytes` into it
1131    ///
1132    /// `size_bytes` can be larger than `bytes` and, if so, the trailing bytes will
1133    /// be zeroed out.
1134    ///
1135    /// # Panics
1136    ///
1137    /// Panics if `size_bytes` is less than `bytes.len()`
1138    fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self;
1139}
1140
1141fn is_pwr_two(n: u64) -> bool {
1142    n & (n - 1) == 0
1143}
1144
1145impl BufferExt for arrow_buffer::Buffer {
1146    fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self {
1147        if is_pwr_two(bytes_per_value) && bytes.as_ptr().align_offset(bytes_per_value as usize) != 0
1148        {
1149            // The original buffer is not aligned, cannot zero-copy
1150            let size_bytes = bytes.len();
1151            Self::copy_bytes_bytes(bytes, size_bytes)
1152        } else {
1153            // The original buffer is aligned, can zero-copy
1154            // SAFETY: the alignment is correct we can make this conversion
1155            unsafe {
1156                Self::from_custom_allocation(
1157                    NonNull::new(bytes.as_ptr() as _).expect("should be a valid pointer"),
1158                    bytes.len(),
1159                    Arc::new(bytes),
1160                )
1161            }
1162        }
1163    }
1164
1165    fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self {
1166        assert!(size_bytes >= bytes.len());
1167        let mut buf = MutableBuffer::with_capacity(size_bytes);
1168        let to_fill = size_bytes - bytes.len();
1169        buf.extend(bytes);
1170        buf.extend(std::iter::repeat_n(0_u8, to_fill));
1171        Self::from(buf)
1172    }
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177    use super::*;
1178    use arrow_array::{new_empty_array, new_null_array, Int32Array, ListArray, StringArray};
1179    use arrow_buffer::OffsetBuffer;
1180
1181    #[test]
1182    fn test_merge_recursive() {
1183        let a_array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
1184        let e_array = Int32Array::from(vec![Some(4), Some(5), Some(6)]);
1185        let c_array = Int32Array::from(vec![Some(7), Some(8), Some(9)]);
1186        let d_array = StringArray::from(vec![Some("a"), Some("b"), Some("c")]);
1187
1188        let left_schema = Schema::new(vec![
1189            Field::new("a", DataType::Int32, true),
1190            Field::new(
1191                "b",
1192                DataType::Struct(vec![Field::new("c", DataType::Int32, true)].into()),
1193                true,
1194            ),
1195        ]);
1196        let left_batch = RecordBatch::try_new(
1197            Arc::new(left_schema),
1198            vec![
1199                Arc::new(a_array.clone()),
1200                Arc::new(StructArray::from(vec![(
1201                    Arc::new(Field::new("c", DataType::Int32, true)),
1202                    Arc::new(c_array.clone()) as ArrayRef,
1203                )])),
1204            ],
1205        )
1206        .unwrap();
1207
1208        let right_schema = Schema::new(vec![
1209            Field::new("e", DataType::Int32, true),
1210            Field::new(
1211                "b",
1212                DataType::Struct(vec![Field::new("d", DataType::Utf8, true)].into()),
1213                true,
1214            ),
1215        ]);
1216        let right_batch = RecordBatch::try_new(
1217            Arc::new(right_schema),
1218            vec![
1219                Arc::new(e_array.clone()),
1220                Arc::new(StructArray::from(vec![(
1221                    Arc::new(Field::new("d", DataType::Utf8, true)),
1222                    Arc::new(d_array.clone()) as ArrayRef,
1223                )])) as ArrayRef,
1224            ],
1225        )
1226        .unwrap();
1227
1228        let merged_schema = Schema::new(vec![
1229            Field::new("a", DataType::Int32, true),
1230            Field::new(
1231                "b",
1232                DataType::Struct(
1233                    vec![
1234                        Field::new("c", DataType::Int32, true),
1235                        Field::new("d", DataType::Utf8, true),
1236                    ]
1237                    .into(),
1238                ),
1239                true,
1240            ),
1241            Field::new("e", DataType::Int32, true),
1242        ]);
1243        let merged_batch = RecordBatch::try_new(
1244            Arc::new(merged_schema),
1245            vec![
1246                Arc::new(a_array) as ArrayRef,
1247                Arc::new(StructArray::from(vec![
1248                    (
1249                        Arc::new(Field::new("c", DataType::Int32, true)),
1250                        Arc::new(c_array) as ArrayRef,
1251                    ),
1252                    (
1253                        Arc::new(Field::new("d", DataType::Utf8, true)),
1254                        Arc::new(d_array) as ArrayRef,
1255                    ),
1256                ])) as ArrayRef,
1257                Arc::new(e_array) as ArrayRef,
1258            ],
1259        )
1260        .unwrap();
1261
1262        let result = left_batch.merge(&right_batch).unwrap();
1263        assert_eq!(result, merged_batch);
1264    }
1265
1266    #[test]
1267    fn test_merge_with_schema() {
1268        fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) {
1269            let fields: Fields = names
1270                .iter()
1271                .zip(types)
1272                .map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false))
1273                .collect();
1274            let schema = Schema::new(vec![Field::new(
1275                "struct",
1276                DataType::Struct(fields.clone()),
1277                false,
1278            )]);
1279            let children = types.iter().map(new_empty_array).collect::<Vec<_>>();
1280            let batch = RecordBatch::try_new(
1281                Arc::new(schema.clone()),
1282                vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef],
1283            );
1284            (schema, batch.unwrap())
1285        }
1286
1287        let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]);
1288        let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]);
1289        let (output_schema, _) = test_batch(
1290            &["b", "a", "c"],
1291            &[DataType::Int64, DataType::Int32, DataType::Int32],
1292        );
1293
1294        // If we use merge_with_schema the schema is respected
1295        let merged = left_batch
1296            .merge_with_schema(&right_batch, &output_schema)
1297            .unwrap();
1298        assert_eq!(merged.schema().as_ref(), &output_schema);
1299
1300        // If we use merge we get first-come first-serve based on the left batch
1301        let (naive_schema, _) = test_batch(
1302            &["a", "b", "c"],
1303            &[DataType::Int32, DataType::Int64, DataType::Int32],
1304        );
1305        let merged = left_batch.merge(&right_batch).unwrap();
1306        assert_eq!(merged.schema().as_ref(), &naive_schema);
1307    }
1308
1309    #[test]
1310    fn test_merge_list_struct() {
1311        let x_field = Arc::new(Field::new("x", DataType::Int32, true));
1312        let y_field = Arc::new(Field::new("y", DataType::Int32, true));
1313        let x_struct_field = Arc::new(Field::new(
1314            "item",
1315            DataType::Struct(Fields::from(vec![x_field.clone()])),
1316            true,
1317        ));
1318        let y_struct_field = Arc::new(Field::new(
1319            "item",
1320            DataType::Struct(Fields::from(vec![y_field.clone()])),
1321            true,
1322        ));
1323        let both_struct_field = Arc::new(Field::new(
1324            "item",
1325            DataType::Struct(Fields::from(vec![x_field.clone(), y_field.clone()])),
1326            true,
1327        ));
1328        let left_schema = Schema::new(vec![Field::new(
1329            "list_struct",
1330            DataType::List(x_struct_field.clone()),
1331            true,
1332        )]);
1333        let right_schema = Schema::new(vec![Field::new(
1334            "list_struct",
1335            DataType::List(y_struct_field.clone()),
1336            true,
1337        )]);
1338        let both_schema = Schema::new(vec![Field::new(
1339            "list_struct",
1340            DataType::List(both_struct_field.clone()),
1341            true,
1342        )]);
1343
1344        let x = Arc::new(Int32Array::from(vec![1]));
1345        let y = Arc::new(Int32Array::from(vec![2]));
1346        let x_struct = Arc::new(StructArray::new(
1347            Fields::from(vec![x_field.clone()]),
1348            vec![x.clone()],
1349            None,
1350        ));
1351        let y_struct = Arc::new(StructArray::new(
1352            Fields::from(vec![y_field.clone()]),
1353            vec![y.clone()],
1354            None,
1355        ));
1356        let both_struct = Arc::new(StructArray::new(
1357            Fields::from(vec![x_field.clone(), y_field.clone()]),
1358            vec![x.clone(), y],
1359            None,
1360        ));
1361        let both_null_struct = Arc::new(StructArray::new(
1362            Fields::from(vec![x_field, y_field]),
1363            vec![x, Arc::new(new_null_array(&DataType::Int32, 1))],
1364            None,
1365        ));
1366        let offsets = OffsetBuffer::from_lengths([1]);
1367        let x_s_list = ListArray::new(x_struct_field, offsets.clone(), x_struct, None);
1368        let y_s_list = ListArray::new(y_struct_field, offsets.clone(), y_struct, None);
1369        let both_list = ListArray::new(
1370            both_struct_field.clone(),
1371            offsets.clone(),
1372            both_struct,
1373            None,
1374        );
1375        let both_null_list = ListArray::new(both_struct_field, offsets, both_null_struct, None);
1376        let x_batch =
1377            RecordBatch::try_new(Arc::new(left_schema), vec![Arc::new(x_s_list)]).unwrap();
1378        let y_batch = RecordBatch::try_new(
1379            Arc::new(right_schema.clone()),
1380            vec![Arc::new(y_s_list.clone())],
1381        )
1382        .unwrap();
1383        let merged = x_batch.merge(&y_batch).unwrap();
1384        let expected =
1385            RecordBatch::try_new(Arc::new(both_schema.clone()), vec![Arc::new(both_list)]).unwrap();
1386        assert_eq!(merged, expected);
1387
1388        let y_null_list = new_null_array(y_s_list.data_type(), 1);
1389        let y_null_batch =
1390            RecordBatch::try_new(Arc::new(right_schema), vec![Arc::new(y_null_list.clone())])
1391                .unwrap();
1392        let expected =
1393            RecordBatch::try_new(Arc::new(both_schema), vec![Arc::new(both_null_list)]).unwrap();
1394        let merged = x_batch.merge(&y_null_batch).unwrap();
1395        assert_eq!(merged, expected);
1396    }
1397
1398    #[test]
1399    fn test_byte_width_opt() {
1400        assert_eq!(DataType::Int32.byte_width_opt(), Some(4));
1401        assert_eq!(DataType::Int64.byte_width_opt(), Some(8));
1402        assert_eq!(DataType::Float32.byte_width_opt(), Some(4));
1403        assert_eq!(DataType::Float64.byte_width_opt(), Some(8));
1404        assert_eq!(DataType::Utf8.byte_width_opt(), None);
1405        assert_eq!(DataType::Binary.byte_width_opt(), None);
1406        assert_eq!(
1407            DataType::List(Arc::new(Field::new("item", DataType::Int32, true))).byte_width_opt(),
1408            None
1409        );
1410        assert_eq!(
1411            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3)
1412                .byte_width_opt(),
1413            Some(12)
1414        );
1415        assert_eq!(
1416            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4)
1417                .byte_width_opt(),
1418            Some(16)
1419        );
1420        assert_eq!(
1421            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Utf8, true)), 5)
1422                .byte_width_opt(),
1423            None
1424        );
1425    }
1426
1427    #[test]
1428    fn test_take_record_batch() {
1429        let schema = Arc::new(Schema::new(vec![
1430            Field::new("a", DataType::Int32, true),
1431            Field::new("b", DataType::Utf8, true),
1432        ]));
1433        let batch = RecordBatch::try_new(
1434            schema.clone(),
1435            vec![
1436                Arc::new(Int32Array::from_iter_values(0..20)),
1437                Arc::new(StringArray::from_iter_values(
1438                    (0..20).map(|i| format!("str-{}", i)),
1439                )),
1440            ],
1441        )
1442        .unwrap();
1443        let taken = batch.take(&(vec![1_u32, 5_u32, 10_u32].into())).unwrap();
1444        assert_eq!(
1445            taken,
1446            RecordBatch::try_new(
1447                schema,
1448                vec![
1449                    Arc::new(Int32Array::from(vec![1, 5, 10])),
1450                    Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
1451                ],
1452            )
1453            .unwrap()
1454        )
1455    }
1456
1457    #[test]
1458    fn test_schema_project_by_schema() {
1459        let metadata = [("key".to_string(), "value".to_string())];
1460        let schema = Arc::new(
1461            Schema::new(vec![
1462                Field::new("a", DataType::Int32, true),
1463                Field::new("b", DataType::Utf8, true),
1464            ])
1465            .with_metadata(metadata.clone().into()),
1466        );
1467        let batch = RecordBatch::try_new(
1468            schema,
1469            vec![
1470                Arc::new(Int32Array::from_iter_values(0..20)),
1471                Arc::new(StringArray::from_iter_values(
1472                    (0..20).map(|i| format!("str-{}", i)),
1473                )),
1474            ],
1475        )
1476        .unwrap();
1477
1478        // Empty schema
1479        let empty_schema = Schema::empty();
1480        let empty_projected = batch.project_by_schema(&empty_schema).unwrap();
1481        let expected_schema = empty_schema.with_metadata(metadata.clone().into());
1482        assert_eq!(
1483            empty_projected,
1484            RecordBatch::from(StructArray::new_empty_fields(batch.num_rows(), None))
1485                .with_schema(Arc::new(expected_schema))
1486                .unwrap()
1487        );
1488
1489        // Re-ordered schema
1490        let reordered_schema = Schema::new(vec![
1491            Field::new("b", DataType::Utf8, true),
1492            Field::new("a", DataType::Int32, true),
1493        ]);
1494        let reordered_projected = batch.project_by_schema(&reordered_schema).unwrap();
1495        let expected_schema = Arc::new(reordered_schema.with_metadata(metadata.clone().into()));
1496        assert_eq!(
1497            reordered_projected,
1498            RecordBatch::try_new(
1499                expected_schema,
1500                vec![
1501                    Arc::new(StringArray::from_iter_values(
1502                        (0..20).map(|i| format!("str-{}", i)),
1503                    )),
1504                    Arc::new(Int32Array::from_iter_values(0..20)),
1505                ],
1506            )
1507            .unwrap()
1508        );
1509
1510        // Sub schema
1511        let sub_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1512        let sub_projected = batch.project_by_schema(&sub_schema).unwrap();
1513        let expected_schema = Arc::new(sub_schema.with_metadata(metadata.into()));
1514        assert_eq!(
1515            sub_projected,
1516            RecordBatch::try_new(
1517                expected_schema,
1518                vec![Arc::new(Int32Array::from_iter_values(0..20))],
1519            )
1520            .unwrap()
1521        );
1522    }
1523}