1use crate::cast::AsArray;
22use crate::{Array, ArrayRef, StructArray, new_empty_array};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31    fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39    fn schema(&self) -> SchemaRef {
40        self.as_ref().schema()
41    }
42}
43
44pub trait RecordBatchWriter {
46    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49    fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80    (@from Boolean) => { $crate::BooleanArray };
82    (@from Int8) => { $crate::Int8Array };
83    (@from Int16) => { $crate::Int16Array };
84    (@from Int32) => { $crate::Int32Array };
85    (@from Int64) => { $crate::Int64Array };
86    (@from UInt8) => { $crate::UInt8Array };
87    (@from UInt16) => { $crate::UInt16Array };
88    (@from UInt32) => { $crate::UInt32Array };
89    (@from UInt64) => { $crate::UInt64Array };
90    (@from Float16) => { $crate::Float16Array };
91    (@from Float32) => { $crate::Float32Array };
92    (@from Float64) => { $crate::Float64Array };
93    (@from Utf8) => { $crate::StringArray };
94    (@from Utf8View) => { $crate::StringViewArray };
95    (@from LargeUtf8) => { $crate::LargeStringArray };
96    (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97    (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98    (@from Second) => { $crate::TimestampSecondArray };
99    (@from Millisecond) => { $crate::TimestampMillisecondArray };
100    (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101    (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102    (@from Second32) => { $crate::Time32SecondArray };
103    (@from Millisecond32) => { $crate::Time32MillisecondArray };
104    (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105    (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106    (@from DurationSecond) => { $crate::DurationSecondArray };
107    (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108    (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109    (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110    (@from Decimal32) => { $crate::Decimal32Array };
111    (@from Decimal64) => { $crate::Decimal64Array };
112    (@from Decimal128) => { $crate::Decimal128Array };
113    (@from Decimal256) => { $crate::Decimal256Array };
114    (@from TimestampSecond) => { $crate::TimestampSecondArray };
115    (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116    (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117    (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119    (@from $ty: ident) => {
120        compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121    };
122
123    (Null, $size: expr) => {
124        std::sync::Arc::new($crate::NullArray::new($size))
125    };
126
127    (Binary, [$($values: expr),*]) => {
128        std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129    };
130
131    (LargeBinary, [$($values: expr),*]) => {
132        std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133    };
134
135    ($ty: tt, [$($values: expr),*]) => {
136        std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137    };
138}
139
140#[macro_export]
157macro_rules! record_batch {
158    ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
159        {
160            let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
161                $(
162                    arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
163                )*
164            ]));
165
166            let batch = $crate::RecordBatch::try_new(
167                schema,
168                vec![$(
169                    $crate::create_array!($type, [$($values),*]),
170                )*]
171            );
172
173            batch
174        }
175    }
176}
177
178#[derive(Clone, Debug, PartialEq)]
202pub struct RecordBatch {
203    schema: SchemaRef,
204    columns: Vec<Arc<dyn Array>>,
205
206    row_count: usize,
210}
211
212impl RecordBatch {
213    pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
242        let options = RecordBatchOptions::new();
243        Self::try_new_impl(schema, columns, &options)
244    }
245
246    pub unsafe fn new_unchecked(
262        schema: SchemaRef,
263        columns: Vec<Arc<dyn Array>>,
264        row_count: usize,
265    ) -> Self {
266        Self {
267            schema,
268            columns,
269            row_count,
270        }
271    }
272
273    pub fn try_new_with_options(
278        schema: SchemaRef,
279        columns: Vec<ArrayRef>,
280        options: &RecordBatchOptions,
281    ) -> Result<Self, ArrowError> {
282        Self::try_new_impl(schema, columns, options)
283    }
284
285    pub fn new_empty(schema: SchemaRef) -> Self {
287        let columns = schema
288            .fields()
289            .iter()
290            .map(|field| new_empty_array(field.data_type()))
291            .collect();
292
293        RecordBatch {
294            schema,
295            columns,
296            row_count: 0,
297        }
298    }
299
300    fn try_new_impl(
303        schema: SchemaRef,
304        columns: Vec<ArrayRef>,
305        options: &RecordBatchOptions,
306    ) -> Result<Self, ArrowError> {
307        if schema.fields().len() != columns.len() {
309            return Err(ArrowError::InvalidArgumentError(format!(
310                "number of columns({}) must match number of fields({}) in schema",
311                columns.len(),
312                schema.fields().len(),
313            )));
314        }
315
316        let row_count = options
317            .row_count
318            .or_else(|| columns.first().map(|col| col.len()))
319            .ok_or_else(|| {
320                ArrowError::InvalidArgumentError(
321                    "must either specify a row count or at least one column".to_string(),
322                )
323            })?;
324
325        for (c, f) in columns.iter().zip(&schema.fields) {
326            if !f.is_nullable() && c.null_count() > 0 {
327                return Err(ArrowError::InvalidArgumentError(format!(
328                    "Column '{}' is declared as non-nullable but contains null values",
329                    f.name()
330                )));
331            }
332        }
333
334        if columns.iter().any(|c| c.len() != row_count) {
336            let err = match options.row_count {
337                Some(_) => "all columns in a record batch must have the specified row count",
338                None => "all columns in a record batch must have the same length",
339            };
340            return Err(ArrowError::InvalidArgumentError(err.to_string()));
341        }
342
343        let type_not_match = if options.match_field_names {
346            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
347        } else {
348            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
349                !col_type.equals_datatype(field_type)
350            }
351        };
352
353        let not_match = columns
355            .iter()
356            .zip(schema.fields().iter())
357            .map(|(col, field)| (col.data_type(), field.data_type()))
358            .enumerate()
359            .find(type_not_match);
360
361        if let Some((i, (col_type, field_type))) = not_match {
362            return Err(ArrowError::InvalidArgumentError(format!(
363                "column types must match schema types, expected {field_type} but found {col_type} at column index {i}"
364            )));
365        }
366
367        Ok(RecordBatch {
368            schema,
369            columns,
370            row_count,
371        })
372    }
373
374    pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
376        (self.schema, self.columns, self.row_count)
377    }
378
379    pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
386        if !schema.contains(self.schema.as_ref()) {
387            return Err(ArrowError::SchemaError(format!(
388                "target schema is not superset of current schema target={schema} current={}",
389                self.schema
390            )));
391        }
392
393        Ok(Self {
394            schema,
395            columns: self.columns,
396            row_count: self.row_count,
397        })
398    }
399
400    pub fn schema(&self) -> SchemaRef {
402        self.schema.clone()
403    }
404
405    pub fn schema_ref(&self) -> &SchemaRef {
407        &self.schema
408    }
409
410    pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
428        let schema = Arc::make_mut(&mut self.schema);
429        &mut schema.metadata
430    }
431
432    pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
434        let projected_schema = self.schema.project(indices)?;
435        let batch_fields = indices
436            .iter()
437            .map(|f| {
438                self.columns.get(*f).cloned().ok_or_else(|| {
439                    ArrowError::SchemaError(format!(
440                        "project index {} out of bounds, max field {}",
441                        f,
442                        self.columns.len()
443                    ))
444                })
445            })
446            .collect::<Result<Vec<_>, _>>()?;
447
448        unsafe {
449            Ok(RecordBatch::new_unchecked(
453                SchemaRef::new(projected_schema),
454                batch_fields,
455                self.row_count,
456            ))
457        }
458    }
459
460    pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
520        let max_level = match max_level.unwrap_or(usize::MAX) {
521            0 => usize::MAX,
522            val => val,
523        };
524        let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
525            .columns
526            .iter()
527            .zip(self.schema.fields())
528            .rev()
529            .map(|(c, f)| {
530                let name_vec: Vec<&str> = vec![f.name()];
531                (0, c, name_vec, f)
532            })
533            .collect();
534        let mut columns: Vec<ArrayRef> = Vec::new();
535        let mut fields: Vec<FieldRef> = Vec::new();
536
537        while let Some((depth, c, name, field_ref)) = stack.pop() {
538            match field_ref.data_type() {
539                DataType::Struct(ff) if depth < max_level => {
540                    for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
542                        let mut name = name.clone();
543                        name.push(separator);
544                        name.push(fff.name());
545                        stack.push((depth + 1, cff, name, fff))
546                    }
547                }
548                _ => {
549                    let updated_field = Field::new(
550                        name.concat(),
551                        field_ref.data_type().clone(),
552                        field_ref.is_nullable(),
553                    );
554                    columns.push(c.clone());
555                    fields.push(Arc::new(updated_field));
556                }
557            }
558        }
559        RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
560    }
561
562    pub fn num_columns(&self) -> usize {
581        self.columns.len()
582    }
583
584    pub fn num_rows(&self) -> usize {
603        self.row_count
604    }
605
606    pub fn column(&self, index: usize) -> &ArrayRef {
612        &self.columns[index]
613    }
614
615    pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
617        self.schema()
618            .column_with_name(name)
619            .map(|(index, _)| &self.columns[index])
620    }
621
622    pub fn columns(&self) -> &[ArrayRef] {
624        &self.columns[..]
625    }
626
627    pub fn remove_column(&mut self, index: usize) -> ArrayRef {
655        let mut builder = SchemaBuilder::from(self.schema.as_ref());
656        builder.remove(index);
657        self.schema = Arc::new(builder.finish());
658        self.columns.remove(index)
659    }
660
661    pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
668        assert!((offset + length) <= self.num_rows());
669
670        let columns = self
671            .columns()
672            .iter()
673            .map(|column| column.slice(offset, length))
674            .collect();
675
676        Self {
677            schema: self.schema.clone(),
678            columns,
679            row_count: length,
680        }
681    }
682
683    pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
720    where
721        I: IntoIterator<Item = (F, ArrayRef)>,
722        F: AsRef<str>,
723    {
724        let iter = value.into_iter().map(|(field_name, array)| {
728            let nullable = array.null_count() > 0;
729            (field_name, array, nullable)
730        });
731
732        Self::try_from_iter_with_nullable(iter)
733    }
734
735    pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
757    where
758        I: IntoIterator<Item = (F, ArrayRef, bool)>,
759        F: AsRef<str>,
760    {
761        let iter = value.into_iter();
762        let capacity = iter.size_hint().0;
763        let mut schema = SchemaBuilder::with_capacity(capacity);
764        let mut columns = Vec::with_capacity(capacity);
765
766        for (field_name, array, nullable) in iter {
767            let field_name = field_name.as_ref();
768            schema.push(Field::new(field_name, array.data_type().clone(), nullable));
769            columns.push(array);
770        }
771
772        let schema = Arc::new(schema.finish());
773        RecordBatch::try_new(schema, columns)
774    }
775
776    pub fn get_array_memory_size(&self) -> usize {
783        self.columns()
784            .iter()
785            .map(|array| array.get_array_memory_size())
786            .sum()
787    }
788}
789
790#[derive(Debug)]
792#[non_exhaustive]
793pub struct RecordBatchOptions {
794    pub match_field_names: bool,
796
797    pub row_count: Option<usize>,
799}
800
801impl RecordBatchOptions {
802    pub fn new() -> Self {
804        Self {
805            match_field_names: true,
806            row_count: None,
807        }
808    }
809    pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
811        self.row_count = row_count;
812        self
813    }
814    pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
816        self.match_field_names = match_field_names;
817        self
818    }
819}
820impl Default for RecordBatchOptions {
821    fn default() -> Self {
822        Self::new()
823    }
824}
825impl From<StructArray> for RecordBatch {
826    fn from(value: StructArray) -> Self {
827        let row_count = value.len();
828        let (fields, columns, nulls) = value.into_parts();
829        assert_eq!(
830            nulls.map(|n| n.null_count()).unwrap_or_default(),
831            0,
832            "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
833        );
834
835        RecordBatch {
836            schema: Arc::new(Schema::new(fields)),
837            row_count,
838            columns,
839        }
840    }
841}
842
843impl From<&StructArray> for RecordBatch {
844    fn from(struct_array: &StructArray) -> Self {
845        struct_array.clone().into()
846    }
847}
848
849impl Index<&str> for RecordBatch {
850    type Output = ArrayRef;
851
852    fn index(&self, name: &str) -> &Self::Output {
858        self.column_by_name(name).unwrap()
859    }
860}
861
862pub struct RecordBatchIterator<I>
888where
889    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
890{
891    inner: I::IntoIter,
892    inner_schema: SchemaRef,
893}
894
895impl<I> RecordBatchIterator<I>
896where
897    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
898{
899    pub fn new(iter: I, schema: SchemaRef) -> Self {
903        Self {
904            inner: iter.into_iter(),
905            inner_schema: schema,
906        }
907    }
908}
909
910impl<I> Iterator for RecordBatchIterator<I>
911where
912    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
913{
914    type Item = I::Item;
915
916    fn next(&mut self) -> Option<Self::Item> {
917        self.inner.next()
918    }
919
920    fn size_hint(&self) -> (usize, Option<usize>) {
921        self.inner.size_hint()
922    }
923}
924
925impl<I> RecordBatchReader for RecordBatchIterator<I>
926where
927    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
928{
929    fn schema(&self) -> SchemaRef {
930        self.inner_schema.clone()
931    }
932}
933
934#[cfg(test)]
935mod tests {
936    use super::*;
937    use crate::{
938        BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
939    };
940    use arrow_buffer::{Buffer, ToByteSlice};
941    use arrow_data::{ArrayData, ArrayDataBuilder};
942    use arrow_schema::Fields;
943    use std::collections::HashMap;
944
945    #[test]
946    fn create_record_batch() {
947        let schema = Schema::new(vec![
948            Field::new("a", DataType::Int32, false),
949            Field::new("b", DataType::Utf8, false),
950        ]);
951
952        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
953        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
954
955        let record_batch =
956            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
957        check_batch(record_batch, 5)
958    }
959
960    #[test]
961    fn create_string_view_record_batch() {
962        let schema = Schema::new(vec![
963            Field::new("a", DataType::Int32, false),
964            Field::new("b", DataType::Utf8View, false),
965        ]);
966
967        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
968        let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
969
970        let record_batch =
971            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
972
973        assert_eq!(5, record_batch.num_rows());
974        assert_eq!(2, record_batch.num_columns());
975        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
976        assert_eq!(
977            &DataType::Utf8View,
978            record_batch.schema().field(1).data_type()
979        );
980        assert_eq!(5, record_batch.column(0).len());
981        assert_eq!(5, record_batch.column(1).len());
982    }
983
984    #[test]
985    fn byte_size_should_not_regress() {
986        let schema = Schema::new(vec![
987            Field::new("a", DataType::Int32, false),
988            Field::new("b", DataType::Utf8, false),
989        ]);
990
991        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
992        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
993
994        let record_batch =
995            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
996        assert_eq!(record_batch.get_array_memory_size(), 364);
997    }
998
999    fn check_batch(record_batch: RecordBatch, num_rows: usize) {
1000        assert_eq!(num_rows, record_batch.num_rows());
1001        assert_eq!(2, record_batch.num_columns());
1002        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1003        assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1004        assert_eq!(num_rows, record_batch.column(0).len());
1005        assert_eq!(num_rows, record_batch.column(1).len());
1006    }
1007
1008    #[test]
1009    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1010    fn create_record_batch_slice() {
1011        let schema = Schema::new(vec![
1012            Field::new("a", DataType::Int32, false),
1013            Field::new("b", DataType::Utf8, false),
1014        ]);
1015        let expected_schema = schema.clone();
1016
1017        let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1018        let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1019
1020        let record_batch =
1021            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1022
1023        let offset = 2;
1024        let length = 5;
1025        let record_batch_slice = record_batch.slice(offset, length);
1026
1027        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1028        check_batch(record_batch_slice, 5);
1029
1030        let offset = 2;
1031        let length = 0;
1032        let record_batch_slice = record_batch.slice(offset, length);
1033
1034        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1035        check_batch(record_batch_slice, 0);
1036
1037        let offset = 2;
1038        let length = 10;
1039        let _record_batch_slice = record_batch.slice(offset, length);
1040    }
1041
1042    #[test]
1043    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1044    fn create_record_batch_slice_empty_batch() {
1045        let schema = Schema::empty();
1046
1047        let record_batch = RecordBatch::new_empty(Arc::new(schema));
1048
1049        let offset = 0;
1050        let length = 0;
1051        let record_batch_slice = record_batch.slice(offset, length);
1052        assert_eq!(0, record_batch_slice.schema().fields().len());
1053
1054        let offset = 1;
1055        let length = 2;
1056        let _record_batch_slice = record_batch.slice(offset, length);
1057    }
1058
1059    #[test]
1060    fn create_record_batch_try_from_iter() {
1061        let a: ArrayRef = Arc::new(Int32Array::from(vec![
1062            Some(1),
1063            Some(2),
1064            None,
1065            Some(4),
1066            Some(5),
1067        ]));
1068        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1069
1070        let record_batch =
1071            RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1072
1073        let expected_schema = Schema::new(vec![
1074            Field::new("a", DataType::Int32, true),
1075            Field::new("b", DataType::Utf8, false),
1076        ]);
1077        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1078        check_batch(record_batch, 5);
1079    }
1080
1081    #[test]
1082    fn create_record_batch_try_from_iter_with_nullable() {
1083        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1084        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1085
1086        let record_batch =
1088            RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1089                .expect("valid conversion");
1090
1091        let expected_schema = Schema::new(vec![
1092            Field::new("a", DataType::Int32, false),
1093            Field::new("b", DataType::Utf8, true),
1094        ]);
1095        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1096        check_batch(record_batch, 5);
1097    }
1098
1099    #[test]
1100    fn create_record_batch_schema_mismatch() {
1101        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1102
1103        let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1104
1105        let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1106        assert_eq!(
1107            err.to_string(),
1108            "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"
1109        );
1110    }
1111
1112    #[test]
1113    fn create_record_batch_field_name_mismatch() {
1114        let fields = vec![
1115            Field::new("a1", DataType::Int32, false),
1116            Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1117        ];
1118        let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1119
1120        let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1121        let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1122        let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1123            "array",
1124            DataType::Int8,
1125            false,
1126        ))))
1127        .add_child_data(a2_child.into_data())
1128        .len(2)
1129        .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1130        .build()
1131        .unwrap();
1132        let a2: ArrayRef = Arc::new(ListArray::from(a2));
1133        let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1134            Field::new("aa1", DataType::Int32, false),
1135            Field::new("a2", a2.data_type().clone(), false),
1136        ])))
1137        .add_child_data(a1.into_data())
1138        .add_child_data(a2.into_data())
1139        .len(2)
1140        .build()
1141        .unwrap();
1142        let a: ArrayRef = Arc::new(StructArray::from(a));
1143
1144        let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1146        assert!(batch.is_err());
1147
1148        let options = RecordBatchOptions {
1150            match_field_names: false,
1151            row_count: None,
1152        };
1153        let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1154        assert!(batch.is_ok());
1155    }
1156
1157    #[test]
1158    fn create_record_batch_record_mismatch() {
1159        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1160
1161        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1162        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1163
1164        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1165        assert!(batch.is_err());
1166    }
1167
1168    #[test]
1169    fn create_record_batch_from_struct_array() {
1170        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1171        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1172        let struct_array = StructArray::from(vec![
1173            (
1174                Arc::new(Field::new("b", DataType::Boolean, false)),
1175                boolean.clone() as ArrayRef,
1176            ),
1177            (
1178                Arc::new(Field::new("c", DataType::Int32, false)),
1179                int.clone() as ArrayRef,
1180            ),
1181        ]);
1182
1183        let batch = RecordBatch::from(&struct_array);
1184        assert_eq!(2, batch.num_columns());
1185        assert_eq!(4, batch.num_rows());
1186        assert_eq!(
1187            struct_array.data_type(),
1188            &DataType::Struct(batch.schema().fields().clone())
1189        );
1190        assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1191        assert_eq!(batch.column(1).as_ref(), int.as_ref());
1192    }
1193
1194    #[test]
1195    fn record_batch_equality() {
1196        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1197        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1198        let schema1 = Schema::new(vec![
1199            Field::new("id", DataType::Int32, false),
1200            Field::new("val", DataType::Int32, false),
1201        ]);
1202
1203        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1204        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1205        let schema2 = Schema::new(vec![
1206            Field::new("id", DataType::Int32, false),
1207            Field::new("val", DataType::Int32, false),
1208        ]);
1209
1210        let batch1 = RecordBatch::try_new(
1211            Arc::new(schema1),
1212            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1213        )
1214        .unwrap();
1215
1216        let batch2 = RecordBatch::try_new(
1217            Arc::new(schema2),
1218            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1219        )
1220        .unwrap();
1221
1222        assert_eq!(batch1, batch2);
1223    }
1224
1225    #[test]
1227    fn record_batch_index_access() {
1228        let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1229        let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1230        let schema1 = Schema::new(vec![
1231            Field::new("id", DataType::Int32, false),
1232            Field::new("val", DataType::Int32, false),
1233        ]);
1234        let record_batch =
1235            RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1236
1237        assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1238        assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1239    }
1240
1241    #[test]
1242    fn record_batch_vals_ne() {
1243        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1244        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1245        let schema1 = Schema::new(vec![
1246            Field::new("id", DataType::Int32, false),
1247            Field::new("val", DataType::Int32, false),
1248        ]);
1249
1250        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1251        let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1252        let schema2 = Schema::new(vec![
1253            Field::new("id", DataType::Int32, false),
1254            Field::new("val", DataType::Int32, false),
1255        ]);
1256
1257        let batch1 = RecordBatch::try_new(
1258            Arc::new(schema1),
1259            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1260        )
1261        .unwrap();
1262
1263        let batch2 = RecordBatch::try_new(
1264            Arc::new(schema2),
1265            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1266        )
1267        .unwrap();
1268
1269        assert_ne!(batch1, batch2);
1270    }
1271
1272    #[test]
1273    fn record_batch_column_names_ne() {
1274        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1275        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1276        let schema1 = Schema::new(vec![
1277            Field::new("id", DataType::Int32, false),
1278            Field::new("val", DataType::Int32, false),
1279        ]);
1280
1281        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1282        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1283        let schema2 = Schema::new(vec![
1284            Field::new("id", DataType::Int32, false),
1285            Field::new("num", DataType::Int32, false),
1286        ]);
1287
1288        let batch1 = RecordBatch::try_new(
1289            Arc::new(schema1),
1290            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1291        )
1292        .unwrap();
1293
1294        let batch2 = RecordBatch::try_new(
1295            Arc::new(schema2),
1296            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1297        )
1298        .unwrap();
1299
1300        assert_ne!(batch1, batch2);
1301    }
1302
1303    #[test]
1304    fn record_batch_column_number_ne() {
1305        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1306        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1307        let schema1 = Schema::new(vec![
1308            Field::new("id", DataType::Int32, false),
1309            Field::new("val", DataType::Int32, false),
1310        ]);
1311
1312        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1313        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1314        let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1315        let schema2 = Schema::new(vec![
1316            Field::new("id", DataType::Int32, false),
1317            Field::new("val", DataType::Int32, false),
1318            Field::new("num", DataType::Int32, false),
1319        ]);
1320
1321        let batch1 = RecordBatch::try_new(
1322            Arc::new(schema1),
1323            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1324        )
1325        .unwrap();
1326
1327        let batch2 = RecordBatch::try_new(
1328            Arc::new(schema2),
1329            vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1330        )
1331        .unwrap();
1332
1333        assert_ne!(batch1, batch2);
1334    }
1335
1336    #[test]
1337    fn record_batch_row_count_ne() {
1338        let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1339        let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1340        let schema1 = Schema::new(vec![
1341            Field::new("id", DataType::Int32, false),
1342            Field::new("val", DataType::Int32, false),
1343        ]);
1344
1345        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1346        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1347        let schema2 = Schema::new(vec![
1348            Field::new("id", DataType::Int32, false),
1349            Field::new("num", DataType::Int32, false),
1350        ]);
1351
1352        let batch1 = RecordBatch::try_new(
1353            Arc::new(schema1),
1354            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1355        )
1356        .unwrap();
1357
1358        let batch2 = RecordBatch::try_new(
1359            Arc::new(schema2),
1360            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1361        )
1362        .unwrap();
1363
1364        assert_ne!(batch1, batch2);
1365    }
1366
1367    #[test]
1368    fn normalize_simple() {
1369        let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1370        let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1371        let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1372
1373        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1374        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1375        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1376
1377        let a = Arc::new(StructArray::from(vec![
1378            (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1379            (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1380            (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1381        ]));
1382
1383        let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1384
1385        let schema = Schema::new(vec![
1386            Field::new(
1387                "a",
1388                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1389                false,
1390            ),
1391            Field::new("month", DataType::Int64, true),
1392        ]);
1393
1394        let normalized =
1395            RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1396                .expect("valid conversion")
1397                .normalize(".", Some(0))
1398                .expect("valid normalization");
1399
1400        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1401            ("a.animals", animals.clone(), true),
1402            ("a.n_legs", n_legs.clone(), true),
1403            ("a.year", year.clone(), true),
1404            ("month", month.clone(), true),
1405        ])
1406        .expect("valid conversion");
1407
1408        assert_eq!(expected, normalized);
1409
1410        let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1412            .expect("valid conversion")
1413            .normalize(".", None)
1414            .expect("valid normalization");
1415
1416        assert_eq!(expected, normalized);
1417    }
1418
1419    #[test]
1420    fn normalize_nested() {
1421        let a = Arc::new(Field::new("a", DataType::Int64, true));
1423        let b = Arc::new(Field::new("b", DataType::Int64, false));
1424        let c = Arc::new(Field::new("c", DataType::Int64, true));
1425
1426        let one = Arc::new(Field::new(
1427            "1",
1428            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1429            false,
1430        ));
1431        let two = Arc::new(Field::new(
1432            "2",
1433            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1434            true,
1435        ));
1436
1437        let exclamation = Arc::new(Field::new(
1438            "!",
1439            DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1440            false,
1441        ));
1442
1443        let schema = Schema::new(vec![exclamation.clone()]);
1444
1445        let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1447        let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1448        let c_field = Int64Array::from(vec![None, Some(4)]);
1449
1450        let one_field = StructArray::from(vec![
1451            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1452            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1453            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1454        ]);
1455        let two_field = StructArray::from(vec![
1456            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1457            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1458            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1459        ]);
1460
1461        let exclamation_field = Arc::new(StructArray::from(vec![
1462            (one.clone(), Arc::new(one_field) as ArrayRef),
1463            (two.clone(), Arc::new(two_field) as ArrayRef),
1464        ]));
1465
1466        let normalized =
1468            RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1469                .expect("valid conversion")
1470                .normalize(".", Some(1))
1471                .expect("valid normalization");
1472
1473        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1474            (
1475                "!.1",
1476                Arc::new(StructArray::from(vec![
1477                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1478                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1479                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1480                ])) as ArrayRef,
1481                false,
1482            ),
1483            (
1484                "!.2",
1485                Arc::new(StructArray::from(vec![
1486                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1487                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1488                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1489                ])) as ArrayRef,
1490                true,
1491            ),
1492        ])
1493        .expect("valid conversion");
1494
1495        assert_eq!(expected, normalized);
1496
1497        let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1499            .expect("valid conversion")
1500            .normalize(".", None)
1501            .expect("valid normalization");
1502
1503        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1504            ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1505            ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1506            ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1507            ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1508            ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1509            ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1510        ])
1511        .expect("valid conversion");
1512
1513        assert_eq!(expected, normalized);
1514    }
1515
1516    #[test]
1517    fn normalize_empty() {
1518        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1519        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1520        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1521
1522        let schema = Schema::new(vec![
1523            Field::new(
1524                "a",
1525                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1526                false,
1527            ),
1528            Field::new("month", DataType::Int64, true),
1529        ]);
1530
1531        let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1532            .normalize(".", Some(0))
1533            .expect("valid normalization");
1534
1535        let expected = RecordBatch::new_empty(Arc::new(
1536            schema.normalize(".", Some(0)).expect("valid normalization"),
1537        ));
1538
1539        assert_eq!(expected, normalized);
1540    }
1541
1542    #[test]
1543    fn project() {
1544        let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1545        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1546        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1547
1548        let record_batch =
1549            RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1550                .expect("valid conversion");
1551
1552        let expected =
1553            RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1554
1555        assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1556    }
1557
1558    #[test]
1559    fn project_empty() {
1560        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1561
1562        let record_batch =
1563            RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1564
1565        let expected = RecordBatch::try_new_with_options(
1566            Arc::new(Schema::empty()),
1567            vec![],
1568            &RecordBatchOptions {
1569                match_field_names: true,
1570                row_count: Some(3),
1571            },
1572        )
1573        .expect("valid conversion");
1574
1575        assert_eq!(expected, record_batch.project(&[]).unwrap());
1576    }
1577
1578    #[test]
1579    fn test_no_column_record_batch() {
1580        let schema = Arc::new(Schema::empty());
1581
1582        let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1583        assert!(
1584            err.to_string()
1585                .contains("must either specify a row count or at least one column")
1586        );
1587
1588        let options = RecordBatchOptions::new().with_row_count(Some(10));
1589
1590        let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1591        assert_eq!(ok.num_rows(), 10);
1592
1593        let a = ok.slice(2, 5);
1594        assert_eq!(a.num_rows(), 5);
1595
1596        let b = ok.slice(5, 0);
1597        assert_eq!(b.num_rows(), 0);
1598
1599        assert_ne!(a, b);
1600        assert_eq!(b, RecordBatch::new_empty(schema))
1601    }
1602
1603    #[test]
1604    fn test_nulls_in_non_nullable_field() {
1605        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1606        let maybe_batch = RecordBatch::try_new(
1607            schema,
1608            vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1609        );
1610        assert_eq!(
1611            "Invalid argument error: Column 'a' is declared as non-nullable but contains null values",
1612            format!("{}", maybe_batch.err().unwrap())
1613        );
1614    }
1615    #[test]
1616    fn test_record_batch_options() {
1617        let options = RecordBatchOptions::new()
1618            .with_match_field_names(false)
1619            .with_row_count(Some(20));
1620        assert!(!options.match_field_names);
1621        assert_eq!(options.row_count.unwrap(), 20)
1622    }
1623
1624    #[test]
1625    #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1626    fn test_from_struct() {
1627        let s = StructArray::from(ArrayData::new_null(
1628            &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1630            2,
1631        ));
1632        let _ = RecordBatch::from(s);
1633    }
1634
1635    #[test]
1636    fn test_with_schema() {
1637        let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1638        let required_schema = Arc::new(required_schema);
1639        let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1640        let nullable_schema = Arc::new(nullable_schema);
1641
1642        let batch = RecordBatch::try_new(
1643            required_schema.clone(),
1644            vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1645        )
1646        .unwrap();
1647
1648        let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1650
1651        batch.clone().with_schema(required_schema).unwrap_err();
1653
1654        let metadata = vec![("foo".to_string(), "bar".to_string())]
1656            .into_iter()
1657            .collect();
1658        let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1659        let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1660
1661        batch.with_schema(nullable_schema).unwrap_err();
1663    }
1664
1665    #[test]
1666    fn test_boxed_reader() {
1667        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1670        let schema = Arc::new(schema);
1671
1672        let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1673        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1674
1675        fn get_size(reader: impl RecordBatchReader) -> usize {
1676            reader.size_hint().0
1677        }
1678
1679        let size = get_size(reader);
1680        assert_eq!(size, 0);
1681    }
1682
1683    #[test]
1684    fn test_remove_column_maintains_schema_metadata() {
1685        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1686        let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1687
1688        let mut metadata = HashMap::new();
1689        metadata.insert("foo".to_string(), "bar".to_string());
1690        let schema = Schema::new(vec![
1691            Field::new("id", DataType::Int32, false),
1692            Field::new("bool", DataType::Boolean, false),
1693        ])
1694        .with_metadata(metadata);
1695
1696        let mut batch = RecordBatch::try_new(
1697            Arc::new(schema),
1698            vec![Arc::new(id_array), Arc::new(bool_array)],
1699        )
1700        .unwrap();
1701
1702        let _removed_column = batch.remove_column(0);
1703        assert_eq!(batch.schema().metadata().len(), 1);
1704        assert_eq!(
1705            batch.schema().metadata().get("foo").unwrap().as_str(),
1706            "bar"
1707        );
1708    }
1709}