lance_arrow/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend Arrow Functionality
5//!
6//! To improve Arrow-RS ergonomic
7
8use std::sync::Arc;
9use std::{collections::HashMap, ptr::NonNull};
10
11use arrow_array::{
12    cast::AsArray, Array, ArrayRef, ArrowNumericType, FixedSizeBinaryArray, FixedSizeListArray,
13    GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt32Array,
14    UInt8Array,
15};
16use arrow_array::{
17    new_null_array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
18};
19use arrow_buffer::MutableBuffer;
20use arrow_data::ArrayDataBuilder;
21use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema};
22use arrow_select::{interleave::interleave, take::take};
23use rand::prelude::*;
24
25pub mod deepcopy;
26pub mod schema;
27pub use schema::*;
28pub mod bfloat16;
29pub mod floats;
30pub use floats::*;
31pub mod cast;
32pub mod list;
33pub mod memory;
34
35type Result<T> = std::result::Result<T, ArrowError>;
36
37pub trait DataTypeExt {
38    /// Returns true if the data type is binary-like, such as (Large)Utf8 and (Large)Binary.
39    ///
40    /// ```
41    /// use lance_arrow::*;
42    /// use arrow_schema::DataType;
43    ///
44    /// assert!(DataType::Utf8.is_binary_like());
45    /// assert!(DataType::Binary.is_binary_like());
46    /// assert!(DataType::LargeUtf8.is_binary_like());
47    /// assert!(DataType::LargeBinary.is_binary_like());
48    /// assert!(!DataType::Int32.is_binary_like());
49    /// ```
50    fn is_binary_like(&self) -> bool;
51
52    /// Returns true if the data type is a struct.
53    fn is_struct(&self) -> bool;
54
55    /// Check whether the given Arrow DataType is fixed stride.
56    ///
57    /// A fixed stride type has the same byte width for all array elements
58    /// This includes all PrimitiveType's Boolean, FixedSizeList, FixedSizeBinary, and Decimals
59    fn is_fixed_stride(&self) -> bool;
60
61    /// Returns true if the [DataType] is a dictionary type.
62    fn is_dictionary(&self) -> bool;
63
64    /// Returns the byte width of the data type
65    /// Panics if the data type is not fixed stride.
66    fn byte_width(&self) -> usize;
67
68    /// Returns the byte width of the data type, if it is fixed stride.
69    /// Returns None if the data type is not fixed stride.
70    fn byte_width_opt(&self) -> Option<usize>;
71}
72
73impl DataTypeExt for DataType {
74    fn is_binary_like(&self) -> bool {
75        use DataType::*;
76        matches!(self, Utf8 | Binary | LargeUtf8 | LargeBinary)
77    }
78
79    fn is_struct(&self) -> bool {
80        matches!(self, Self::Struct(_))
81    }
82
83    fn is_fixed_stride(&self) -> bool {
84        use DataType::*;
85        matches!(
86            self,
87            Boolean
88                | UInt8
89                | UInt16
90                | UInt32
91                | UInt64
92                | Int8
93                | Int16
94                | Int32
95                | Int64
96                | Float16
97                | Float32
98                | Float64
99                | Decimal128(_, _)
100                | Decimal256(_, _)
101                | FixedSizeList(_, _)
102                | FixedSizeBinary(_)
103                | Duration(_)
104                | Timestamp(_, _)
105                | Date32
106                | Date64
107                | Time32(_)
108                | Time64(_)
109        )
110    }
111
112    fn is_dictionary(&self) -> bool {
113        matches!(self, Self::Dictionary(_, _))
114    }
115
116    fn byte_width_opt(&self) -> Option<usize> {
117        match self {
118            Self::Int8 => Some(1),
119            Self::Int16 => Some(2),
120            Self::Int32 => Some(4),
121            Self::Int64 => Some(8),
122            Self::UInt8 => Some(1),
123            Self::UInt16 => Some(2),
124            Self::UInt32 => Some(4),
125            Self::UInt64 => Some(8),
126            Self::Float16 => Some(2),
127            Self::Float32 => Some(4),
128            Self::Float64 => Some(8),
129            Self::Date32 => Some(4),
130            Self::Date64 => Some(8),
131            Self::Time32(_) => Some(4),
132            Self::Time64(_) => Some(8),
133            Self::Timestamp(_, _) => Some(8),
134            Self::Duration(_) => Some(8),
135            Self::Decimal128(_, _) => Some(16),
136            Self::Decimal256(_, _) => Some(32),
137            Self::Interval(unit) => match unit {
138                IntervalUnit::YearMonth => Some(4),
139                IntervalUnit::DayTime => Some(8),
140                IntervalUnit::MonthDayNano => Some(16),
141            },
142            Self::FixedSizeBinary(s) => Some(*s as usize),
143            Self::FixedSizeList(dt, s) => Some(*s as usize * dt.data_type().byte_width()),
144            _ => None,
145        }
146    }
147
148    fn byte_width(&self) -> usize {
149        self.byte_width_opt()
150            .unwrap_or_else(|| panic!("Expecting fixed stride data type, found {:?}", self))
151    }
152}
153
154/// Create an [`GenericListArray`] from values and offsets.
155///
156/// ```
157/// use arrow_array::{Int32Array, Int64Array, ListArray};
158/// use arrow_array::types::Int64Type;
159/// use lance_arrow::try_new_generic_list_array;
160///
161/// let offsets = Int32Array::from_iter([0, 2, 7, 10]);
162/// let int_values = Int64Array::from_iter(0..10);
163/// let list_arr = try_new_generic_list_array(int_values, &offsets).unwrap();
164/// assert_eq!(list_arr,
165///     ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
166///         Some(vec![Some(0), Some(1)]),
167///         Some(vec![Some(2), Some(3), Some(4), Some(5), Some(6)]),
168///         Some(vec![Some(7), Some(8), Some(9)]),
169/// ]))
170/// ```
171pub fn try_new_generic_list_array<T: Array, Offset: ArrowNumericType>(
172    values: T,
173    offsets: &PrimitiveArray<Offset>,
174) -> Result<GenericListArray<Offset::Native>>
175where
176    Offset::Native: OffsetSizeTrait,
177{
178    let data_type = if Offset::Native::IS_LARGE {
179        DataType::LargeList(Arc::new(Field::new(
180            "item",
181            values.data_type().clone(),
182            true,
183        )))
184    } else {
185        DataType::List(Arc::new(Field::new(
186            "item",
187            values.data_type().clone(),
188            true,
189        )))
190    };
191    let data = ArrayDataBuilder::new(data_type)
192        .len(offsets.len() - 1)
193        .add_buffer(offsets.into_data().buffers()[0].clone())
194        .add_child_data(values.into_data())
195        .build()?;
196
197    Ok(GenericListArray::from(data))
198}
199
200pub fn fixed_size_list_type(list_width: i32, inner_type: DataType) -> DataType {
201    DataType::FixedSizeList(Arc::new(Field::new("item", inner_type, true)), list_width)
202}
203
204pub trait FixedSizeListArrayExt {
205    /// Create an [`FixedSizeListArray`] from values and list size.
206    ///
207    /// ```
208    /// use arrow_array::{Int64Array, FixedSizeListArray};
209    /// use arrow_array::types::Int64Type;
210    /// use lance_arrow::FixedSizeListArrayExt;
211    ///
212    /// let int_values = Int64Array::from_iter(0..10);
213    /// let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(int_values, 2).unwrap();
214    /// assert_eq!(fixed_size_list_arr,
215    ///     FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(vec![
216    ///         Some(vec![Some(0), Some(1)]),
217    ///         Some(vec![Some(2), Some(3)]),
218    ///         Some(vec![Some(4), Some(5)]),
219    ///         Some(vec![Some(6), Some(7)]),
220    ///         Some(vec![Some(8), Some(9)])
221    /// ], 2))
222    /// ```
223    fn try_new_from_values<T: Array + 'static>(
224        values: T,
225        list_size: i32,
226    ) -> Result<FixedSizeListArray>;
227
228    /// Sample `n` rows from the [FixedSizeListArray]
229    ///
230    /// ```
231    /// use arrow_array::{Int64Array, FixedSizeListArray, Array};
232    /// use lance_arrow::FixedSizeListArrayExt;
233    ///
234    /// let int_values = Int64Array::from_iter(0..256);
235    /// let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(int_values, 16).unwrap();
236    /// let sampled = fixed_size_list_arr.sample(10).unwrap();
237    /// assert_eq!(sampled.len(), 10);
238    /// assert_eq!(sampled.value_length(), 16);
239    /// assert_eq!(sampled.values().len(), 160);
240    /// ```
241    fn sample(&self, n: usize) -> Result<FixedSizeListArray>;
242
243    /// Ensure the [FixedSizeListArray] of Float16, Float32, Float64,
244    /// Int8, Int16, Int32, Int64, UInt8, UInt32 type to its closest floating point type.
245    fn convert_to_floating_point(&self) -> Result<FixedSizeListArray>;
246}
247
248impl FixedSizeListArrayExt for FixedSizeListArray {
249    fn try_new_from_values<T: Array + 'static>(values: T, list_size: i32) -> Result<Self> {
250        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
251        let values = Arc::new(values);
252
253        Self::try_new(field, list_size, values, None)
254    }
255
256    fn sample(&self, n: usize) -> Result<FixedSizeListArray> {
257        if n >= self.len() {
258            return Ok(self.clone());
259        }
260        let mut rng = SmallRng::from_entropy();
261        let chosen = (0..self.len() as u32).choose_multiple(&mut rng, n);
262        take(self, &UInt32Array::from(chosen), None).map(|arr| arr.as_fixed_size_list().clone())
263    }
264
265    fn convert_to_floating_point(&self) -> Result<FixedSizeListArray> {
266        match self.data_type() {
267            DataType::FixedSizeList(field, size) => match field.data_type() {
268                DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(self.clone()),
269                DataType::Int8 => Ok(Self::new(
270                    Arc::new(arrow_schema::Field::new(
271                        field.name(),
272                        DataType::Float32,
273                        field.is_nullable(),
274                    )),
275                    *size,
276                    Arc::new(Float32Array::from_iter_values(
277                        self.values()
278                            .as_any()
279                            .downcast_ref::<Int8Array>()
280                            .ok_or(ArrowError::ParseError(
281                                "Fail to cast primitive array to Int8Type".to_string(),
282                            ))?
283                            .into_iter()
284                            .filter_map(|x| x.map(|y| y as f32)),
285                    )),
286                    self.nulls().cloned(),
287                )),
288                DataType::Int16 => Ok(Self::new(
289                    Arc::new(arrow_schema::Field::new(
290                        field.name(),
291                        DataType::Float32,
292                        field.is_nullable(),
293                    )),
294                    *size,
295                    Arc::new(Float32Array::from_iter_values(
296                        self.values()
297                            .as_any()
298                            .downcast_ref::<Int16Array>()
299                            .ok_or(ArrowError::ParseError(
300                                "Fail to cast primitive array to Int8Type".to_string(),
301                            ))?
302                            .into_iter()
303                            .filter_map(|x| x.map(|y| y as f32)),
304                    )),
305                    self.nulls().cloned(),
306                )),
307                DataType::Int32 => Ok(Self::new(
308                    Arc::new(arrow_schema::Field::new(
309                        field.name(),
310                        DataType::Float32,
311                        field.is_nullable(),
312                    )),
313                    *size,
314                    Arc::new(Float32Array::from_iter_values(
315                        self.values()
316                            .as_any()
317                            .downcast_ref::<Int32Array>()
318                            .ok_or(ArrowError::ParseError(
319                                "Fail to cast primitive array to Int8Type".to_string(),
320                            ))?
321                            .into_iter()
322                            .filter_map(|x| x.map(|y| y as f32)),
323                    )),
324                    self.nulls().cloned(),
325                )),
326                DataType::Int64 => Ok(Self::new(
327                    Arc::new(arrow_schema::Field::new(
328                        field.name(),
329                        DataType::Float64,
330                        field.is_nullable(),
331                    )),
332                    *size,
333                    Arc::new(Float64Array::from_iter_values(
334                        self.values()
335                            .as_any()
336                            .downcast_ref::<Int64Array>()
337                            .ok_or(ArrowError::ParseError(
338                                "Fail to cast primitive array to Int8Type".to_string(),
339                            ))?
340                            .into_iter()
341                            .filter_map(|x| x.map(|y| y as f64)),
342                    )),
343                    self.nulls().cloned(),
344                )),
345                DataType::UInt8 => Ok(Self::new(
346                    Arc::new(arrow_schema::Field::new(
347                        field.name(),
348                        DataType::Float64,
349                        field.is_nullable(),
350                    )),
351                    *size,
352                    Arc::new(Float64Array::from_iter_values(
353                        self.values()
354                            .as_any()
355                            .downcast_ref::<UInt8Array>()
356                            .ok_or(ArrowError::ParseError(
357                                "Fail to cast primitive array to Int8Type".to_string(),
358                            ))?
359                            .into_iter()
360                            .filter_map(|x| x.map(|y| y as f64)),
361                    )),
362                    self.nulls().cloned(),
363                )),
364                DataType::UInt32 => Ok(Self::new(
365                    Arc::new(arrow_schema::Field::new(
366                        field.name(),
367                        DataType::Float64,
368                        field.is_nullable(),
369                    )),
370                    *size,
371                    Arc::new(Float64Array::from_iter_values(
372                        self.values()
373                            .as_any()
374                            .downcast_ref::<UInt32Array>()
375                            .ok_or(ArrowError::ParseError(
376                                "Fail to cast primitive array to Int8Type".to_string(),
377                            ))?
378                            .into_iter()
379                            .filter_map(|x| x.map(|y| y as f64)),
380                    )),
381                    self.nulls().cloned(),
382                )),
383                data_type => Err(ArrowError::ParseError(format!(
384                    "Expect either floating type or integer got {:?}",
385                    data_type
386                ))),
387            },
388            data_type => Err(ArrowError::ParseError(format!(
389                "Expect either FixedSizeList got {:?}",
390                data_type
391            ))),
392        }
393    }
394}
395
396/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
397/// [`FixedSizeListArray`], panic'ing on failure.
398pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
399    arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
400}
401
402pub trait FixedSizeBinaryArrayExt {
403    /// Create an [`FixedSizeBinaryArray`] from values and stride.
404    ///
405    /// ```
406    /// use arrow_array::{UInt8Array, FixedSizeBinaryArray};
407    /// use arrow_array::types::UInt8Type;
408    /// use lance_arrow::FixedSizeBinaryArrayExt;
409    ///
410    /// let int_values = UInt8Array::from_iter(0..10);
411    /// let fixed_size_list_arr = FixedSizeBinaryArray::try_new_from_values(&int_values, 2).unwrap();
412    /// assert_eq!(fixed_size_list_arr,
413    ///     FixedSizeBinaryArray::from(vec![
414    ///         Some(vec![0, 1].as_slice()),
415    ///         Some(vec![2, 3].as_slice()),
416    ///         Some(vec![4, 5].as_slice()),
417    ///         Some(vec![6, 7].as_slice()),
418    ///         Some(vec![8, 9].as_slice())
419    /// ]))
420    /// ```
421    fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<FixedSizeBinaryArray>;
422}
423
424impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray {
425    fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<Self> {
426        let data_type = DataType::FixedSizeBinary(stride);
427        let data = ArrayDataBuilder::new(data_type)
428            .len(values.len() / stride as usize)
429            .add_buffer(values.into_data().buffers()[0].clone())
430            .build()?;
431        Ok(Self::from(data))
432    }
433}
434
435pub fn as_fixed_size_binary_array(arr: &dyn Array) -> &FixedSizeBinaryArray {
436    arr.as_any().downcast_ref::<FixedSizeBinaryArray>().unwrap()
437}
438
439pub fn iter_str_array(arr: &dyn Array) -> Box<dyn Iterator<Item = Option<&str>> + '_> {
440    match arr.data_type() {
441        DataType::Utf8 => Box::new(arr.as_string::<i32>().iter()),
442        DataType::LargeUtf8 => Box::new(arr.as_string::<i64>().iter()),
443        _ => panic!("Expecting Utf8 or LargeUtf8, found {:?}", arr.data_type()),
444    }
445}
446
447/// Extends Arrow's [RecordBatch].
448pub trait RecordBatchExt {
449    /// Append a new column to this [`RecordBatch`] and returns a new RecordBatch.
450    ///
451    /// ```
452    /// use std::sync::Arc;
453    /// use arrow_array::{RecordBatch, Int32Array, StringArray};
454    /// use arrow_schema::{Schema, Field, DataType};
455    /// use lance_arrow::*;
456    ///
457    /// let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
458    /// let int_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
459    /// let record_batch = RecordBatch::try_new(schema, vec![int_arr.clone()]).unwrap();
460    ///
461    /// let new_field = Field::new("s", DataType::Utf8, true);
462    /// let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
463    /// let new_record_batch = record_batch.try_with_column(new_field, str_arr.clone()).unwrap();
464    ///
465    /// assert_eq!(
466    ///     new_record_batch,
467    ///     RecordBatch::try_new(
468    ///         Arc::new(Schema::new(
469    ///             vec![
470    ///                 Field::new("a", DataType::Int32, true),
471    ///                 Field::new("s", DataType::Utf8, true)
472    ///             ])
473    ///         ),
474    ///         vec![int_arr, str_arr],
475    ///     ).unwrap()
476    /// )
477    /// ```
478    fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
479
480    /// Created a new RecordBatch with column at index.
481    fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
482
483    /// Creates a new [`RecordBatch`] from the provided  [`StructArray`].
484    ///
485    /// The fields on the [`StructArray`] need to match this [`RecordBatch`] schema
486    fn try_new_from_struct_array(&self, arr: StructArray) -> Result<RecordBatch>;
487
488    /// Merge with another [`RecordBatch`] and returns a new one.
489    ///
490    /// Fields are merged based on name.  First we iterate the left columns.  If a matching
491    /// name is found in the right then we merge the two columns.  If there is no match then
492    /// we add the left column to the output.
493    ///
494    /// To merge two columns we consider the type.  If both arrays are struct arrays we recurse.
495    /// Otherwise we use the left array.
496    ///
497    /// Afterwards we add all non-matching right columns to the output.
498    ///
499    /// Note: This method likely does not handle nested fields correctly and you may want to consider
500    /// using [`merge_with_schema`] instead.
501    /// ```
502    /// use std::sync::Arc;
503    /// use arrow_array::*;
504    /// use arrow_schema::{Schema, Field, DataType};
505    /// use lance_arrow::*;
506    ///
507    /// let left_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
508    /// let int_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
509    /// let left = RecordBatch::try_new(left_schema, vec![int_arr.clone()]).unwrap();
510    ///
511    /// let right_schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
512    /// let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
513    /// let right = RecordBatch::try_new(right_schema, vec![str_arr.clone()]).unwrap();
514    ///
515    /// let new_record_batch = left.merge(&right).unwrap();
516    ///
517    /// assert_eq!(
518    ///     new_record_batch,
519    ///     RecordBatch::try_new(
520    ///         Arc::new(Schema::new(
521    ///             vec![
522    ///                 Field::new("a", DataType::Int32, true),
523    ///                 Field::new("s", DataType::Utf8, true)
524    ///             ])
525    ///         ),
526    ///         vec![int_arr, str_arr],
527    ///     ).unwrap()
528    /// )
529    /// ```
530    ///
531    /// TODO: add merge nested fields support.
532    fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;
533
534    /// Create a batch by merging columns between two batches with a given schema.
535    ///
536    /// A reference schema is used to determine the proper ordering of nested fields.
537    ///
538    /// For each field in the reference schema we look for corresponding fields in
539    /// the left and right batches.  If a field is found in both batches we recursively merge
540    /// it.
541    ///
542    /// If a field is only in the left or right batch we take it as it is.
543    fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch>;
544
545    /// Drop one column specified with the name and return the new [`RecordBatch`].
546    ///
547    /// If the named column does not exist, it returns a copy of this [`RecordBatch`].
548    fn drop_column(&self, name: &str) -> Result<RecordBatch>;
549
550    /// Replace a column (specified by name) and return the new [`RecordBatch`].
551    fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch>;
552
553    /// Replace a column schema (specified by name) and return the new [`RecordBatch`].
554    fn replace_column_schema_by_name(
555        &self,
556        name: &str,
557        new_data_type: DataType,
558        column: Arc<dyn Array>,
559    ) -> Result<RecordBatch>;
560
561    /// Get (potentially nested) column by qualified name.
562    fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef>;
563
564    /// Project the schema over the [RecordBatch].
565    fn project_by_schema(&self, schema: &Schema) -> Result<RecordBatch>;
566
567    /// metadata of the schema.
568    fn metadata(&self) -> &HashMap<String, String>;
569
570    /// Add metadata to the schema.
571    fn add_metadata(&self, key: String, value: String) -> Result<RecordBatch> {
572        let mut metadata = self.metadata().clone();
573        metadata.insert(key, value);
574        self.with_metadata(metadata)
575    }
576
577    /// Replace the schema metadata with the provided one.
578    fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch>;
579
580    /// Take selected rows from the [RecordBatch].
581    fn take(&self, indices: &UInt32Array) -> Result<RecordBatch>;
582}
583
584impl RecordBatchExt for RecordBatch {
585    fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<Self> {
586        let new_schema = Arc::new(self.schema().as_ref().try_with_column(field)?);
587        let mut new_columns = self.columns().to_vec();
588        new_columns.push(arr);
589        Self::try_new(new_schema, new_columns)
590    }
591
592    fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<Self> {
593        let new_schema = Arc::new(self.schema().as_ref().try_with_column_at(index, field)?);
594        let mut new_columns = self.columns().to_vec();
595        new_columns.insert(index, arr);
596        Self::try_new(new_schema, new_columns)
597    }
598
599    fn try_new_from_struct_array(&self, arr: StructArray) -> Result<Self> {
600        let schema = Arc::new(Schema::new_with_metadata(
601            arr.fields().to_vec(),
602            self.schema().metadata.clone(),
603        ));
604        let batch = Self::from(arr);
605        batch.with_schema(schema)
606    }
607
608    fn merge(&self, other: &Self) -> Result<Self> {
609        if self.num_rows() != other.num_rows() {
610            return Err(ArrowError::InvalidArgumentError(format!(
611                "Attempt to merge two RecordBatch with different sizes: {} != {}",
612                self.num_rows(),
613                other.num_rows()
614            )));
615        }
616        let left_struct_array: StructArray = self.clone().into();
617        let right_struct_array: StructArray = other.clone().into();
618        self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array))
619    }
620
621    fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch> {
622        if self.num_rows() != other.num_rows() {
623            return Err(ArrowError::InvalidArgumentError(format!(
624                "Attempt to merge two RecordBatch with different sizes: {} != {}",
625                self.num_rows(),
626                other.num_rows()
627            )));
628        }
629        let left_struct_array: StructArray = self.clone().into();
630        let right_struct_array: StructArray = other.clone().into();
631        self.try_new_from_struct_array(merge_with_schema(
632            &left_struct_array,
633            &right_struct_array,
634            schema.fields(),
635        ))
636    }
637
638    fn drop_column(&self, name: &str) -> Result<Self> {
639        let mut fields = vec![];
640        let mut columns = vec![];
641        for i in 0..self.schema().fields.len() {
642            if self.schema().field(i).name() != name {
643                fields.push(self.schema().field(i).clone());
644                columns.push(self.column(i).clone());
645            }
646        }
647        Self::try_new(
648            Arc::new(Schema::new_with_metadata(
649                fields,
650                self.schema().metadata().clone(),
651            )),
652            columns,
653        )
654    }
655
656    fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch> {
657        let mut columns = self.columns().to_vec();
658        let field_i = self
659            .schema()
660            .fields()
661            .iter()
662            .position(|f| f.name() == name)
663            .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
664        columns[field_i] = column;
665        Self::try_new(self.schema(), columns)
666    }
667
668    fn replace_column_schema_by_name(
669        &self,
670        name: &str,
671        new_data_type: DataType,
672        column: Arc<dyn Array>,
673    ) -> Result<RecordBatch> {
674        let fields = self
675            .schema()
676            .fields()
677            .iter()
678            .map(|x| {
679                if x.name() != name {
680                    x.clone()
681                } else {
682                    let new_field = Field::new(name, new_data_type.clone(), x.is_nullable());
683                    Arc::new(new_field)
684                }
685            })
686            .collect::<Vec<_>>();
687        let schema = Schema::new_with_metadata(fields, self.schema().metadata.clone());
688        let mut columns = self.columns().to_vec();
689        let field_i = self
690            .schema()
691            .fields()
692            .iter()
693            .position(|f| f.name() == name)
694            .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
695        columns[field_i] = column;
696        Self::try_new(Arc::new(schema), columns)
697    }
698
699    fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef> {
700        let split = name.split('.').collect::<Vec<_>>();
701        if split.is_empty() {
702            return None;
703        }
704
705        self.column_by_name(split[0])
706            .and_then(|arr| get_sub_array(arr, &split[1..]))
707    }
708
709    fn project_by_schema(&self, schema: &Schema) -> Result<Self> {
710        let struct_array: StructArray = self.clone().into();
711        self.try_new_from_struct_array(project(&struct_array, schema.fields())?)
712    }
713
714    fn metadata(&self) -> &HashMap<String, String> {
715        self.schema_ref().metadata()
716    }
717
718    fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch> {
719        let mut schema = self.schema_ref().as_ref().clone();
720        schema.metadata = metadata;
721        Self::try_new(schema.into(), self.columns().into())
722    }
723
724    fn take(&self, indices: &UInt32Array) -> Result<Self> {
725        let struct_array: StructArray = self.clone().into();
726        let taken = take(&struct_array, indices, None)?;
727        self.try_new_from_struct_array(taken.as_struct().clone())
728    }
729}
730
731fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
732    if fields.is_empty() {
733        return Ok(StructArray::new_empty_fields(
734            struct_array.len(),
735            struct_array.nulls().cloned(),
736        ));
737    }
738    let mut columns: Vec<ArrayRef> = vec![];
739    for field in fields.iter() {
740        if let Some(col) = struct_array.column_by_name(field.name()) {
741            match field.data_type() {
742                // TODO handle list-of-struct
743                DataType::Struct(subfields) => {
744                    let projected = project(col.as_struct(), subfields)?;
745                    columns.push(Arc::new(projected));
746                }
747                _ => {
748                    columns.push(col.clone());
749                }
750            }
751        } else {
752            return Err(ArrowError::SchemaError(format!(
753                "field {} does not exist in the RecordBatch",
754                field.name()
755            )));
756        }
757    }
758    StructArray::try_new(fields.clone(), columns, None)
759}
760
761fn lists_have_same_offsets_helper<T: OffsetSizeTrait>(left: &dyn Array, right: &dyn Array) -> bool {
762    let left_list: &GenericListArray<T> = left.as_list();
763    let right_list: &GenericListArray<T> = right.as_list();
764    left_list.offsets().inner() == right_list.offsets().inner()
765}
766
767fn merge_list_structs_helper<T: OffsetSizeTrait>(
768    left: &dyn Array,
769    right: &dyn Array,
770    items_field_name: impl Into<String>,
771    items_nullable: bool,
772) -> Arc<dyn Array> {
773    let left_list: &GenericListArray<T> = left.as_list();
774    let right_list: &GenericListArray<T> = right.as_list();
775    let left_struct = left_list.values();
776    let right_struct = right_list.values();
777    let left_struct_arr = left_struct.as_struct();
778    let right_struct_arr = right_struct.as_struct();
779    let merged_items = Arc::new(merge(left_struct_arr, right_struct_arr));
780    let items_field = Arc::new(Field::new(
781        items_field_name,
782        merged_items.data_type().clone(),
783        items_nullable,
784    ));
785    Arc::new(GenericListArray::<T>::new(
786        items_field,
787        left_list.offsets().clone(),
788        merged_items,
789        left_list.nulls().cloned(),
790    ))
791}
792
793fn merge_list_struct_null_helper<T: OffsetSizeTrait>(
794    left: &dyn Array,
795    right: &dyn Array,
796    not_null: &dyn Array,
797    items_field_name: impl Into<String>,
798) -> Arc<dyn Array> {
799    let left_list: &GenericListArray<T> = left.as_list::<T>();
800    let not_null_list = not_null.as_list::<T>();
801    let right_list = right.as_list::<T>();
802
803    let left_struct = left_list.values().as_struct();
804    let not_null_struct: &StructArray = not_null_list.values().as_struct();
805    let right_struct = right_list.values().as_struct();
806
807    let values_len = not_null_list.values().len();
808    let mut merged_fields =
809        Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
810    let mut merged_columns =
811        Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
812
813    for (_, field) in left_struct.columns().iter().zip(left_struct.fields()) {
814        merged_fields.push(field.clone());
815        if let Some(val) = not_null_struct.column_by_name(field.name()) {
816            merged_columns.push(val.clone());
817        } else {
818            merged_columns.push(new_null_array(field.data_type(), values_len))
819        }
820    }
821    for (_, field) in right_struct
822        .columns()
823        .iter()
824        .zip(right_struct.fields())
825        .filter(|(_, field)| left_struct.column_by_name(field.name()).is_none())
826    {
827        merged_fields.push(field.clone());
828        if let Some(val) = not_null_struct.column_by_name(field.name()) {
829            merged_columns.push(val.clone());
830        } else {
831            merged_columns.push(new_null_array(field.data_type(), values_len));
832        }
833    }
834
835    let merged_struct = Arc::new(StructArray::new(
836        Fields::from(merged_fields),
837        merged_columns,
838        not_null_struct.nulls().cloned(),
839    ));
840    let items_field = Arc::new(Field::new(
841        items_field_name,
842        merged_struct.data_type().clone(),
843        true,
844    ));
845    Arc::new(GenericListArray::<T>::new(
846        items_field,
847        not_null_list.offsets().clone(),
848        merged_struct,
849        not_null_list.nulls().cloned(),
850    ))
851}
852
853fn merge_list_struct_null(
854    left: &dyn Array,
855    right: &dyn Array,
856    not_null: &dyn Array,
857) -> Arc<dyn Array> {
858    match left.data_type() {
859        DataType::List(left_field) => {
860            merge_list_struct_null_helper::<i32>(left, right, not_null, left_field.name())
861        }
862        DataType::LargeList(left_field) => {
863            merge_list_struct_null_helper::<i64>(left, right, not_null, left_field.name())
864        }
865        _ => unreachable!(),
866    }
867}
868
869fn merge_list_struct(left: &dyn Array, right: &dyn Array) -> Arc<dyn Array> {
870    // Merging fields into a list<struct<...>> is tricky and can only succeed
871    // in two ways.  First, if both lists have the same offsets.  Second, if
872    // one of the lists is all-null
873    if left.null_count() == left.len() {
874        return merge_list_struct_null(left, right, right);
875    } else if right.null_count() == right.len() {
876        return merge_list_struct_null(left, right, left);
877    }
878    match (left.data_type(), right.data_type()) {
879        (DataType::List(left_field), DataType::List(_)) => {
880            if !lists_have_same_offsets_helper::<i32>(left, right) {
881                panic!("Attempt to merge list struct arrays which do not have same offsets");
882            }
883            merge_list_structs_helper::<i32>(
884                left,
885                right,
886                left_field.name(),
887                left_field.is_nullable(),
888            )
889        }
890        (DataType::LargeList(left_field), DataType::LargeList(_)) => {
891            if !lists_have_same_offsets_helper::<i64>(left, right) {
892                panic!("Attempt to merge list struct arrays which do not have same offsets");
893            }
894            merge_list_structs_helper::<i64>(
895                left,
896                right,
897                left_field.name(),
898                left_field.is_nullable(),
899            )
900        }
901        _ => unreachable!(),
902    }
903}
904
905fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray {
906    let mut fields: Vec<Field> = vec![];
907    let mut columns: Vec<ArrayRef> = vec![];
908    let right_fields = right_struct_array.fields();
909    let right_columns = right_struct_array.columns();
910
911    // iterate through the fields on the left hand side
912    for (left_field, left_column) in left_struct_array
913        .fields()
914        .iter()
915        .zip(left_struct_array.columns().iter())
916    {
917        match right_fields
918            .iter()
919            .position(|f| f.name() == left_field.name())
920        {
921            // if the field exists on the right hand side, merge them recursively if appropriate
922            Some(right_index) => {
923                let right_field = right_fields.get(right_index).unwrap();
924                let right_column = right_columns.get(right_index).unwrap();
925                // if both fields are struct, merge them recursively
926                match (left_field.data_type(), right_field.data_type()) {
927                    (DataType::Struct(_), DataType::Struct(_)) => {
928                        let left_sub_array = left_column.as_struct();
929                        let right_sub_array = right_column.as_struct();
930                        let merged_sub_array = merge(left_sub_array, right_sub_array);
931                        fields.push(Field::new(
932                            left_field.name(),
933                            merged_sub_array.data_type().clone(),
934                            left_field.is_nullable(),
935                        ));
936                        columns.push(Arc::new(merged_sub_array) as ArrayRef);
937                    }
938                    (DataType::List(left_list), DataType::List(right_list))
939                        if left_list.data_type().is_struct()
940                            && right_list.data_type().is_struct() =>
941                    {
942                        // If there is nothing to merge just use the left field
943                        if left_list.data_type() == right_list.data_type() {
944                            fields.push(left_field.as_ref().clone());
945                            columns.push(left_column.clone());
946                        }
947                        // If we have two List<Struct> and they have different sets of fields then
948                        // we can merge them if the offsets arrays are the same.  Otherwise, we
949                        // have to consider it an error.
950                        let merged_sub_array = merge_list_struct(&left_column, &right_column);
951
952                        fields.push(Field::new(
953                            left_field.name(),
954                            merged_sub_array.data_type().clone(),
955                            left_field.is_nullable(),
956                        ));
957                        columns.push(merged_sub_array);
958                    }
959                    // otherwise, just use the field on the left hand side
960                    _ => {
961                        // TODO handle list-of-struct and other types
962                        fields.push(left_field.as_ref().clone());
963                        columns.push(left_column.clone());
964                    }
965                }
966            }
967            None => {
968                fields.push(left_field.as_ref().clone());
969                columns.push(left_column.clone());
970            }
971        }
972    }
973
974    // now iterate through the fields on the right hand side
975    right_fields
976        .iter()
977        .zip(right_columns.iter())
978        .for_each(|(field, column)| {
979            // add new columns on the right
980            if !left_struct_array
981                .fields()
982                .iter()
983                .any(|f| f.name() == field.name())
984            {
985                fields.push(field.as_ref().clone());
986                columns.push(column.clone() as ArrayRef);
987            }
988        });
989
990    let zipped: Vec<(FieldRef, ArrayRef)> = fields
991        .iter()
992        .cloned()
993        .map(Arc::new)
994        .zip(columns.iter().cloned())
995        .collect::<Vec<_>>();
996    StructArray::from(zipped)
997}
998
999fn merge_with_schema(
1000    left_struct_array: &StructArray,
1001    right_struct_array: &StructArray,
1002    fields: &Fields,
1003) -> StructArray {
1004    // Helper function that returns true if both types are struct or both are non-struct
1005    fn same_type_kind(left: &DataType, right: &DataType) -> bool {
1006        match (left, right) {
1007            (DataType::Struct(_), DataType::Struct(_)) => true,
1008            (DataType::Struct(_), _) => false,
1009            (_, DataType::Struct(_)) => false,
1010            _ => true,
1011        }
1012    }
1013
1014    let mut output_fields: Vec<Field> = Vec::with_capacity(fields.len());
1015    let mut columns: Vec<ArrayRef> = Vec::with_capacity(fields.len());
1016
1017    let left_fields = left_struct_array.fields();
1018    let left_columns = left_struct_array.columns();
1019    let right_fields = right_struct_array.fields();
1020    let right_columns = right_struct_array.columns();
1021
1022    for field in fields {
1023        let left_match_idx = left_fields.iter().position(|f| {
1024            f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
1025        });
1026        let right_match_idx = right_fields.iter().position(|f| {
1027            f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
1028        });
1029
1030        match (left_match_idx, right_match_idx) {
1031            (None, Some(right_idx)) => {
1032                output_fields.push(right_fields[right_idx].as_ref().clone());
1033                columns.push(right_columns[right_idx].clone());
1034            }
1035            (Some(left_idx), None) => {
1036                output_fields.push(left_fields[left_idx].as_ref().clone());
1037                columns.push(left_columns[left_idx].clone());
1038            }
1039            (Some(left_idx), Some(right_idx)) => {
1040                if let DataType::Struct(child_fields) = field.data_type() {
1041                    let left_sub_array = left_columns[left_idx].as_struct();
1042                    let right_sub_array = right_columns[right_idx].as_struct();
1043                    let merged_sub_array =
1044                        merge_with_schema(left_sub_array, right_sub_array, child_fields);
1045                    output_fields.push(Field::new(
1046                        field.name(),
1047                        merged_sub_array.data_type().clone(),
1048                        field.is_nullable(),
1049                    ));
1050                    columns.push(Arc::new(merged_sub_array) as ArrayRef);
1051                } else {
1052                    output_fields.push(left_fields[left_idx].as_ref().clone());
1053                    columns.push(left_columns[left_idx].clone());
1054                }
1055            }
1056            (None, None) => {
1057                // The field will not be included in the output
1058            }
1059        }
1060    }
1061
1062    let zipped: Vec<(FieldRef, ArrayRef)> = output_fields
1063        .into_iter()
1064        .map(Arc::new)
1065        .zip(columns)
1066        .collect::<Vec<_>>();
1067    StructArray::from(zipped)
1068}
1069
1070fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> {
1071    if components.is_empty() {
1072        return Some(array);
1073    }
1074    if !matches!(array.data_type(), DataType::Struct(_)) {
1075        return None;
1076    }
1077    let struct_arr = array.as_struct();
1078    struct_arr
1079        .column_by_name(components[0])
1080        .and_then(|arr| get_sub_array(arr, &components[1..]))
1081}
1082
1083/// Interleave multiple RecordBatches into a single RecordBatch.
1084///
1085/// Behaves like [`arrow::compute::interleave`], but for RecordBatches.
1086pub fn interleave_batches(
1087    batches: &[RecordBatch],
1088    indices: &[(usize, usize)],
1089) -> Result<RecordBatch> {
1090    let first_batch = batches.first().ok_or_else(|| {
1091        ArrowError::InvalidArgumentError("Cannot interleave zero RecordBatches".to_string())
1092    })?;
1093    let schema = first_batch.schema();
1094    let num_columns = first_batch.num_columns();
1095    let mut columns = Vec::with_capacity(num_columns);
1096    let mut chunks = Vec::with_capacity(batches.len());
1097
1098    for i in 0..num_columns {
1099        for batch in batches {
1100            chunks.push(batch.column(i).as_ref());
1101        }
1102        let new_column = interleave(&chunks, indices)?;
1103        columns.push(new_column);
1104        chunks.clear();
1105    }
1106
1107    RecordBatch::try_new(schema, columns)
1108}
1109
1110pub trait BufferExt {
1111    /// Create an `arrow_buffer::Buffer`` from a `bytes::Bytes` object
1112    ///
1113    /// The alignment must be specified (as `bytes_per_value`) since we want to make
1114    /// sure we can safely reinterpret the buffer.
1115    ///
1116    /// If the buffer is properly aligned this will be zero-copy.  If not, a copy
1117    /// will be made and an owned buffer returned.
1118    ///
1119    /// If `bytes_per_value` is not a power of two, then we assume the buffer is
1120    /// never going to be reinterpreted into another type and we can safely
1121    /// ignore the alignment.
1122    ///
1123    /// Yes, the method name is odd.  It's because there is already a `from_bytes`
1124    /// which converts from `arrow_buffer::bytes::Bytes` (not `bytes::Bytes`)
1125    fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self;
1126
1127    /// Allocates a new properly aligned arrow buffer and copies `bytes` into it
1128    ///
1129    /// `size_bytes` can be larger than `bytes` and, if so, the trailing bytes will
1130    /// be zeroed out.
1131    ///
1132    /// # Panics
1133    ///
1134    /// Panics if `size_bytes` is less than `bytes.len()`
1135    fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self;
1136}
1137
1138fn is_pwr_two(n: u64) -> bool {
1139    n & (n - 1) == 0
1140}
1141
1142impl BufferExt for arrow_buffer::Buffer {
1143    fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self {
1144        if is_pwr_two(bytes_per_value) && bytes.as_ptr().align_offset(bytes_per_value as usize) != 0
1145        {
1146            // The original buffer is not aligned, cannot zero-copy
1147            let size_bytes = bytes.len();
1148            Self::copy_bytes_bytes(bytes, size_bytes)
1149        } else {
1150            // The original buffer is aligned, can zero-copy
1151            // SAFETY: the alignment is correct we can make this conversion
1152            unsafe {
1153                Self::from_custom_allocation(
1154                    NonNull::new(bytes.as_ptr() as _).expect("should be a valid pointer"),
1155                    bytes.len(),
1156                    Arc::new(bytes),
1157                )
1158            }
1159        }
1160    }
1161
1162    fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self {
1163        assert!(size_bytes >= bytes.len());
1164        let mut buf = MutableBuffer::with_capacity(size_bytes);
1165        let to_fill = size_bytes - bytes.len();
1166        buf.extend(bytes);
1167        buf.extend(std::iter::repeat_n(0_u8, to_fill));
1168        Self::from(buf)
1169    }
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174    use super::*;
1175    use arrow_array::{new_empty_array, new_null_array, Int32Array, ListArray, StringArray};
1176    use arrow_buffer::OffsetBuffer;
1177
1178    #[test]
1179    fn test_merge_recursive() {
1180        let a_array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
1181        let e_array = Int32Array::from(vec![Some(4), Some(5), Some(6)]);
1182        let c_array = Int32Array::from(vec![Some(7), Some(8), Some(9)]);
1183        let d_array = StringArray::from(vec![Some("a"), Some("b"), Some("c")]);
1184
1185        let left_schema = Schema::new(vec![
1186            Field::new("a", DataType::Int32, true),
1187            Field::new(
1188                "b",
1189                DataType::Struct(vec![Field::new("c", DataType::Int32, true)].into()),
1190                true,
1191            ),
1192        ]);
1193        let left_batch = RecordBatch::try_new(
1194            Arc::new(left_schema),
1195            vec![
1196                Arc::new(a_array.clone()),
1197                Arc::new(StructArray::from(vec![(
1198                    Arc::new(Field::new("c", DataType::Int32, true)),
1199                    Arc::new(c_array.clone()) as ArrayRef,
1200                )])),
1201            ],
1202        )
1203        .unwrap();
1204
1205        let right_schema = Schema::new(vec![
1206            Field::new("e", DataType::Int32, true),
1207            Field::new(
1208                "b",
1209                DataType::Struct(vec![Field::new("d", DataType::Utf8, true)].into()),
1210                true,
1211            ),
1212        ]);
1213        let right_batch = RecordBatch::try_new(
1214            Arc::new(right_schema),
1215            vec![
1216                Arc::new(e_array.clone()),
1217                Arc::new(StructArray::from(vec![(
1218                    Arc::new(Field::new("d", DataType::Utf8, true)),
1219                    Arc::new(d_array.clone()) as ArrayRef,
1220                )])) as ArrayRef,
1221            ],
1222        )
1223        .unwrap();
1224
1225        let merged_schema = Schema::new(vec![
1226            Field::new("a", DataType::Int32, true),
1227            Field::new(
1228                "b",
1229                DataType::Struct(
1230                    vec![
1231                        Field::new("c", DataType::Int32, true),
1232                        Field::new("d", DataType::Utf8, true),
1233                    ]
1234                    .into(),
1235                ),
1236                true,
1237            ),
1238            Field::new("e", DataType::Int32, true),
1239        ]);
1240        let merged_batch = RecordBatch::try_new(
1241            Arc::new(merged_schema),
1242            vec![
1243                Arc::new(a_array) as ArrayRef,
1244                Arc::new(StructArray::from(vec![
1245                    (
1246                        Arc::new(Field::new("c", DataType::Int32, true)),
1247                        Arc::new(c_array) as ArrayRef,
1248                    ),
1249                    (
1250                        Arc::new(Field::new("d", DataType::Utf8, true)),
1251                        Arc::new(d_array) as ArrayRef,
1252                    ),
1253                ])) as ArrayRef,
1254                Arc::new(e_array) as ArrayRef,
1255            ],
1256        )
1257        .unwrap();
1258
1259        let result = left_batch.merge(&right_batch).unwrap();
1260        assert_eq!(result, merged_batch);
1261    }
1262
1263    #[test]
1264    fn test_merge_with_schema() {
1265        fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) {
1266            let fields: Fields = names
1267                .iter()
1268                .zip(types)
1269                .map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false))
1270                .collect();
1271            let schema = Schema::new(vec![Field::new(
1272                "struct",
1273                DataType::Struct(fields.clone()),
1274                false,
1275            )]);
1276            let children = types.iter().map(new_empty_array).collect::<Vec<_>>();
1277            let batch = RecordBatch::try_new(
1278                Arc::new(schema.clone()),
1279                vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef],
1280            );
1281            (schema, batch.unwrap())
1282        }
1283
1284        let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]);
1285        let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]);
1286        let (output_schema, _) = test_batch(
1287            &["b", "a", "c"],
1288            &[DataType::Int64, DataType::Int32, DataType::Int32],
1289        );
1290
1291        // If we use merge_with_schema the schema is respected
1292        let merged = left_batch
1293            .merge_with_schema(&right_batch, &output_schema)
1294            .unwrap();
1295        assert_eq!(merged.schema().as_ref(), &output_schema);
1296
1297        // If we use merge we get first-come first-serve based on the left batch
1298        let (naive_schema, _) = test_batch(
1299            &["a", "b", "c"],
1300            &[DataType::Int32, DataType::Int64, DataType::Int32],
1301        );
1302        let merged = left_batch.merge(&right_batch).unwrap();
1303        assert_eq!(merged.schema().as_ref(), &naive_schema);
1304    }
1305
1306    #[test]
1307    fn test_merge_list_struct() {
1308        let x_field = Arc::new(Field::new("x", DataType::Int32, true));
1309        let y_field = Arc::new(Field::new("y", DataType::Int32, true));
1310        let x_struct_field = Arc::new(Field::new(
1311            "item",
1312            DataType::Struct(Fields::from(vec![x_field.clone()])),
1313            true,
1314        ));
1315        let y_struct_field = Arc::new(Field::new(
1316            "item",
1317            DataType::Struct(Fields::from(vec![y_field.clone()])),
1318            true,
1319        ));
1320        let both_struct_field = Arc::new(Field::new(
1321            "item",
1322            DataType::Struct(Fields::from(vec![x_field.clone(), y_field.clone()])),
1323            true,
1324        ));
1325        let left_schema = Schema::new(vec![Field::new(
1326            "list_struct",
1327            DataType::List(x_struct_field.clone()),
1328            true,
1329        )]);
1330        let right_schema = Schema::new(vec![Field::new(
1331            "list_struct",
1332            DataType::List(y_struct_field.clone()),
1333            true,
1334        )]);
1335        let both_schema = Schema::new(vec![Field::new(
1336            "list_struct",
1337            DataType::List(both_struct_field.clone()),
1338            true,
1339        )]);
1340
1341        let x = Arc::new(Int32Array::from(vec![1]));
1342        let y = Arc::new(Int32Array::from(vec![2]));
1343        let x_struct = Arc::new(StructArray::new(
1344            Fields::from(vec![x_field.clone()]),
1345            vec![x.clone()],
1346            None,
1347        ));
1348        let y_struct = Arc::new(StructArray::new(
1349            Fields::from(vec![y_field.clone()]),
1350            vec![y.clone()],
1351            None,
1352        ));
1353        let both_struct = Arc::new(StructArray::new(
1354            Fields::from(vec![x_field.clone(), y_field.clone()]),
1355            vec![x.clone(), y],
1356            None,
1357        ));
1358        let both_null_struct = Arc::new(StructArray::new(
1359            Fields::from(vec![x_field, y_field]),
1360            vec![x, Arc::new(new_null_array(&DataType::Int32, 1))],
1361            None,
1362        ));
1363        let offsets = OffsetBuffer::from_lengths([1]);
1364        let x_s_list = ListArray::new(x_struct_field, offsets.clone(), x_struct, None);
1365        let y_s_list = ListArray::new(y_struct_field, offsets.clone(), y_struct, None);
1366        let both_list = ListArray::new(
1367            both_struct_field.clone(),
1368            offsets.clone(),
1369            both_struct,
1370            None,
1371        );
1372        let both_null_list = ListArray::new(both_struct_field, offsets, both_null_struct, None);
1373        let x_batch =
1374            RecordBatch::try_new(Arc::new(left_schema), vec![Arc::new(x_s_list)]).unwrap();
1375        let y_batch = RecordBatch::try_new(
1376            Arc::new(right_schema.clone()),
1377            vec![Arc::new(y_s_list.clone())],
1378        )
1379        .unwrap();
1380        let merged = x_batch.merge(&y_batch).unwrap();
1381        let expected =
1382            RecordBatch::try_new(Arc::new(both_schema.clone()), vec![Arc::new(both_list)]).unwrap();
1383        assert_eq!(merged, expected);
1384
1385        let y_null_list = new_null_array(y_s_list.data_type(), 1);
1386        let y_null_batch =
1387            RecordBatch::try_new(Arc::new(right_schema), vec![Arc::new(y_null_list.clone())])
1388                .unwrap();
1389        let expected =
1390            RecordBatch::try_new(Arc::new(both_schema), vec![Arc::new(both_null_list)]).unwrap();
1391        let merged = x_batch.merge(&y_null_batch).unwrap();
1392        assert_eq!(merged, expected);
1393    }
1394
1395    #[test]
1396    fn test_take_record_batch() {
1397        let schema = Arc::new(Schema::new(vec![
1398            Field::new("a", DataType::Int32, true),
1399            Field::new("b", DataType::Utf8, true),
1400        ]));
1401        let batch = RecordBatch::try_new(
1402            schema.clone(),
1403            vec![
1404                Arc::new(Int32Array::from_iter_values(0..20)),
1405                Arc::new(StringArray::from_iter_values(
1406                    (0..20).map(|i| format!("str-{}", i)),
1407                )),
1408            ],
1409        )
1410        .unwrap();
1411        let taken = batch.take(&(vec![1_u32, 5_u32, 10_u32].into())).unwrap();
1412        assert_eq!(
1413            taken,
1414            RecordBatch::try_new(
1415                schema,
1416                vec![
1417                    Arc::new(Int32Array::from(vec![1, 5, 10])),
1418                    Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
1419                ],
1420            )
1421            .unwrap()
1422        )
1423    }
1424
1425    #[test]
1426    fn test_schema_project_by_schema() {
1427        let metadata = [("key".to_string(), "value".to_string())];
1428        let schema = Arc::new(
1429            Schema::new(vec![
1430                Field::new("a", DataType::Int32, true),
1431                Field::new("b", DataType::Utf8, true),
1432            ])
1433            .with_metadata(metadata.clone().into()),
1434        );
1435        let batch = RecordBatch::try_new(
1436            schema,
1437            vec![
1438                Arc::new(Int32Array::from_iter_values(0..20)),
1439                Arc::new(StringArray::from_iter_values(
1440                    (0..20).map(|i| format!("str-{}", i)),
1441                )),
1442            ],
1443        )
1444        .unwrap();
1445
1446        // Empty schema
1447        let empty_schema = Schema::empty();
1448        let empty_projected = batch.project_by_schema(&empty_schema).unwrap();
1449        let expected_schema = empty_schema.with_metadata(metadata.clone().into());
1450        assert_eq!(
1451            empty_projected,
1452            RecordBatch::from(StructArray::new_empty_fields(batch.num_rows(), None))
1453                .with_schema(Arc::new(expected_schema))
1454                .unwrap()
1455        );
1456
1457        // Re-ordered schema
1458        let reordered_schema = Schema::new(vec![
1459            Field::new("b", DataType::Utf8, true),
1460            Field::new("a", DataType::Int32, true),
1461        ]);
1462        let reordered_projected = batch.project_by_schema(&reordered_schema).unwrap();
1463        let expected_schema = Arc::new(reordered_schema.with_metadata(metadata.clone().into()));
1464        assert_eq!(
1465            reordered_projected,
1466            RecordBatch::try_new(
1467                expected_schema,
1468                vec![
1469                    Arc::new(StringArray::from_iter_values(
1470                        (0..20).map(|i| format!("str-{}", i)),
1471                    )),
1472                    Arc::new(Int32Array::from_iter_values(0..20)),
1473                ],
1474            )
1475            .unwrap()
1476        );
1477
1478        // Sub schema
1479        let sub_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1480        let sub_projected = batch.project_by_schema(&sub_schema).unwrap();
1481        let expected_schema = Arc::new(sub_schema.with_metadata(metadata.into()));
1482        assert_eq!(
1483            sub_projected,
1484            RecordBatch::try_new(
1485                expected_schema,
1486                vec![Arc::new(Int32Array::from_iter_values(0..20))],
1487            )
1488            .unwrap()
1489        );
1490    }
1491}